diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 6036e57cf59ae..7a0dba61a62dc 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -383,6 +383,8 @@ def _on_evaluation_epoch_end(self) -> None: def _store_dataloader_outputs(self) -> None: trainer = self.trainer trainer._logger_connector.epoch_end_reached() + # Sync on_epoch metrics across ranks and validate all ranks logged the same keys + trainer._logger_connector.sync_on_epoch_metrics() self._logged_outputs.append(trainer._logger_connector.update_eval_epoch_metrics()) def _on_before_fetch(self) -> None: @@ -442,6 +444,9 @@ def _evaluation_step( self.batch_progress.increment_processed() + # Sync on_step metrics across ranks and validate all ranks logged the same keys + trainer._logger_connector.sync_on_step_metrics() + if using_dataloader_iter: # update the hook kwargs now that the step method might have consumed the iterator batch = data_fetcher._batch diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 8bb123939dc20..31196873508a5 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -479,6 +479,10 @@ def on_advance_end(self) -> None: call._call_lightning_module_hook(trainer, "on_train_epoch_end") call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=True) + # Sync on_epoch metrics across ranks and validate all ranks logged the same keys + # Must be called before on_epoch_end() which computes the metrics + trainer._logger_connector.sync_on_epoch_metrics() + trainer._logger_connector.on_epoch_end() if not self.restarting and self.epoch_loop._num_ready_batches_reached(): @@ -489,6 +493,7 @@ def on_advance_end(self) -> None: # we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics # even when the batch loop has finished self.epoch_loop._batches_that_stepped -= 1 + # log epoch metrics trainer._logger_connector.update_train_epoch_metrics() self.epoch_loop._batches_that_stepped += 1 diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 6212bfe264e6e..1552a5b8b1b04 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -355,6 +355,9 @@ def advance(self, data_fetcher: _DataFetcher) -> None: self.batch_progress.increment_processed() + # Sync on_step metrics across ranks and validate all ranks logged the same keys + trainer._logger_connector.sync_on_step_metrics() + # update non-plateau LR schedulers # update epoch-interval ones only when we are at the end of training epoch self.update_lr_schedulers("step", update_plateau_schedulers=False) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py index f6e8885ee050a..6f177d93e0dc7 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py @@ -190,6 +190,36 @@ def on_batch_start(self, batch: Any, dataloader_idx: Optional[int] = None) -> No results.batch_size = None results.dataloader_idx = dataloader_idx + def sync_on_step_metrics(self) -> None: + """Synchronize on_step metrics across ranks. + + This must be called at a point where ALL ranks are synchronized, typically right after + training_step/validation_step returns. It validates that all ranks logged the same metric keys with + sync_dist=True and performs the sync operations. + + See + https://github.com/Lightning-AI/pytorch-lightning/issues/21409 + + """ + results = self.trainer._results + if results is not None: + results.sync_on_step_metrics() + + def sync_on_epoch_metrics(self) -> None: + """Synchronize on_epoch metrics across ranks. + + This must be called at a point where ALL ranks are synchronized, typically at epoch end before metrics are + consumed. It validates that all ranks logged the same metric keys with sync_dist=True and performs the + compute/sync operations. + + See + https://github.com/Lightning-AI/pytorch-lightning/issues/21409 + + """ + results = self.trainer._results + if results is not None: + results.sync_on_epoch_metrics() + def epoch_end_reached(self) -> None: self._first_loop_iter = None diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 7c9364b5ddfe1..f56dff2c5517e 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -17,6 +17,7 @@ from typing import Any, Callable, Optional, Union, cast import torch +import torch.distributed as dist from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torchmetrics import Metric @@ -46,6 +47,42 @@ class _METRICS(TypedDict): warning_cache = WarningCache() +def _assert_sync_dist_metric_keys_consistency(keys: list[str], fx: str, group: Optional[Any]) -> None: + """Validate that all ranks have the same metric keys for sync_dist operations. + + This function must be called at a synchronization point where ALL ranks are guaranteed + to participate. It uses all_gather_object to collect keys from all ranks and validates + they are identical. + + Args: + keys: List of metric keys that need to be synchronized on this rank. + fx: The hook name (e.g., 'training_step') for error messages. + group: The process group to use for the collective operation. + + Raises: + MisconfigurationException: If ranks have different metric keys. + + """ + if not _distributed_is_initialized() or not dist.is_available(): + return + world_size = dist.get_world_size(group=group) + if world_size <= 1: + return + + gathered: list[object] = [None] * world_size + dist.all_gather_object(gathered, keys, group=group) + first = gathered[0] + if any(item != first for item in gathered[1:]): + ranks = "\n".join(f" rank={i}: {k}" for i, k in enumerate(gathered)) + raise MisconfigurationException( + "When logging with `sync_dist=True`, all processes must log the same metric keys in the same order " + f"within a given hook. Detected a mismatch during `{fx}`.\n" + f"Synchronized metric keys per rank:\n{ranks}\n" + "Either log the same keys on all ranks (for example by logging dummy values), or set `sync_dist=False` " + "and manually synchronize (for example using `all_gather`)." + ) + + @dataclass class _Sync: fn: Optional[Callable] = None @@ -202,6 +239,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum) # this is defined here only because upstream is missing the type annotation self._forward_cache: Optional[Any] = None + self._forward_cache_synced: bool = False @override def update(self, value: _VALUE, batch_size: int) -> None: @@ -222,7 +260,10 @@ def update(self, value: _VALUE, batch_size: int) -> None: value = value.to(dtype) if self.meta.on_step: - self._forward_cache = self.meta.sync(value.clone()) # `clone` because `sync` is in-place + # Defer sync to sync_on_step_metrics() which is called at a controlled synchronization point + # This allows validating that all ranks have the same metric keys before syncing + self._forward_cache = value.clone() + self._forward_cache_synced = False # performance: no need to accumulate on values only logged on_step if not self.meta.on_epoch: self.value = self._forward_cache @@ -239,7 +280,7 @@ def update(self, value: _VALUE, batch_size: int) -> None: self.value = self.value + value else: value = cast(Metric, value) - self.value = value + self.value = value # type: ignore[assignment] self._forward_cache = value._forward_cache @override @@ -421,6 +462,103 @@ def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None: result_metric.forward(value, batch_size) result_metric.has_reset = False + def sync_on_step_metrics(self) -> None: + """Synchronize all on_step metrics that have sync_dist=True. + + This method must be called at a point where ALL ranks are synchronized (e.g., after + training_step/validation_step returns). It: + 1. Gathers all metric keys that need syncing from all ranks + 2. Validates that all ranks have the same keys in the same order + 3. Performs the sync operations in a deterministic order + + This approach prevents the silent data corruption that occurs when ranks log different + metric keys with sync_dist=True. + + See https://github.com/Lightning-AI/pytorch-lightning/issues/21409 + + """ + if not _distributed_is_initialized(): + return + + # Collect all metrics that need on_step sync + items_to_sync: list[tuple[str, _ResultMetric]] = [] + for key, result_metric in self.valid_items(): + if ( + result_metric.meta.on_step + and result_metric.is_tensor + and result_metric.meta.sync.should + and not result_metric.meta.sync.rank_zero_only + and not result_metric._forward_cache_synced + and result_metric._forward_cache is not None + ): + items_to_sync.append((key, result_metric)) + + if not items_to_sync: + return + + # Get keys in order for validation + keys = [key for key, _ in items_to_sync] + fx = items_to_sync[0][1].meta.fx + group = items_to_sync[0][1].meta.sync.group + + # Validate all ranks have the same keys (this is a collective operation) + _assert_sync_dist_metric_keys_consistency(keys, fx, group) + + # Now perform the actual sync for each metric in order + for _, result_metric in items_to_sync: + if result_metric._forward_cache is not None: + synced_value = result_metric.meta.sync(result_metric._forward_cache.clone()) + result_metric._forward_cache = synced_value + result_metric._forward_cache_synced = True + # Also update value if this is on_step only (not accumulated for on_epoch) + if not result_metric.meta.on_epoch: + result_metric.value = synced_value + + def sync_on_epoch_metrics(self) -> None: + """Synchronize all on_epoch metrics that have sync_dist=True. + + This method must be called at a point where ALL ranks are synchronized (e.g., at + epoch end before metrics are consumed). It: + 1. Gathers all metric keys that need syncing from all ranks + 2. Validates that all ranks have the same keys in the same order + 3. Performs the compute() (which includes sync) in a deterministic order + + This approach prevents the silent data corruption that occurs when ranks log different + metric keys with sync_dist=True. + + See https://github.com/Lightning-AI/pytorch-lightning/issues/21409 + + """ + if not _distributed_is_initialized(): + return + + # Collect all metrics that need on_epoch sync (not yet computed) + items_to_sync: list[tuple[str, _ResultMetric]] = [] + for key, result_metric in self.valid_items(): + if ( + result_metric.meta.on_epoch + and result_metric.is_tensor + and result_metric.meta.sync.should + and not result_metric.meta.sync.rank_zero_only + and result_metric._computed is None # Not yet computed/synced + ): + items_to_sync.append((key, result_metric)) + + if not items_to_sync: + return + + # Get keys in order for validation + keys = [key for key, _ in items_to_sync] + fx = items_to_sync[0][1].meta.fx + group = items_to_sync[0][1].meta.sync.group + + # Validate all ranks have the same keys (this is a collective operation) + _assert_sync_dist_metric_keys_consistency(keys, fx, group) + + # Now perform the actual compute (which includes sync) for each metric in order + for _, result_metric in items_to_sync: + result_metric.compute() + @staticmethod def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]: cache = None diff --git a/tests/tests_pytorch/core/test_results.py b/tests/tests_pytorch/core/test_results.py index c1d50e8458da6..d5ecdc003d555 100644 --- a/tests/tests_pytorch/core/test_results.py +++ b/tests/tests_pytorch/core/test_results.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial +from unittest.mock import patch import pytest import torch @@ -21,7 +22,12 @@ from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher -from lightning.pytorch.trainer.connectors.logger_connector.result import _Sync +from lightning.pytorch.trainer.connectors.logger_connector.result import ( + _assert_sync_dist_metric_keys_consistency, + _ResultCollection, + _Sync, +) +from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.runif import RunIf from tests_pytorch.models.test_tpu import wrap_launch_function @@ -54,3 +60,664 @@ def result_reduce_ddp_fn(strategy): @RunIf(skip_windows=True) def test_result_reduce_ddp(): spawn_launch(result_reduce_ddp_fn, [torch.device("cpu")] * 2) + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=False) +def test_assert_sync_dist_metric_keys_consistency_not_initialized(_): + """Test that _assert_sync_dist_metric_keys_consistency returns early when dist is not initialized.""" + # Should not raise even with mismatched keys since dist is not initialized + _assert_sync_dist_metric_keys_consistency(["key_a"], "training_step", None) + + +@patch("torch.distributed.is_available", return_value=False) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_assert_sync_dist_metric_keys_consistency_dist_not_available(_, __): + """Test that _assert_sync_dist_metric_keys_consistency returns early when dist is not available.""" + # Should not raise even with mismatched keys since dist is not available + _assert_sync_dist_metric_keys_consistency(["key_a"], "training_step", None) + + +@patch("torch.distributed.get_world_size", return_value=1) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_assert_sync_dist_metric_keys_consistency_single_process(_, __, ___): + """Test that _assert_sync_dist_metric_keys_consistency returns early with world_size=1.""" + # Should not raise with single process + _assert_sync_dist_metric_keys_consistency(["key_a"], "training_step", None) + + +def _mock_all_gather_consistent(output_list, obj, group=None): + """Mock all_gather_object that simulates consistent keys across 2 ranks.""" + output_list[0] = obj + output_list[1] = obj + + +def _mock_all_gather_inconsistent(output_list, obj, group=None): + """Mock all_gather_object that simulates inconsistent keys across 2 ranks.""" + output_list[0] = ["training_step.loss", "training_step.metric_a"] + output_list[1] = ["training_step.loss", "training_step.metric_b"] + + +def _mock_all_gather_order_mismatch(output_list, obj, group=None): + """Mock all_gather_object that simulates same keys but different order across 2 ranks.""" + output_list[0] = ["training_step.metric_a", "training_step.metric_b"] + output_list[1] = ["training_step.metric_b", "training_step.metric_a"] + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_consistent) +@patch("torch.distributed.get_world_size", return_value=2) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_assert_sync_dist_metric_keys_consistency_matching_keys_mocked(_, __, ___, ____): + """Test key consistency validation with matching keys using mocked distributed functions.""" + keys = ["training_step.loss", "training_step.acc"] + # Should not raise when all ranks have matching keys + _assert_sync_dist_metric_keys_consistency(keys, "training_step", None) + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_inconsistent) +@patch("torch.distributed.get_world_size", return_value=2) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_assert_sync_dist_metric_keys_consistency_mismatched_keys_mocked(_, __, ___, ____): + """Test key consistency validation raises with mismatched keys using mocked distributed functions.""" + keys = ["training_step.loss", "training_step.metric_a"] + with pytest.raises(MisconfigurationException) as excinfo: + _assert_sync_dist_metric_keys_consistency(keys, "training_step", None) + + message = str(excinfo.value) + assert "sync_dist=True" in message + assert "Detected a mismatch" in message + assert "training_step.metric_a" in message + assert "training_step.metric_b" in message + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_order_mismatch) +@patch("torch.distributed.get_world_size", return_value=2) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_assert_sync_dist_metric_keys_consistency_order_mismatch_mocked(_, __, ___, ____): + """Test key consistency validation raises with different key order using mocked distributed functions.""" + keys = ["training_step.metric_a", "training_step.metric_b"] + with pytest.raises(MisconfigurationException) as excinfo: + _assert_sync_dist_metric_keys_consistency(keys, "training_step", None) + + message = str(excinfo.value) + assert "sync_dist=True" in message + assert "Detected a mismatch" in message + + +def _sync_dist_keys_consistency_ddp_fn(strategy): + """Consolidated function to test key consistency validation in DDP mode.""" + rank = dist.get_rank() + + # Test 1: Matching keys should not raise + keys = ["training_step.loss", "training_step.acc"] + _assert_sync_dist_metric_keys_consistency(keys, "training_step", None) + + # Test 2: Empty keys should not raise + empty_keys: list[str] = [] + _assert_sync_dist_metric_keys_consistency(empty_keys, "training_step", None) + + # Test 3: Mismatched keys should raise + mismatched_keys = ["training_step.metric_a"] if rank == 0 else ["training_step.metric_b"] + with pytest.raises(MisconfigurationException) as excinfo: + _assert_sync_dist_metric_keys_consistency(mismatched_keys, "training_step", None) + message = str(excinfo.value) + assert "sync_dist=True" in message + assert "Detected a mismatch" in message + + # Test 4: Different key order should raise + if rank == 0: + order_keys = ["training_step.metric_x", "training_step.metric_y"] + else: + order_keys = ["training_step.metric_y", "training_step.metric_x"] + with pytest.raises(MisconfigurationException) as excinfo: + _assert_sync_dist_metric_keys_consistency(order_keys, "training_step", None) + message = str(excinfo.value) + assert "sync_dist=True" in message + assert "Detected a mismatch" in message + + +@pytest.mark.flaky(reruns=3) +@RunIf(skip_windows=True) +def test_assert_sync_dist_metric_keys_consistency_ddp(): + """Test _assert_sync_dist_metric_keys_consistency in DDP mode (consolidated tests).""" + spawn_launch(_sync_dist_keys_consistency_ddp_fn, [torch.device("cpu")] * 2) + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=False) +def test_sync_on_step_metrics_not_distributed(_): + """Test that sync_on_step_metrics returns early when not in distributed mode.""" + result = _ResultCollection(training=True) + result.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + # Should not raise, just return early + result.sync_on_step_metrics() + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=False) +def test_sync_on_epoch_metrics_not_distributed(_): + """Test that sync_on_epoch_metrics returns early when not in distributed mode.""" + result = _ResultCollection(training=True) + result.log("training_step", "loss", torch.tensor(1.0), on_step=False, on_epoch=True, sync_dist=True) + # Should not raise, just return early + result.sync_on_epoch_metrics() + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._assert_sync_dist_metric_keys_consistency") +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_step_metrics_no_items_to_sync(_, mock_assert): + """Test that sync_on_step_metrics returns early when no items need syncing.""" + result = _ResultCollection(training=True) + # Log without sync_dist=True + result.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=False) + result.sync_on_step_metrics() + # Should not call the validation since there are no items to sync + mock_assert.assert_not_called() + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._assert_sync_dist_metric_keys_consistency") +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_epoch_metrics_no_items_to_sync(_, mock_assert): + """Test that sync_on_epoch_metrics returns early when no items need syncing.""" + result = _ResultCollection(training=True) + # Log without sync_dist=True + result.log("training_step", "loss", torch.tensor(1.0), on_step=False, on_epoch=True, sync_dist=False) + result.sync_on_epoch_metrics() + # Should not call the validation since there are no items to sync + mock_assert.assert_not_called() + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._assert_sync_dist_metric_keys_consistency") +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_step_metrics_rank_zero_only_skipped(_, mock_assert): + """Test that rank_zero_only metrics are skipped from sync validation.""" + result = _ResultCollection(training=True) + # Log with rank_zero_only=True - should be skipped from sync validation + result.log( + "training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True, rank_zero_only=True + ) + result.sync_on_step_metrics() + # Should not call the validation since rank_zero_only metrics are skipped + mock_assert.assert_not_called() + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._assert_sync_dist_metric_keys_consistency") +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_epoch_metrics_rank_zero_only_skipped(_, mock_assert): + """Test that rank_zero_only metrics are skipped from sync validation.""" + result = _ResultCollection(training=True) + # Log with rank_zero_only=True - should be skipped from sync validation + result.log( + "training_step", "loss", torch.tensor(1.0), on_step=False, on_epoch=True, sync_dist=True, rank_zero_only=True + ) + result.sync_on_epoch_metrics() + # Should not call the validation since rank_zero_only metrics are skipped + mock_assert.assert_not_called() + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=False) +def test_result_metric_deferred_sync_behavior(_): + """Test that on_step metrics defer sync until sync_on_step_metrics is called.""" + result = _ResultCollection(training=True) + + # Log with on_step=True and sync_dist=True + result.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + + # Before sync_on_step_metrics is called, forward cache should be set but not synced + loss_metric = result["training_step.loss"] + assert loss_metric._forward_cache is not None, "Forward cache should be set" + assert loss_metric._forward_cache_synced is False, "Forward cache should not be synced yet" + + # In non-distributed mode, sync_on_step_metrics returns early + result.sync_on_step_metrics() + # Still not synced because we're not in distributed mode + assert loss_metric._forward_cache_synced is False + + +def _mock_sync_fn(value, reduce_op=None, group=None): + """Mock sync function that returns the value unchanged.""" + return value + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_consistent) +@patch("torch.distributed.get_world_size", return_value=2) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_step_metrics_with_mocked_distributed(_, __, ___, ____): + """Test sync_on_step_metrics executes sync logic with mocked distributed functions.""" + result = _ResultCollection(training=True) + result.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + + loss_metric = result["training_step.loss"] + # Override the sync function with our mock + loss_metric.meta._sync.fn = _mock_sync_fn + loss_metric.meta._sync._should = True + + assert loss_metric._forward_cache is not None + assert loss_metric._forward_cache_synced is False + + result.sync_on_step_metrics() + + # After sync, forward cache should be marked as synced + assert loss_metric._forward_cache_synced is True + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_consistent) +@patch("torch.distributed.get_world_size", return_value=2) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_epoch_metrics_with_mocked_distributed(_, __, ___, ____): + """Test sync_on_epoch_metrics executes sync logic with mocked distributed functions.""" + result = _ResultCollection(training=True) + result.log("training_step", "loss", torch.tensor(1.0), on_step=False, on_epoch=True, sync_dist=True) + + loss_metric = result["training_step.loss"] + # Override the sync function with our mock + loss_metric.meta._sync.fn = _mock_sync_fn + loss_metric.meta._sync._should = True + + assert loss_metric._computed is None + + result.sync_on_epoch_metrics() + + # After sync, compute should have been called + assert loss_metric._computed is not None + + +def _sync_metrics_ddp_fn(strategy): + """Consolidated function to test sync_on_step_metrics and sync_on_epoch_metrics in DDP mode.""" + rank = dist.get_rank() + + # Test 1: sync_on_step_metrics with consistent keys + result1 = _ResultCollection(training=True) + result1.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + result1.log("training_step", "acc", torch.tensor(0.9), on_step=True, on_epoch=False, sync_dist=True) + loss_metric1 = result1["training_step.loss"] + assert loss_metric1._forward_cache is not None, "Forward cache should be set" + assert loss_metric1._forward_cache_synced is False, "Forward cache should not be synced yet" + result1.sync_on_step_metrics() + assert loss_metric1._forward_cache_synced is True + + # Test 2: sync_on_epoch_metrics with consistent keys + result2 = _ResultCollection(training=True) + result2.log("training_step", "loss", torch.tensor(1.0), on_step=False, on_epoch=True, sync_dist=True) + result2.log("training_step", "acc", torch.tensor(0.9), on_step=False, on_epoch=True, sync_dist=True) + result2.sync_on_epoch_metrics() + loss_metric2 = result2["training_step.loss"] + assert loss_metric2._computed is not None + + # Test 3: sync_on_step_metrics with mismatched keys should raise + result3 = _ResultCollection(training=True) + result3.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + if rank == 0: + result3.log("training_step", "metric_a", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + else: + result3.log("training_step", "metric_b", torch.tensor(2.0), on_step=True, on_epoch=False, sync_dist=True) + with pytest.raises(MisconfigurationException) as excinfo: + result3.sync_on_step_metrics() + message = str(excinfo.value) + assert "sync_dist=True" in message + assert "Detected a mismatch" in message + + # Test 4: sync_on_epoch_metrics with mismatched keys should raise + result4 = _ResultCollection(training=True) + result4.log("training_step", "loss", torch.tensor(1.0), on_step=False, on_epoch=True, sync_dist=True) + if rank == 0: + result4.log("training_step", "metric_c", torch.tensor(1.0), on_step=False, on_epoch=True, sync_dist=True) + else: + result4.log("training_step", "metric_d", torch.tensor(2.0), on_step=False, on_epoch=True, sync_dist=True) + with pytest.raises(MisconfigurationException) as excinfo: + result4.sync_on_epoch_metrics() + message = str(excinfo.value) + assert "sync_dist=True" in message + assert "Detected a mismatch" in message + + # Test 5: sync_on_step_metrics updates value for on_step only metrics + result5 = _ResultCollection(training=True) + result5.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + loss_metric5 = result5["training_step.loss"] + assert loss_metric5._forward_cache is not None + assert loss_metric5._forward_cache_synced is False + result5.sync_on_step_metrics() + assert loss_metric5._forward_cache_synced is True + assert torch.equal(loss_metric5.value, loss_metric5._forward_cache) + + # Test 6: both on_step and on_epoch sync work together + result6 = _ResultCollection(training=True) + result6.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=True, sync_dist=True) + result6.log("training_step", "acc", torch.tensor(0.9), on_step=True, on_epoch=True, sync_dist=True) + loss_metric6 = result6["training_step.loss"] + acc_metric6 = result6["training_step.acc"] + result6.sync_on_step_metrics() + assert loss_metric6._forward_cache_synced is True + assert acc_metric6._forward_cache_synced is True + result6.sync_on_epoch_metrics() + assert loss_metric6._computed is not None + assert acc_metric6._computed is not None + + +@pytest.mark.flaky(reruns=3) +@RunIf(skip_windows=True) +def test_sync_metrics_ddp(): + """Test sync_on_step_metrics and sync_on_epoch_metrics in DDP mode (consolidated tests).""" + spawn_launch(_sync_metrics_ddp_fn, [torch.device("cpu")] * 2) + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._assert_sync_dist_metric_keys_consistency") +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_step_metrics_non_tensor_skipped(_, mock_assert): + """Test that non-tensor (TorchMetric) metrics are skipped from sync validation.""" + from torchmetrics import Accuracy + + result = _ResultCollection(training=True) + # Log a TorchMetric - these have is_tensor=False and should be skipped + metric = Accuracy(task="binary") + result.log("training_step", "accuracy", metric, on_step=True, on_epoch=False, sync_dist=True) + + accuracy_metric = result["training_step.accuracy"] + # Verify it's not a tensor metric + assert accuracy_metric.is_tensor is False + + result.sync_on_step_metrics() + # Should not call the validation since non-tensor metrics are skipped + mock_assert.assert_not_called() + + +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._assert_sync_dist_metric_keys_consistency") +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_epoch_metrics_non_tensor_skipped(_, mock_assert): + """Test that non-tensor (TorchMetric) metrics are skipped from sync validation.""" + from torchmetrics import Accuracy + + result = _ResultCollection(training=True) + # Log a TorchMetric - these have is_tensor=False and should be skipped + metric = Accuracy(task="binary") + result.log("training_step", "accuracy", metric, on_step=False, on_epoch=True, sync_dist=True) + + accuracy_metric = result["training_step.accuracy"] + # Verify it's not a tensor metric + assert accuracy_metric.is_tensor is False + + result.sync_on_epoch_metrics() + # Should not call the validation since non-tensor metrics are skipped + mock_assert.assert_not_called() + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_consistent) +@patch("torch.distributed.get_world_size", return_value=2) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_step_metrics_on_step_only_updates_value(_, __, ___, ____): + """Test that sync_on_step_metrics updates result_metric.value for on_step only metrics.""" + result = _ResultCollection(training=True) + # Log with on_step=True, on_epoch=False - this is the "on_step only" path + result.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + + loss_metric = result["training_step.loss"] + # Override the sync function with our mock + loss_metric.meta._sync.fn = _mock_sync_fn + loss_metric.meta._sync._should = True + + # Before sync, cache is not synced + assert loss_metric._forward_cache_synced is False + loss_metric.value.clone() + + result.sync_on_step_metrics() + + # After sync, value should be updated (for on_step only metrics) + assert loss_metric._forward_cache_synced is True + # The value should now equal the synced forward_cache + assert torch.equal(loss_metric.value, loss_metric._forward_cache) + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_consistent) +@patch("torch.distributed.get_world_size", return_value=2) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_step_metrics_on_step_and_epoch_no_value_update(_, __, ___, ____): + """Test that sync_on_step_metrics does NOT update value for on_step+on_epoch metrics.""" + result = _ResultCollection(training=True) + # Log with on_step=True, on_epoch=True - value should NOT be updated + result.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=True, sync_dist=True) + + loss_metric = result["training_step.loss"] + # Override the sync function with our mock + loss_metric.meta._sync.fn = _mock_sync_fn + loss_metric.meta._sync._should = True + + # Store original value + loss_metric.value.clone() + + result.sync_on_step_metrics() + + # After sync, forward_cache should be synced but value should NOT be updated + # (because on_epoch=True means we accumulate, not replace) + assert loss_metric._forward_cache_synced is True + # Value is accumulated for epoch-level metrics, so it's updated during update() but not in sync + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_consistent) +@patch("torch.distributed.get_world_size", return_value=2) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_step_metrics_already_synced_skipped(_, __, ___, ____): + """Test that already-synced metrics are skipped in sync_on_step_metrics.""" + result = _ResultCollection(training=True) + result.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + + loss_metric = result["training_step.loss"] + loss_metric.meta._sync.fn = _mock_sync_fn + loss_metric.meta._sync._should = True + + # Mark as already synced + loss_metric._forward_cache_synced = True + + # Create a mock to track if sync is called + call_count = [0] + original_fn = loss_metric.meta._sync.fn + + def counting_sync(value, reduce_op=None, group=None): + call_count[0] += 1 + return original_fn(value, reduce_op, group) + + loss_metric.meta._sync.fn = counting_sync + + result.sync_on_step_metrics() + + # Sync should not have been called since it was already synced + assert call_count[0] == 0 + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_consistent) +@patch("torch.distributed.get_world_size", return_value=2) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_epoch_metrics_already_computed_skipped(_, __, ___, ____): + """Test that already-computed metrics are skipped in sync_on_epoch_metrics.""" + result = _ResultCollection(training=True) + result.log("training_step", "loss", torch.tensor(1.0), on_step=False, on_epoch=True, sync_dist=True) + + loss_metric = result["training_step.loss"] + loss_metric.meta._sync.fn = _mock_sync_fn + loss_metric.meta._sync._should = True + + # Mark as already computed + loss_metric._computed = torch.tensor(1.0) + + # Create a mock to track if compute is called + compute_called = [False] + original_compute = loss_metric.compute + + def tracking_compute(): + compute_called[0] = True + return original_compute() + + loss_metric.compute = tracking_compute + + result.sync_on_epoch_metrics() + + # Compute should not have been called since it was already computed + assert compute_called[0] is False + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_consistent) +@patch("torch.distributed.get_world_size", return_value=2) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_step_metrics_forward_cache_none_skipped(_, __, ___, ____): + """Test that metrics with None forward_cache are skipped.""" + result = _ResultCollection(training=True) + result.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + + loss_metric = result["training_step.loss"] + loss_metric.meta._sync.fn = _mock_sync_fn + loss_metric.meta._sync._should = True + + # Set forward cache to None + loss_metric._forward_cache = None + + # Should not raise, just skip this metric + result.sync_on_step_metrics() + # Since forward_cache is None, it should remain not synced + assert loss_metric._forward_cache_synced is False + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_consistent) +@patch("torch.distributed.get_world_size", return_value=2) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_step_metrics_multiple_metrics(_, __, ___, ____): + """Test sync_on_step_metrics with multiple metrics.""" + result = _ResultCollection(training=True) + result.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + result.log("training_step", "acc", torch.tensor(0.9), on_step=True, on_epoch=False, sync_dist=True) + result.log("training_step", "f1", torch.tensor(0.85), on_step=True, on_epoch=False, sync_dist=True) + + for key in ["training_step.loss", "training_step.acc", "training_step.f1"]: + metric = result[key] + metric.meta._sync.fn = _mock_sync_fn + metric.meta._sync._should = True + + result.sync_on_step_metrics() + + # All metrics should be synced + for key in ["training_step.loss", "training_step.acc", "training_step.f1"]: + metric = result[key] + assert metric._forward_cache_synced is True + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_consistent) +@patch("torch.distributed.get_world_size", return_value=2) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_sync_on_epoch_metrics_multiple_metrics(_, __, ___, ____): + """Test sync_on_epoch_metrics with multiple metrics.""" + result = _ResultCollection(training=True) + result.log("training_step", "loss", torch.tensor(1.0), on_step=False, on_epoch=True, sync_dist=True) + result.log("training_step", "acc", torch.tensor(0.9), on_step=False, on_epoch=True, sync_dist=True) + result.log("training_step", "f1", torch.tensor(0.85), on_step=False, on_epoch=True, sync_dist=True) + + for key in ["training_step.loss", "training_step.acc", "training_step.f1"]: + metric = result[key] + metric.meta._sync.fn = _mock_sync_fn + metric.meta._sync._should = True + + result.sync_on_epoch_metrics() + + # All metrics should have _computed set + for key in ["training_step.loss", "training_step.acc", "training_step.f1"]: + metric = result[key] + assert metric._computed is not None + + +# NOTE: The following DDP tests have been consolidated into test_sync_metrics_ddp above +# to reduce the number of process spawns and avoid potential segfaults in CI. +# - test_sync_on_step_metrics_ddp +# - test_sync_on_epoch_metrics_ddp +# - test_sync_on_step_metrics_mismatch_ddp +# - test_sync_on_epoch_metrics_mismatch_ddp +# - test_sync_on_step_metrics_value_update_ddp +# - test_sync_on_step_and_epoch_metrics_ddp + + +def _mock_all_gather_detailed_mismatch(output_list, obj, group=None): + """Mock all_gather_object that simulates detailed inconsistent keys across 3 ranks.""" + output_list[0] = ["training_step.loss", "training_step.metric_a"] + output_list[1] = ["training_step.loss", "training_step.metric_b"] + output_list[2] = ["training_step.loss", "training_step.metric_c"] + + +@patch("torch.distributed.all_gather_object", side_effect=_mock_all_gather_detailed_mismatch) +@patch("torch.distributed.get_world_size", return_value=3) +@patch("torch.distributed.is_available", return_value=True) +@patch("lightning.pytorch.trainer.connectors.logger_connector.result._distributed_is_initialized", return_value=True) +def test_assert_sync_dist_metric_keys_consistency_detailed_error_message(_, __, ___, ____): + """Test that error message contains detailed information about all ranks.""" + keys = ["training_step.loss", "training_step.metric_a"] + with pytest.raises(MisconfigurationException) as excinfo: + _assert_sync_dist_metric_keys_consistency(keys, "training_step", None) + + message = str(excinfo.value) + # Verify the error message contains all the expected components + assert "sync_dist=True" in message + assert "Detected a mismatch during `training_step`" in message + assert "Synchronized metric keys per rank:" in message + assert "rank=0:" in message + assert "rank=1:" in message + assert "rank=2:" in message + assert "training_step.metric_a" in message + assert "training_step.metric_b" in message + assert "training_step.metric_c" in message + # Verify it contains guidance on how to fix + assert "log the same keys on all ranks" in message + assert "sync_dist=False" in message + + +def test_forward_cache_synced_initialization(): + """Test that _forward_cache_synced is initialized to False.""" + result = _ResultCollection(training=True) + result.log("training_step", "loss", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + + loss_metric = result["training_step.loss"] + # Verify _forward_cache_synced is initialized to False + assert loss_metric._forward_cache_synced is False + # Verify _forward_cache is set + assert loss_metric._forward_cache is not None + + +def test_forward_cache_set_on_step_metrics(): + """Test that _forward_cache is properly set for on_step metrics.""" + result = _ResultCollection(training=True) + value = torch.tensor(2.5) + result.log("training_step", "loss", value, on_step=True, on_epoch=False, sync_dist=True) + + loss_metric = result["training_step.loss"] + # The forward_cache should be a clone of the value + assert loss_metric._forward_cache is not None + assert torch.equal(loss_metric._forward_cache, value) + assert loss_metric._forward_cache_synced is False + + +def test_on_step_only_value_set_to_forward_cache(): + """Test that for on_step only metrics, value is set to forward_cache.""" + result = _ResultCollection(training=True) + value = torch.tensor(3.0) + result.log("training_step", "loss", value, on_step=True, on_epoch=False, sync_dist=True) + + loss_metric = result["training_step.loss"] + # For on_step only metrics, value should equal forward_cache + assert torch.equal(loss_metric.value, loss_metric._forward_cache) + + +def test_on_step_and_epoch_value_accumulated(): + """Test that for on_step+on_epoch metrics, value is accumulated separately.""" + result = _ResultCollection(training=True) + value = torch.tensor(2.0) + result.log("training_step", "loss", value, on_step=True, on_epoch=True, sync_dist=True) + + loss_metric = result["training_step.loss"] + # forward_cache should be set + assert loss_metric._forward_cache is not None + assert loss_metric._forward_cache_synced is False + # Value should be accumulated (not just set to forward_cache for on_epoch metrics) diff --git a/tests/tests_pytorch/trainer/connectors/test_logger_connector.py b/tests/tests_pytorch/trainer/connectors/test_logger_connector.py index 7a89efd133235..9ea3ce631a743 100644 --- a/tests/tests_pytorch/trainer/connectors/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_logger_connector.py @@ -59,3 +59,45 @@ def test_uses_batches_that_stepped(mock_convert): ) logger.save.assert_called_once_with() mock_convert.return_value.setdefault.assert_called_once_with("epoch", trainer.current_epoch) + + +def test_sync_on_step_metrics_delegates_to_results(): + """Test that sync_on_step_metrics delegates to results.sync_on_step_metrics.""" + trainer = MagicMock(spec=Trainer) + trainer._results = MagicMock() + connector = _LoggerConnector(trainer) + + connector.sync_on_step_metrics() + + trainer._results.sync_on_step_metrics.assert_called_once() + + +def test_sync_on_step_metrics_handles_none_results(): + """Test that sync_on_step_metrics handles None results gracefully.""" + trainer = MagicMock(spec=Trainer) + trainer._results = None + connector = _LoggerConnector(trainer) + + # Should not raise when results is None + connector.sync_on_step_metrics() + + +def test_sync_on_epoch_metrics_delegates_to_results(): + """Test that sync_on_epoch_metrics delegates to results.sync_on_epoch_metrics.""" + trainer = MagicMock(spec=Trainer) + trainer._results = MagicMock() + connector = _LoggerConnector(trainer) + + connector.sync_on_epoch_metrics() + + trainer._results.sync_on_epoch_metrics.assert_called_once() + + +def test_sync_on_epoch_metrics_handles_none_results(): + """Test that sync_on_epoch_metrics handles None results gracefully.""" + trainer = MagicMock(spec=Trainer) + trainer._results = None + connector = _LoggerConnector(trainer) + + # Should not raise when results is None + connector.sync_on_epoch_metrics() diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index 6916eae68e9c0..6196db9311a9a 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -24,6 +24,7 @@ import torch from lightning_utilities.test.warning import no_warning_call from torch import Tensor +from torch.multiprocessing import ProcessRaisedException from torch.utils.data import DataLoader from torchmetrics import Accuracy @@ -342,6 +343,110 @@ def validation_step(self, batch, batch_idx): return super().validation_step(batch, batch_idx) +class InconsistentSyncDistKeysModel(BoringModel): + """Model that logs different metric keys on different ranks with sync_dist=True.""" + + def validation_step(self, batch, batch_idx): + self.log("val_loss", torch.tensor(0.0), sync_dist=True) + if self.trainer.global_rank == 0: + self.log("val_metric_a", torch.tensor(1.0), sync_dist=True) + else: + self.log("val_metric_b", torch.tensor(2.0), sync_dist=True) + return super().validation_step(batch, batch_idx) + + +class InconsistentSyncDistKeysOnStepModel(BoringModel): + """Model that logs different metric keys on different ranks with on_step=True and sync_dist=True.""" + + def training_step(self, batch, batch_idx): + # make sure we exercise the on_step sync path + self.log("train_loss", torch.tensor(0.0), on_step=True, on_epoch=False, sync_dist=True) + if self.trainer.global_rank == 0: + self.log("train_metric_a", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + else: + self.log("train_metric_b", torch.tensor(2.0), on_step=True, on_epoch=False, sync_dist=True) + return super().training_step(batch, batch_idx) + + +class InconsistentSyncDistKeysOrderModel(BoringModel): + """Model that logs the same metric keys but in different order on different ranks.""" + + def validation_step(self, batch, batch_idx): + self.log("val_loss", torch.tensor(0.0), sync_dist=True) + if self.trainer.global_rank == 0: + self.log("val_metric_a", torch.tensor(1.0), sync_dist=True) + self.log("val_metric_b", torch.tensor(2.0), sync_dist=True) + else: + self.log("val_metric_b", torch.tensor(2.0), sync_dist=True) + self.log("val_metric_a", torch.tensor(1.0), sync_dist=True) + return super().validation_step(batch, batch_idx) + + +class InconsistentSyncDistKeysOnEpochTrainingModel(BoringModel): + """Model that logs different metric keys on different ranks with on_epoch=True in training.""" + + def training_step(self, batch, batch_idx): + # Exercise the on_epoch sync path in training loop + self.log("train_loss", torch.tensor(0.0), on_step=False, on_epoch=True, sync_dist=True) + if self.trainer.global_rank == 0: + self.log("train_metric_a", torch.tensor(1.0), on_step=False, on_epoch=True, sync_dist=True) + else: + self.log("train_metric_b", torch.tensor(2.0), on_step=False, on_epoch=True, sync_dist=True) + return super().training_step(batch, batch_idx) + + +class InconsistentSyncDistKeysOnStepValidationModel(BoringModel): + """Model that logs different metric keys on different ranks with on_step=True in validation.""" + + def validation_step(self, batch, batch_idx): + # Exercise the on_step sync path in validation loop + self.log("val_loss", torch.tensor(0.0), on_step=True, on_epoch=False, sync_dist=True) + if self.trainer.global_rank == 0: + self.log("val_metric_a", torch.tensor(1.0), on_step=True, on_epoch=False, sync_dist=True) + else: + self.log("val_metric_b", torch.tensor(2.0), on_step=True, on_epoch=False, sync_dist=True) + return super().validation_step(batch, batch_idx) + + +class InconsistentSyncDistKeysBothStepAndEpochModel(BoringModel): + """Model that logs different metric keys with on_step=True, on_epoch=True.""" + + def training_step(self, batch, batch_idx): + # Exercise both on_step and on_epoch sync paths + self.log("train_loss", torch.tensor(0.0), on_step=True, on_epoch=True, sync_dist=True) + if self.trainer.global_rank == 0: + self.log("train_metric_a", torch.tensor(1.0), on_step=True, on_epoch=True, sync_dist=True) + else: + self.log("train_metric_b", torch.tensor(2.0), on_step=True, on_epoch=True, sync_dist=True) + return super().training_step(batch, batch_idx) + + +class ConsistentSyncDistKeysModel(BoringModel): + """Model that logs consistent metric keys on all ranks with sync_dist=True.""" + + def training_step(self, batch, batch_idx): + # All ranks log the same keys + self.log( + "train_loss", torch.tensor(float(self.trainer.global_rank)), on_step=True, on_epoch=True, sync_dist=True + ) + self.log( + "train_acc", + torch.tensor(0.9 + self.trainer.global_rank * 0.05), + on_step=True, + on_epoch=True, + sync_dist=True, + ) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + # All ranks log the same keys + self.log("val_loss", torch.tensor(float(self.trainer.global_rank)), on_step=True, on_epoch=True, sync_dist=True) + self.log( + "val_acc", torch.tensor(0.85 + self.trainer.global_rank * 0.05), on_step=True, on_epoch=True, sync_dist=True + ) + return super().validation_step(batch, batch_idx) + + @pytest.mark.parametrize( ("devices", "accelerator"), [ @@ -422,6 +527,202 @@ def validation_step(self, batch, batch_idx): assert trainer.logged_metrics["bar"] == 2 +@RunIf(skip_windows=True) +def test_logging_sync_dist_inconsistent_keys_raises(tmp_path): + """Test that logging different metric keys with sync_dist=True raises an error.""" + model = InconsistentSyncDistKeysModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cpu", + devices=2, + strategy="ddp_spawn", + max_epochs=1, + limit_train_batches=0, + limit_val_batches=1, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + + with pytest.raises(ProcessRaisedException) as excinfo: + trainer.validate(model, dataloaders=model.val_dataloader(), verbose=False) + + message = str(excinfo.value) + assert "sync_dist=True" in message + assert "Detected a mismatch" in message + assert "validation_step.val_metric_a" in message + assert "validation_step.val_metric_b" in message + + +@RunIf(skip_windows=True) +def test_logging_sync_dist_inconsistent_keys_on_step_raises(tmp_path): + """Test that logging different metric keys with on_step=True and sync_dist=True raises an error.""" + model = InconsistentSyncDistKeysOnStepModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cpu", + devices=2, + strategy="ddp_spawn", + max_epochs=1, + limit_train_batches=1, + limit_val_batches=0, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + log_every_n_steps=1, + ) + + with pytest.raises(ProcessRaisedException) as excinfo: + trainer.fit(model) + + message = str(excinfo.value) + assert "sync_dist=True" in message + assert "Detected a mismatch" in message + assert "training_step.train_metric_a" in message + assert "training_step.train_metric_b" in message + + +@RunIf(skip_windows=True) +def test_logging_sync_dist_inconsistent_keys_order_raises(tmp_path): + """Test that logging same metric keys in different order with sync_dist=True raises an error.""" + model = InconsistentSyncDistKeysOrderModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cpu", + devices=2, + strategy="ddp_spawn", + max_epochs=1, + limit_train_batches=0, + limit_val_batches=1, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + + with pytest.raises(ProcessRaisedException) as excinfo: + trainer.validate(model, dataloaders=model.val_dataloader(), verbose=False) + + message = str(excinfo.value) + assert "sync_dist=True" in message + assert "Detected a mismatch" in message + assert "validation_step.val_metric_a" in message + assert "validation_step.val_metric_b" in message + + +@RunIf(skip_windows=True) +def test_logging_sync_dist_inconsistent_keys_on_epoch_training_raises(tmp_path): + """Test that logging different metric keys with on_epoch=True in training raises an error.""" + model = InconsistentSyncDistKeysOnEpochTrainingModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cpu", + devices=2, + strategy="ddp_spawn", + max_epochs=1, + limit_train_batches=2, + limit_val_batches=0, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + + with pytest.raises(ProcessRaisedException) as excinfo: + trainer.fit(model) + + message = str(excinfo.value) + assert "sync_dist=True" in message + assert "Detected a mismatch" in message + assert "training_step.train_metric_a" in message + assert "training_step.train_metric_b" in message + + +@RunIf(skip_windows=True) +def test_logging_sync_dist_inconsistent_keys_on_step_validation_raises(tmp_path): + """Test that logging different metric keys with on_step=True in validation raises an error.""" + model = InconsistentSyncDistKeysOnStepValidationModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cpu", + devices=2, + strategy="ddp_spawn", + max_epochs=1, + limit_train_batches=0, + limit_val_batches=1, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + + with pytest.raises(ProcessRaisedException) as excinfo: + trainer.validate(model, dataloaders=model.val_dataloader(), verbose=False) + + message = str(excinfo.value) + assert "sync_dist=True" in message + assert "Detected a mismatch" in message + assert "validation_step.val_metric_a" in message + assert "validation_step.val_metric_b" in message + + +@RunIf(skip_windows=True) +def test_logging_sync_dist_inconsistent_keys_both_step_and_epoch_raises(tmp_path): + """Test that logging different metric keys with on_step=True and on_epoch=True raises an error.""" + model = InconsistentSyncDistKeysBothStepAndEpochModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cpu", + devices=2, + strategy="ddp_spawn", + max_epochs=1, + limit_train_batches=1, + limit_val_batches=0, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + log_every_n_steps=1, + ) + + with pytest.raises(ProcessRaisedException) as excinfo: + trainer.fit(model) + + message = str(excinfo.value) + assert "sync_dist=True" in message + assert "Detected a mismatch" in message + + +@RunIf(skip_windows=True) +def test_logging_sync_dist_consistent_keys_works(tmp_path): + """Test that logging consistent metric keys with sync_dist=True works correctly.""" + model = ConsistentSyncDistKeysModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cpu", + devices=2, + strategy="ddp_spawn", + max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + ) + + # Should not raise - all ranks have consistent keys + trainer.fit(model) + + # Verify metrics were logged (values are averaged across ranks) + assert "train_loss" in trainer.callback_metrics + assert "train_acc" in trainer.callback_metrics + assert "val_loss" in trainer.callback_metrics + assert "val_acc" in trainer.callback_metrics + + def test_progress_bar_metrics_contains_values_on_train_epoch_end(tmp_path: str): class TestModel(BoringModel): def training_step(self, *args):