Skip to content

Commit

Permalink
Added class_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Jun 26, 2024
1 parent d73827d commit 83de6ca
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
11 changes: 9 additions & 2 deletions pytorch_toolbelt/losses/focal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
reduced_threshold: Optional[float] = None,
activation: str = "sigmoid",
softmax_dim: Optional[int] = None,
class_weights: Optional[Tensor] = None,
):
"""
Expand All @@ -44,6 +45,10 @@ def __init__(
self.activation = activation
self.softmax_dim = softmax_dim

if class_weights is not None and not torch.is_tensor(class_weights):
class_weights = torch.tensor(list(class_weights), dtype=torch.float32)
self.register_buffer('class_weights', class_weights, persistent=False)

self.focal_loss_fn = partial(
focal_loss_with_logits,
alpha=alpha,
Expand All @@ -64,7 +69,9 @@ def __repr__(self):
repr = f"{self.__class__.__name__}(alpha={self.alpha}, gamma={self.gamma}, "
repr += f"ignore_index={self.ignore_index}, reduction={self.reduction}, normalized={self.normalized}, "
repr += f"reduced_threshold={self.reduced_threshold}, activation={self.activation}, "
repr += f"softmax_dim={self.softmax_dim})"
repr += f"softmax_dim={self.softmax_dim},"
repr += f"class_weights={self.class_weights.tolist()}, "
repr += f")"
return repr

def forward(self, inputs: Tensor, targets: Tensor) -> Tensor:
Expand All @@ -81,7 +88,7 @@ def forward(self, inputs: Tensor, targets: Tensor) -> Tensor:
if len(targets.shape) + 1 == len(inputs.shape):
targets = self.get_one_hot_targets(targets, num_classes=inputs.size(1))

loss = self.focal_loss_fn(inputs, targets)
loss = self.focal_loss_fn(inputs, targets, class_weights=self.class_weights)
return loss

def _one_hot_targets(self, targets, num_classes):
Expand Down
8 changes: 8 additions & 0 deletions pytorch_toolbelt/losses/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def focal_loss_with_logits(
ignore_index=None,
activation: str = "sigmoid",
softmax_dim: Optional[int] = None,
class_weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Compute binary focal loss between target and output logits.
Expand Down Expand Up @@ -77,6 +78,13 @@ def focal_loss_with_logits(
if alpha is not None:
loss *= alpha * target + (1 - alpha) * (1 - target)

if class_weights is not None:
# class_weights is of shape [C]
# Loss is of shape [B,C ...]
# Reshape class_weights to [1, C, ...]
class_weights = class_weights.view(1, -1, *(1 for _ in range(loss.dim() - 2)))
loss *= class_weights

if ignore_index is not None:
ignore_mask = target.eq(ignore_index)
loss = torch.masked_fill(loss, ignore_mask, 0)
Expand Down

0 comments on commit 83de6ca

Please sign in to comment.