Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source-pytorch/advanced/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Lightning provides two callbacks to facilitate weight averaging. :class:`~lightn
is a generic callback that wraps the
`AveragedModel <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.AveragedModel.html>`__ class from
PyTorch. It allows SWA, EMA, or a custom averaging strategy to be used. By default, it updates the weights after every
step, but it can be customized to update at specific steps or epochs by overriding the `should_update()` method.
step, but it can be customized to update at specific steps or epochs by overriding the ``should_update()`` method.

The older :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback is specific to SWA. It starts the SWA
procedure after a certain number of epochs and always runs on every epoch. Additionally, it switches to a constant
Expand All @@ -75,7 +75,7 @@ procedure starts.

.. seealso::
The :class:`~lightning.pytorch.callbacks.WeightAveraging` callback and
:class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback
:class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback.

.. testcode::

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Added support for FSDP with `WeightAveraging`, by summoning full parameters before averaged model update ([#21414](https://github.com/Lightning-AI/pytorch-lightning/pull/21414))

### Changed

Expand Down
86 changes: 63 additions & 23 deletions src/lightning/pytorch/callbacks/weight_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
from typing import Any, Optional, Union

import torch
from torch import Tensor, nn
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
from typing_extensions import override

import lightning.pytorch as pl
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.strategies.fsdp import FSDPStrategy
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -56,10 +59,11 @@ class WeightAveraging(Callback):
provided by Lightning.

Note:
To ensure that the :class:`AveragedModel` will contain all layers, ``setup()`` will call
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` before instantiating the
:class:`AveragedModel`. However, that hook is not called in a strategy aware context, sharded models do not work
with weight averaging, and a warning will be issued.
Sharded models are challenging for weight averaging. To ensure that the :class:`AveragedModel` will contain all
parameters, ``setup()`` will call :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` before
instantiating the :class:`AveragedModel`. However, that hook is not called in a strategy aware context, meaning
that the full model is initialized in CPU memory. Furthermore, every time the averaged model is updated, the
full parameters are summoned in GPU memory.

Example::

Expand Down Expand Up @@ -149,8 +153,9 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
# AveragedModel. However, sharding will not be done and a warning will be issued.
if is_overridden("configure_model", pl_module):
rank_zero_warn(
"You're using the WeightAveraging callback with a model that overrides the configure_model "
"callback. WeightAveraging doesn't support sharding model layers, so you may run out of memory."
"You're using the WeightAveraging callback with a model that overrides the configure_model() hook. "
"WeightAveraging will construct the model and the average model in CPU memory, so you may run out "
"of memory during initialization."
)
pl_module.configure_model()

Expand Down Expand Up @@ -178,8 +183,7 @@ def on_train_batch_end(
# make step_idx consistent with epoch_idx, we'll pass a zero-based index.
step_idx = trainer.global_step - 1
if (trainer.global_step > self._latest_update_step) and self.should_update(step_idx=step_idx):
assert self._average_model is not None
self._average_model.update_parameters(pl_module)
self._update_average_model(trainer, pl_module)
self._latest_update_step = trainer.global_step

@override
Expand All @@ -194,8 +198,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu

"""
if (trainer.current_epoch > self._latest_update_epoch) and self.should_update(epoch_idx=trainer.current_epoch):
assert self._average_model is not None
self._average_model.update_parameters(pl_module)
self._update_average_model(trainer, pl_module)
self._latest_update_epoch = trainer.current_epoch

@override
Expand All @@ -210,7 +213,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -

"""
assert self._average_model is not None
self._copy_average_to_current(pl_module)
self._copy_average_to_current(trainer, pl_module)

@override
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -224,7 +227,7 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn

"""
if self._average_model is not None:
self._swap_models(pl_module)
self._swap_models(trainer, pl_module)

@override
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand All @@ -238,7 +241,7 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin

"""
if self._average_model is not None:
self._swap_models(pl_module)
self._swap_models(trainer, pl_module)

@override
def state_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -334,7 +337,23 @@ def on_load_checkpoint(
)
self._average_model.module.load_state_dict(deepcopy(checkpoint["state_dict"]), strict=False)

def _swap_models(self, pl_module: "pl.LightningModule") -> None:
def _update_average_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Updates the :class:`AveragedModel` parameters.

Args:
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.

"""
assert self._average_model is not None
if isinstance(trainer.strategy, FSDPStrategy):
assert isinstance(trainer.strategy.model, nn.Module)
with FullyShardedDataParallel.summon_full_params(trainer.strategy.model):
self._average_model.update_parameters(trainer.strategy.model)
else:
self._average_model.update_parameters(pl_module)

def _swap_models(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Swaps the parameter values of the current model and the :class:`AveragedModel`.

Args:
Expand All @@ -343,13 +362,25 @@ def _swap_models(self, pl_module: "pl.LightningModule") -> None:
"""
assert self._average_model is not None
average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers())
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
for average_param, current_param in zip(average_params, current_params):
tmp = average_param.data.clone()
average_param.data.copy_(current_param.data)
current_param.data.copy_(tmp)

def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
def _swap_param(a: nn.Parameter | Tensor, b: nn.Parameter | Tensor) -> None:
tmp = a.data.clone()
a.data.copy_(b.data)
b.data.copy_(tmp)

def _swap(model: nn.Module) -> None:
current_params = itertools.chain(model.parameters(), model.buffers())
for average_param, current_param in zip(average_params, current_params):
_swap_param(average_param, current_param)

if isinstance(trainer.strategy, FSDPStrategy):
assert isinstance(trainer.strategy.model, nn.Module)
with FullyShardedDataParallel.summon_full_params(trainer.strategy.model):
_swap(trainer.strategy.model)
else:
_swap(pl_module)

def _copy_average_to_current(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Copies the parameter values from the :class:`AveragedModel` to the current model.

Args:
Expand All @@ -358,9 +389,18 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
"""
assert self._average_model is not None
average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers())
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
for average_param, current_param in zip(average_params, current_params):
current_param.data.copy_(average_param.data)

def _copy(model: nn.Module) -> None:
current_params = itertools.chain(model.parameters(), model.buffers())
for average_param, current_param in zip(average_params, current_params):
current_param.data.copy_(average_param.data)

if isinstance(trainer.strategy, FSDPStrategy):
assert isinstance(trainer.strategy.model, nn.Module)
with FullyShardedDataParallel.summon_full_params(trainer.strategy.model):
_copy(trainer.strategy.model)
else:
_copy(pl_module)


class EMAWeightAveraging(WeightAveraging):
Expand Down
16 changes: 12 additions & 4 deletions tests/tests_pytorch/callbacks/test_weight_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def configure_optimizers(self) -> None:
class LargeTestModel(BoringModel):
def __init__(self):
super().__init__()
self.layer = None
self.layer: Optional[nn.Module] = None

def configure_model(self):
print("XXX configure_model")
self.layer = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2))
if self.layer is None:
self.layer = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2))

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.01)
return torch.optim.AdamW(self.parameters(), lr=0.1)


class EMAAveragingFunction:
Expand Down Expand Up @@ -281,6 +281,14 @@ def test_ema_configure_model(tmp_path, strategy, accelerator, devices):
assert isinstance(callback._average_model.module.layer, nn.Sequential)


@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
def test_ema_fsdp(tmp_path):
model = LargeTestModel()
dataset = RandomIterableDataset(32, 32)
_train(model, dataset, tmp_path, EMATestCallback(), strategy="fsdp", accelerator="gpu", devices=2)


def _train(
model: BoringModel,
dataset: Dataset,
Expand Down
Loading