From 83de6cabf134d0d885d77a8f718e435cdc6b5cec Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Wed, 26 Jun 2024 18:29:45 +0300 Subject: [PATCH] Added class_weights --- pytorch_toolbelt/losses/focal.py | 11 +++++++++-- pytorch_toolbelt/losses/functional.py | 8 ++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/pytorch_toolbelt/losses/focal.py b/pytorch_toolbelt/losses/focal.py index 35bd758ca..359ae78c0 100644 --- a/pytorch_toolbelt/losses/focal.py +++ b/pytorch_toolbelt/losses/focal.py @@ -23,6 +23,7 @@ def __init__( reduced_threshold: Optional[float] = None, activation: str = "sigmoid", softmax_dim: Optional[int] = None, + class_weights: Optional[Tensor] = None, ): """ @@ -44,6 +45,10 @@ def __init__( self.activation = activation self.softmax_dim = softmax_dim + if class_weights is not None and not torch.is_tensor(class_weights): + class_weights = torch.tensor(list(class_weights), dtype=torch.float32) + self.register_buffer('class_weights', class_weights, persistent=False) + self.focal_loss_fn = partial( focal_loss_with_logits, alpha=alpha, @@ -64,7 +69,9 @@ def __repr__(self): repr = f"{self.__class__.__name__}(alpha={self.alpha}, gamma={self.gamma}, " repr += f"ignore_index={self.ignore_index}, reduction={self.reduction}, normalized={self.normalized}, " repr += f"reduced_threshold={self.reduced_threshold}, activation={self.activation}, " - repr += f"softmax_dim={self.softmax_dim})" + repr += f"softmax_dim={self.softmax_dim}," + repr += f"class_weights={self.class_weights.tolist()}, " + repr += f")" return repr def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: @@ -81,7 +88,7 @@ def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: if len(targets.shape) + 1 == len(inputs.shape): targets = self.get_one_hot_targets(targets, num_classes=inputs.size(1)) - loss = self.focal_loss_fn(inputs, targets) + loss = self.focal_loss_fn(inputs, targets, class_weights=self.class_weights) return loss def _one_hot_targets(self, targets, num_classes): diff --git a/pytorch_toolbelt/losses/functional.py b/pytorch_toolbelt/losses/functional.py index 12dc87572..b56ea3e14 100644 --- a/pytorch_toolbelt/losses/functional.py +++ b/pytorch_toolbelt/losses/functional.py @@ -29,6 +29,7 @@ def focal_loss_with_logits( ignore_index=None, activation: str = "sigmoid", softmax_dim: Optional[int] = None, + class_weights: Optional[torch.Tensor] = None ) -> torch.Tensor: """Compute binary focal loss between target and output logits. @@ -77,6 +78,13 @@ def focal_loss_with_logits( if alpha is not None: loss *= alpha * target + (1 - alpha) * (1 - target) + if class_weights is not None: + # class_weights is of shape [C] + # Loss is of shape [B,C ...] + # Reshape class_weights to [1, C, ...] + class_weights = class_weights.view(1, -1, *(1 for _ in range(loss.dim() - 2))) + loss *= class_weights + if ignore_index is not None: ignore_mask = target.eq(ignore_index) loss = torch.masked_fill(loss, ignore_mask, 0)