diff --git a/docs/source-pytorch/advanced/training_tricks.rst b/docs/source-pytorch/advanced/training_tricks.rst index 2df53b216a654..c48d06849c02f 100644 --- a/docs/source-pytorch/advanced/training_tricks.rst +++ b/docs/source-pytorch/advanced/training_tricks.rst @@ -62,7 +62,7 @@ Lightning provides two callbacks to facilitate weight averaging. :class:`~lightn is a generic callback that wraps the `AveragedModel `__ class from PyTorch. It allows SWA, EMA, or a custom averaging strategy to be used. By default, it updates the weights after every -step, but it can be customized to update at specific steps or epochs by overriding the `should_update()` method. +step, but it can be customized to update at specific steps or epochs by overriding the ``should_update()`` method. The older :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback is specific to SWA. It starts the SWA procedure after a certain number of epochs and always runs on every epoch. Additionally, it switches to a constant @@ -75,7 +75,7 @@ procedure starts. .. seealso:: The :class:`~lightning.pytorch.callbacks.WeightAveraging` callback and - :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback + :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback. .. testcode:: diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 591d258bcdd0b..b87d41296cf23 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added support for FSDP with `WeightAveraging`, by summoning full parameters before averaged model update ([#21414](https://github.com/Lightning-AI/pytorch-lightning/pull/21414)) ### Changed diff --git a/src/lightning/pytorch/callbacks/weight_averaging.py b/src/lightning/pytorch/callbacks/weight_averaging.py index 0640efed3d87b..94d14d8718a74 100644 --- a/src/lightning/pytorch/callbacks/weight_averaging.py +++ b/src/lightning/pytorch/callbacks/weight_averaging.py @@ -21,11 +21,14 @@ from typing import Any, Optional, Union import torch +from torch import Tensor, nn +from torch.distributed.fsdp import FullyShardedDataParallel from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn from typing_extensions import override import lightning.pytorch as pl from lightning.pytorch.callbacks.callback import Callback +from lightning.pytorch.strategies.fsdp import FSDPStrategy from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -56,10 +59,11 @@ class WeightAveraging(Callback): provided by Lightning. Note: - To ensure that the :class:`AveragedModel` will contain all layers, ``setup()`` will call - :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` before instantiating the - :class:`AveragedModel`. However, that hook is not called in a strategy aware context, sharded models do not work - with weight averaging, and a warning will be issued. + Sharded models are challenging for weight averaging. To ensure that the :class:`AveragedModel` will contain all + parameters, ``setup()`` will call :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` before + instantiating the :class:`AveragedModel`. However, that hook is not called in a strategy aware context, meaning + that the full model is initialized in CPU memory. Furthermore, every time the averaged model is updated, the + full parameters are summoned in GPU memory. Example:: @@ -149,8 +153,9 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s # AveragedModel. However, sharding will not be done and a warning will be issued. if is_overridden("configure_model", pl_module): rank_zero_warn( - "You're using the WeightAveraging callback with a model that overrides the configure_model " - "callback. WeightAveraging doesn't support sharding model layers, so you may run out of memory." + "You're using the WeightAveraging callback with a model that overrides the configure_model() hook. " + "WeightAveraging will construct the model and the average model in CPU memory, so you may run out " + "of memory during initialization." ) pl_module.configure_model() @@ -178,8 +183,7 @@ def on_train_batch_end( # make step_idx consistent with epoch_idx, we'll pass a zero-based index. step_idx = trainer.global_step - 1 if (trainer.global_step > self._latest_update_step) and self.should_update(step_idx=step_idx): - assert self._average_model is not None - self._average_model.update_parameters(pl_module) + self._update_average_model(trainer, pl_module) self._latest_update_step = trainer.global_step @override @@ -194,8 +198,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu """ if (trainer.current_epoch > self._latest_update_epoch) and self.should_update(epoch_idx=trainer.current_epoch): - assert self._average_model is not None - self._average_model.update_parameters(pl_module) + self._update_average_model(trainer, pl_module) self._latest_update_epoch = trainer.current_epoch @override @@ -210,7 +213,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - """ assert self._average_model is not None - self._copy_average_to_current(pl_module) + self._copy_average_to_current(trainer, pl_module) @override def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -224,7 +227,7 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn """ if self._average_model is not None: - self._swap_models(pl_module) + self._swap_models(trainer, pl_module) @override def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -238,7 +241,7 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin """ if self._average_model is not None: - self._swap_models(pl_module) + self._swap_models(trainer, pl_module) @override def state_dict(self) -> dict[str, Any]: @@ -334,7 +337,23 @@ def on_load_checkpoint( ) self._average_model.module.load_state_dict(deepcopy(checkpoint["state_dict"]), strict=False) - def _swap_models(self, pl_module: "pl.LightningModule") -> None: + def _update_average_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Updates the :class:`AveragedModel` parameters. + + Args: + trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance. + pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance. + + """ + assert self._average_model is not None + if isinstance(trainer.strategy, FSDPStrategy): + assert isinstance(trainer.strategy.model, nn.Module) + with FullyShardedDataParallel.summon_full_params(trainer.strategy.model): + self._average_model.update_parameters(trainer.strategy.model) + else: + self._average_model.update_parameters(pl_module) + + def _swap_models(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Swaps the parameter values of the current model and the :class:`AveragedModel`. Args: @@ -343,13 +362,25 @@ def _swap_models(self, pl_module: "pl.LightningModule") -> None: """ assert self._average_model is not None average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers()) - current_params = itertools.chain(pl_module.parameters(), pl_module.buffers()) - for average_param, current_param in zip(average_params, current_params): - tmp = average_param.data.clone() - average_param.data.copy_(current_param.data) - current_param.data.copy_(tmp) - def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None: + def _swap_param(a: nn.Parameter | Tensor, b: nn.Parameter | Tensor) -> None: + tmp = a.data.clone() + a.data.copy_(b.data) + b.data.copy_(tmp) + + def _swap(model: nn.Module) -> None: + current_params = itertools.chain(model.parameters(), model.buffers()) + for average_param, current_param in zip(average_params, current_params): + _swap_param(average_param, current_param) + + if isinstance(trainer.strategy, FSDPStrategy): + assert isinstance(trainer.strategy.model, nn.Module) + with FullyShardedDataParallel.summon_full_params(trainer.strategy.model): + _swap(trainer.strategy.model) + else: + _swap(pl_module) + + def _copy_average_to_current(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Copies the parameter values from the :class:`AveragedModel` to the current model. Args: @@ -358,9 +389,18 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None: """ assert self._average_model is not None average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers()) - current_params = itertools.chain(pl_module.parameters(), pl_module.buffers()) - for average_param, current_param in zip(average_params, current_params): - current_param.data.copy_(average_param.data) + + def _copy(model: nn.Module) -> None: + current_params = itertools.chain(model.parameters(), model.buffers()) + for average_param, current_param in zip(average_params, current_params): + current_param.data.copy_(average_param.data) + + if isinstance(trainer.strategy, FSDPStrategy): + assert isinstance(trainer.strategy.model, nn.Module) + with FullyShardedDataParallel.summon_full_params(trainer.strategy.model): + _copy(trainer.strategy.model) + else: + _copy(pl_module) class EMAWeightAveraging(WeightAveraging): diff --git a/tests/tests_pytorch/callbacks/test_weight_averaging.py b/tests/tests_pytorch/callbacks/test_weight_averaging.py index cfb066f023af0..54d0288839fc5 100644 --- a/tests/tests_pytorch/callbacks/test_weight_averaging.py +++ b/tests/tests_pytorch/callbacks/test_weight_averaging.py @@ -50,14 +50,14 @@ def configure_optimizers(self) -> None: class LargeTestModel(BoringModel): def __init__(self): super().__init__() - self.layer = None + self.layer: Optional[nn.Module] = None def configure_model(self): - print("XXX configure_model") - self.layer = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2)) + if self.layer is None: + self.layer = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2)) def configure_optimizers(self): - return torch.optim.SGD(self.parameters(), lr=0.01) + return torch.optim.AdamW(self.parameters(), lr=0.1) class EMAAveragingFunction: @@ -281,6 +281,14 @@ def test_ema_configure_model(tmp_path, strategy, accelerator, devices): assert isinstance(callback._average_model.module.layer, nn.Sequential) +@pytest.mark.filterwarnings("ignore::FutureWarning") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +def test_ema_fsdp(tmp_path): + model = LargeTestModel() + dataset = RandomIterableDataset(32, 32) + _train(model, dataset, tmp_path, EMATestCallback(), strategy="fsdp", accelerator="gpu", devices=2) + + def _train( model: BoringModel, dataset: Dataset,