Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ def _on_evaluation_epoch_end(self) -> None:
def _store_dataloader_outputs(self) -> None:
trainer = self.trainer
trainer._logger_connector.epoch_end_reached()
# Sync on_epoch metrics across ranks and validate all ranks logged the same keys
trainer._logger_connector.sync_on_epoch_metrics()
self._logged_outputs.append(trainer._logger_connector.update_eval_epoch_metrics())

def _on_before_fetch(self) -> None:
Expand Down Expand Up @@ -442,6 +444,9 @@ def _evaluation_step(

self.batch_progress.increment_processed()

# Sync on_step metrics across ranks and validate all ranks logged the same keys
trainer._logger_connector.sync_on_step_metrics()

if using_dataloader_iter:
# update the hook kwargs now that the step method might have consumed the iterator
batch = data_fetcher._batch
Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,10 @@ def on_advance_end(self) -> None:
call._call_lightning_module_hook(trainer, "on_train_epoch_end")
call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=True)

# Sync on_epoch metrics across ranks and validate all ranks logged the same keys
# Must be called before on_epoch_end() which computes the metrics
trainer._logger_connector.sync_on_epoch_metrics()

trainer._logger_connector.on_epoch_end()

if not self.restarting and self.epoch_loop._num_ready_batches_reached():
Expand All @@ -489,6 +493,7 @@ def on_advance_end(self) -> None:
# we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics
# even when the batch loop has finished
self.epoch_loop._batches_that_stepped -= 1

# log epoch metrics
trainer._logger_connector.update_train_epoch_metrics()
self.epoch_loop._batches_that_stepped += 1
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def advance(self, data_fetcher: _DataFetcher) -> None:

self.batch_progress.increment_processed()

# Sync on_step metrics across ranks and validate all ranks logged the same keys
trainer._logger_connector.sync_on_step_metrics()

# update non-plateau LR schedulers
# update epoch-interval ones only when we are at the end of training epoch
self.update_lr_schedulers("step", update_plateau_schedulers=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,36 @@ def on_batch_start(self, batch: Any, dataloader_idx: Optional[int] = None) -> No
results.batch_size = None
results.dataloader_idx = dataloader_idx

def sync_on_step_metrics(self) -> None:
"""Synchronize on_step metrics across ranks.

This must be called at a point where ALL ranks are synchronized, typically right after
training_step/validation_step returns. It validates that all ranks logged the same metric keys with
sync_dist=True and performs the sync operations.

See
https://github.com/Lightning-AI/pytorch-lightning/issues/21409

"""
results = self.trainer._results
if results is not None:
results.sync_on_step_metrics()

def sync_on_epoch_metrics(self) -> None:
"""Synchronize on_epoch metrics across ranks.

This must be called at a point where ALL ranks are synchronized, typically at epoch end before metrics are
consumed. It validates that all ranks logged the same metric keys with sync_dist=True and performs the
compute/sync operations.

See
https://github.com/Lightning-AI/pytorch-lightning/issues/21409

"""
results = self.trainer._results
if results is not None:
results.sync_on_epoch_metrics()

def epoch_end_reached(self) -> None:
self._first_loop_iter = None

Expand Down
142 changes: 140 additions & 2 deletions src/lightning/pytorch/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Callable, Optional, Union, cast

import torch
import torch.distributed as dist
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torchmetrics import Metric
Expand Down Expand Up @@ -46,6 +47,42 @@ class _METRICS(TypedDict):
warning_cache = WarningCache()


def _assert_sync_dist_metric_keys_consistency(keys: list[str], fx: str, group: Optional[Any]) -> None:
"""Validate that all ranks have the same metric keys for sync_dist operations.

This function must be called at a synchronization point where ALL ranks are guaranteed
to participate. It uses all_gather_object to collect keys from all ranks and validates
they are identical.

Args:
keys: List of metric keys that need to be synchronized on this rank.
fx: The hook name (e.g., 'training_step') for error messages.
group: The process group to use for the collective operation.

Raises:
MisconfigurationException: If ranks have different metric keys.

"""
if not _distributed_is_initialized() or not dist.is_available():
return
world_size = dist.get_world_size(group=group)
if world_size <= 1:
return

gathered: list[object] = [None] * world_size
dist.all_gather_object(gathered, keys, group=group)
first = gathered[0]
if any(item != first for item in gathered[1:]):
ranks = "\n".join(f" rank={i}: {k}" for i, k in enumerate(gathered))
raise MisconfigurationException(
"When logging with `sync_dist=True`, all processes must log the same metric keys in the same order "
f"within a given hook. Detected a mismatch during `{fx}`.\n"
f"Synchronized metric keys per rank:\n{ranks}\n"
"Either log the same keys on all ranks (for example by logging dummy values), or set `sync_dist=False` "
"and manually synchronize (for example using `all_gather`)."
)


@dataclass
class _Sync:
fn: Optional[Callable] = None
Expand Down Expand Up @@ -202,6 +239,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum)
# this is defined here only because upstream is missing the type annotation
self._forward_cache: Optional[Any] = None
self._forward_cache_synced: bool = False

@override
def update(self, value: _VALUE, batch_size: int) -> None:
Expand All @@ -222,7 +260,10 @@ def update(self, value: _VALUE, batch_size: int) -> None:
value = value.to(dtype)

if self.meta.on_step:
self._forward_cache = self.meta.sync(value.clone()) # `clone` because `sync` is in-place
# Defer sync to sync_on_step_metrics() which is called at a controlled synchronization point
# This allows validating that all ranks have the same metric keys before syncing
self._forward_cache = value.clone()
self._forward_cache_synced = False
# performance: no need to accumulate on values only logged on_step
if not self.meta.on_epoch:
self.value = self._forward_cache
Expand All @@ -239,7 +280,7 @@ def update(self, value: _VALUE, batch_size: int) -> None:
self.value = self.value + value
else:
value = cast(Metric, value)
self.value = value
self.value = value # type: ignore[assignment]
self._forward_cache = value._forward_cache

@override
Expand Down Expand Up @@ -421,6 +462,103 @@ def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
result_metric.forward(value, batch_size)
result_metric.has_reset = False

def sync_on_step_metrics(self) -> None:
"""Synchronize all on_step metrics that have sync_dist=True.

This method must be called at a point where ALL ranks are synchronized (e.g., after
training_step/validation_step returns). It:
1. Gathers all metric keys that need syncing from all ranks
2. Validates that all ranks have the same keys in the same order
3. Performs the sync operations in a deterministic order

This approach prevents the silent data corruption that occurs when ranks log different
metric keys with sync_dist=True.

See https://github.com/Lightning-AI/pytorch-lightning/issues/21409

"""
if not _distributed_is_initialized():
return

# Collect all metrics that need on_step sync
items_to_sync: list[tuple[str, _ResultMetric]] = []
for key, result_metric in self.valid_items():
if (
result_metric.meta.on_step
and result_metric.is_tensor
and result_metric.meta.sync.should
and not result_metric.meta.sync.rank_zero_only
and not result_metric._forward_cache_synced
and result_metric._forward_cache is not None
):
items_to_sync.append((key, result_metric))

if not items_to_sync:
return

# Get keys in order for validation
keys = [key for key, _ in items_to_sync]
fx = items_to_sync[0][1].meta.fx
group = items_to_sync[0][1].meta.sync.group

# Validate all ranks have the same keys (this is a collective operation)
_assert_sync_dist_metric_keys_consistency(keys, fx, group)

# Now perform the actual sync for each metric in order
for _, result_metric in items_to_sync:
if result_metric._forward_cache is not None:
synced_value = result_metric.meta.sync(result_metric._forward_cache.clone())
result_metric._forward_cache = synced_value
result_metric._forward_cache_synced = True
# Also update value if this is on_step only (not accumulated for on_epoch)
if not result_metric.meta.on_epoch:
result_metric.value = synced_value

def sync_on_epoch_metrics(self) -> None:
"""Synchronize all on_epoch metrics that have sync_dist=True.

This method must be called at a point where ALL ranks are synchronized (e.g., at
epoch end before metrics are consumed). It:
1. Gathers all metric keys that need syncing from all ranks
2. Validates that all ranks have the same keys in the same order
3. Performs the compute() (which includes sync) in a deterministic order

This approach prevents the silent data corruption that occurs when ranks log different
metric keys with sync_dist=True.

See https://github.com/Lightning-AI/pytorch-lightning/issues/21409

"""
if not _distributed_is_initialized():
return

# Collect all metrics that need on_epoch sync (not yet computed)
items_to_sync: list[tuple[str, _ResultMetric]] = []
for key, result_metric in self.valid_items():
if (
result_metric.meta.on_epoch
and result_metric.is_tensor
and result_metric.meta.sync.should
and not result_metric.meta.sync.rank_zero_only
and result_metric._computed is None # Not yet computed/synced
):
items_to_sync.append((key, result_metric))

if not items_to_sync:
return

# Get keys in order for validation
keys = [key for key, _ in items_to_sync]
fx = items_to_sync[0][1].meta.fx
group = items_to_sync[0][1].meta.sync.group

# Validate all ranks have the same keys (this is a collective operation)
_assert_sync_dist_metric_keys_consistency(keys, fx, group)

# Now perform the actual compute (which includes sync) for each metric in order
for _, result_metric in items_to_sync:
result_metric.compute()

@staticmethod
def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
cache = None
Expand Down
Loading
Loading