Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Losses
ctc_loss
ctc_loss_with_forward_probs
dice_loss
dice_plus_ce_loss
hinge_loss
huber_loss
kl_divergence
Expand Down Expand Up @@ -60,6 +61,10 @@ Dice loss
.. autofunction:: multiclass_generalized_dice_loss
.. autofunction:: binary_dice_loss

Dice plus Cross-Entropy loss
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: dice_plus_ce_loss

Fenchel Young loss
~~~~~~~~~~~~~~~~~~
.. autofunction:: make_fenchel_young_loss
Expand Down
1 change: 1 addition & 0 deletions optax/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from optax.losses._segmentation import binary_dice_loss
from optax.losses._segmentation import dice_loss
from optax.losses._segmentation import multiclass_generalized_dice_loss
from optax.losses._segmentation import dice_plus_ce_loss
from optax.losses._self_supervised import ntxent
from optax.losses._self_supervised import triplet_margin_loss
from optax.losses._smoothing import smooth_labels
172 changes: 145 additions & 27 deletions optax/losses/_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,77 @@
import jax
import jax.numpy as jnp

from optax.losses._classification import kl_divergence
from optax.losses._classification import softmax_cross_entropy


def _reduce_loss(
loss: chex.Array, reduction: str, axis: Optional[int] = None
) -> chex.Array:
if reduction == "mean":
return jnp.mean(loss, axis=axis)
elif reduction == "sum":
return jnp.sum(loss, axis=axis)
elif reduction == "none":
return loss
else:
raise ValueError(f"Unsupported reduction: {reduction}")


def dice_loss(
predictions: chex.Array,
targets: chex.Array,
*,
class_weights: Optional[chex.Array] = None,
smooth: float = 1.0,
alpha: float = 0.5,
beta: float = 0.5,
apply_softmax: bool = True,
reduction: str = "mean",
ignore_background: bool = False,
axis: Optional[chex.Array] = None,
) -> chex.Array:
r"""Computes the Dice Loss for multi-class segmentation.

Computes the Soft Dice Loss for segmentation tasks. Works for both binary
and multi-class segmentation. For binary segmentation, use targets with
shape [..., 1] or [...] and predictions with corresponding logits.
Computes the Soft Dice Loss for segmentation tasks. This implementation
includes parameters to weigh false positives and false negatives, making it
a generalization of the standard Dice Loss. Works for both binary and
multi-class segmentation.

The loss is computed per class and then averaged (or summed) across classes.
For class c:

.. math::
intersection_c = \sum_i^{N} p_{i,c} \cdot t_{i,c}
\\
dice_c = \frac{2 \cdot intersection_c + smooth}{
\sum_i^{N} p_{i,c} + \sum_i^{N} t_{i,c} + smooth
dice_c = 1 - \frac{
intersection_c + smooth
}{
intersection_c +
\alpha \cdot (P_c - intersection_c) +
\beta \cdot (T_c - intersection_c) +
smooth
}

where:
- :math:`p_{i,c}` is the predicted probability for class c at pixel i
- :math:`t_{i,c}` is the target value (0 or 1) for class c at pixel i
- N is the total number of pixels
- :math:`p_{i,c}`: predicted probability for class c at pixel i.
- :math:`t_{i,c}`: target value (0 or 1) for class c at pixel i.
- :math:`P_c = \sum_i p_{i,c}` (sum of predicted probabilities
for class c)
- :math:`T_c = \sum_i t_{i,c}` (sum of target values for class c)
- :math:`\alpha`: weight for false positives
(:math:`FP_c = P_c - intersection_c`).
- :math:`\beta`: weight for false negatives
(:math:`FN_c = T_c - intersection_c`).

Note: This formulation differs from the traditional Dice coefficient.
When :math:`\alpha = \beta = 0.5`, this gives:
:math:`(intersection + smooth) / (0.5 \cdot (P_c + T_c) + smooth)`,
which equals :math:`(2 \cdot intersection + 2 \cdot smooth) /
(P_c + T_c + 2 \cdot smooth)`.
To match the traditional Dice coefficient formula
:math:`(2 \cdot intersection + smooth) / (P_c + T_c + smooth)`,
use :math:`\alpha = \beta = 0.5` with :math:`smooth / 2`.

Args:
predictions: Logits of shape [..., num_classes] for multi-class or
Expand All @@ -63,6 +102,8 @@ def dice_loss(
If None, all classes weighted equally.
smooth: Smoothing parameter to avoid division by zero and improve
gradient stability.
alpha: Weight for false positives. Defaults to 0.5 (standard Dice).
beta: Weight for false negatives. Defaults to 0.5 (standard Dice).
apply_softmax: Whether to apply softmax to predictions. Set False if
predictions are already probabilities.
reduction: How to reduce across classes: 'mean', 'sum', or 'none'.
Expand All @@ -81,7 +122,7 @@ def dice_loss(
- 'none': [..., num_classes] (includes class dimension)

Examples:
Binary segmentation:
Binary segmentation (standard Dice):

>>> import jax.numpy as jnp
>>> from optax.losses import dice_loss
Expand All @@ -91,14 +132,16 @@ def dice_loss(
>>> loss.shape
(2,)

Multi-class segmentation:
Multi-class Dice with custom weighting for false positives/negatives:

>>> import jax
>>> key = jax.random.PRNGKey(0)
>>> logits = jax.random.normal(key, (2, 4, 4, 3)) # 2 samples, 3 classes
>>> labels = jax.random.randint(key, (2, 4, 4), 0, 3) # Random labels
>>> targets = jax.nn.one_hot(labels, 3) # One-hot encoded
>>> loss = dice_loss(logits, targets)
>>> logits = jax.random.normal(key, (2, 4, 4, 3))
>>> labels = jax.random.randint(key, (2, 4, 4), 0, 3)
>>> targets = jax.nn.one_hot(labels, 3)
>>> loss = dice_loss(
... logits, targets, alpha=0.3, beta=0.7
... )
>>> loss.shape
(2,)

Expand Down Expand Up @@ -151,9 +194,16 @@ def dice_loss(
pred_sum = jnp.sum(probs, axis=axis)
target_sum = jnp.sum(targets, axis=axis)

# Compute Dice coefficient per class
dice_coeff = (2.0 * intersection + smooth) / (pred_sum + target_sum + smooth)
dice_l = 1.0 - dice_coeff # [..., classes]
# Generalized Dice calculation
numerator = intersection + smooth
denominator = (
intersection
+ alpha * (pred_sum - intersection)
+ beta * (target_sum - intersection)
+ smooth
)
coeff = numerator / denominator
dice_l = 1.0 - coeff # [..., classes]

# Apply class weights if provided
if class_weights is not None:
Expand All @@ -167,16 +217,7 @@ def dice_loss(
dice_l = dice_l[..., 1:]

# Reduce across classes according to reduction parameter
if reduction == "mean":
dice_l = jnp.mean(dice_l, axis=-1)
elif reduction == "sum":
dice_l = jnp.sum(dice_l, axis=-1)
elif reduction == "none":
pass # Keep per-class losses
else:
raise ValueError(
f"reduction must be 'mean', 'sum', or 'none', got {reduction}"
)
dice_l = _reduce_loss(dice_l, reduction, axis=-1)

return dice_l

Expand Down Expand Up @@ -236,6 +277,83 @@ def multiclass_generalized_dice_loss(
)


def dice_plus_ce_loss(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the spirit of modularity, maybe this can be removed? Can the user add the dice and ce losses themselves? Do you see an advantage to exposing this as a single function?

predictions: chex.Array,
targets: chex.Array,
*,
dice_weight: float = 1.0,
ce_weight: float = 1.0,
apply_softmax: bool = True,
dice_kwargs: Optional[dict] = None,
ce_kwargs: Optional[dict] = None,
) -> chex.Array:
"""Computes a combined Dice and Cross-Entropy loss.

This loss is frequently used in segmentation tasks, as it combines the
strengths of both Dice Loss (good for class imbalance) and Cross-Entropy
(good for pixel-level accuracy).

The final loss is `dice_weight * dice_loss + ce_weight * ce_loss`.

Args:
predictions: Logits or probabilities, depending on `apply_softmax`.
Shape: [..., num_classes].
targets: One-hot encoded targets of shape [..., num_classes].
dice_weight: Weight for the Dice loss component.
ce_weight: Weight for the Cross-Entropy loss component.
apply_softmax: Whether to apply softmax to predictions. If True,
predictions are assumed to be logits; if False, they are assumed to
be probabilities.
dice_kwargs: Optional dictionary of keyword arguments for the
:func:`~optax.losses.dice_loss`.
ce_kwargs: Optional dictionary of keyword arguments for the
:func:`~optax.losses.softmax_cross_entropy` (if `apply_softmax=True`)
or :func:`~optax.losses.kl_divergence` (if `apply_softmax=False`).

Returns:
Combined loss value, with shape determined by the reduction of the
underlying losses (typically `[...]` for batch dimensions).

Examples:
>>> import jax
>>> import jax.numpy as jnp
>>> from optax.losses._segmentation import dice_plus_ce_loss
>>> key = jax.random.PRNGKey(0)
>>> logits = jax.random.normal(key, (2, 4, 4, 3))
>>> labels = jax.random.randint(key, (2, 4, 4), 0, 3)
>>> targets = jax.nn.one_hot(labels, 3)
>>> loss = dice_plus_ce_loss(logits, targets)
>>> loss.shape
(2,)
"""
dice_kwargs = dice_kwargs or {}
ce_kwargs = ce_kwargs or {}

# Pass `apply_softmax` to dice_loss unless explicitly overridden
if "apply_softmax" not in dice_kwargs:
dice_kwargs["apply_softmax"] = apply_softmax

d_loss = dice_loss(predictions, targets, **dice_kwargs)

if apply_softmax:
# Cross-entropy on logits
ce_loss_per_pixel = softmax_cross_entropy(predictions, targets, **ce_kwargs)
else:
# Cross-entropy on probabilities, using KL divergence
# Add a small epsilon to prevent log(0)
log_predictions = jnp.log(jnp.clip(predictions, 1e-7, 1.0))
ce_loss_per_pixel = kl_divergence(log_predictions, targets, **ce_kwargs)

# Reduce cross-entropy loss to have the same shape as the dice loss
axis_to_reduce = tuple(range(1, ce_loss_per_pixel.ndim))
if axis_to_reduce:
ce_loss = jnp.mean(ce_loss_per_pixel, axis=axis_to_reduce)
else:
ce_loss = ce_loss_per_pixel # Already per-batch

return dice_weight * d_loss + ce_weight * ce_loss


def binary_dice_loss(
predictions: chex.Array,
targets: chex.Array,
Expand Down
Loading
Loading