Skip to content

Commit aff4ce9

Browse files
author
Seppo Enarvi
committed
Summon all model parameters before updating the average model, when using FSDP
1 parent 04baf7f commit aff4ce9

File tree

4 files changed

+78
-30
lines changed

4 files changed

+78
-30
lines changed

docs/source-pytorch/advanced/training_tricks.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ Lightning provides two callbacks to facilitate weight averaging. :class:`~lightn
6262
is a generic callback that wraps the
6363
`AveragedModel <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.AveragedModel.html>`__ class from
6464
PyTorch. It allows SWA, EMA, or a custom averaging strategy to be used. By default, it updates the weights after every
65-
step, but it can be customized to update at specific steps or epochs by overriding the `should_update()` method.
65+
step, but it can be customized to update at specific steps or epochs by overriding the ``should_update()`` method.
6666

6767
The older :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback is specific to SWA. It starts the SWA
6868
procedure after a certain number of epochs and always runs on every epoch. Additionally, it switches to a constant
@@ -75,7 +75,7 @@ procedure starts.
7575

7676
.. seealso::
7777
The :class:`~lightning.pytorch.callbacks.WeightAveraging` callback and
78-
:class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback
78+
:class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback.
7979

8080
.. testcode::
8181

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Added
1212

13-
-
13+
- Added support for FSDP with `WeightAveraging`, by summoning full parameters before averaged model update ([#21414](https://github.com/Lightning-AI/pytorch-lightning/pull/21414))
1414

1515
### Changed
1616

src/lightning/pytorch/callbacks/weight_averaging.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
from typing import Any, Optional, Union
2222

2323
import torch
24+
from torch import Tensor, nn
25+
from torch.distributed.fsdp import FullyShardedDataParallel
2426
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
2527
from typing_extensions import override
2628

2729
import lightning.pytorch as pl
2830
from lightning.pytorch.callbacks.callback import Callback
31+
from lightning.pytorch.strategies.fsdp import FSDPStrategy
2932
from lightning.pytorch.utilities.model_helpers import is_overridden
3033
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
3134
from lightning.pytorch.utilities.types import STEP_OUTPUT
@@ -56,10 +59,11 @@ class WeightAveraging(Callback):
5659
provided by Lightning.
5760
5861
Note:
59-
To ensure that the :class:`AveragedModel` will contain all layers, ``setup()`` will call
60-
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` before instantiating the
61-
:class:`AveragedModel`. However, that hook is not called in a strategy aware context, sharded models do not work
62-
with weight averaging, and a warning will be issued.
62+
Sharded models are challenging for weight averaging. To ensure that the :class:`AveragedModel` will contain all
63+
parameters, ``setup()`` will call :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` before
64+
instantiating the :class:`AveragedModel`. However, that hook is not called in a strategy aware context, meaning
65+
that the full model is initialized in CPU memory. Furthermore, every time the averaged model is updated, the
66+
full parameters are summoned in GPU memory.
6367
6468
Example::
6569
@@ -149,8 +153,9 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
149153
# AveragedModel. However, sharding will not be done and a warning will be issued.
150154
if is_overridden("configure_model", pl_module):
151155
rank_zero_warn(
152-
"You're using the WeightAveraging callback with a model that overrides the configure_model "
153-
"callback. WeightAveraging doesn't support sharding model layers, so you may run out of memory."
156+
"You're using the WeightAveraging callback with a model that overrides the configure_model() hook. "
157+
"WeightAveraging will construct the model and the average model in CPU memory, so you may run out "
158+
"of memory during initialization."
154159
)
155160
pl_module.configure_model()
156161

@@ -178,8 +183,7 @@ def on_train_batch_end(
178183
# make step_idx consistent with epoch_idx, we'll pass a zero-based index.
179184
step_idx = trainer.global_step - 1
180185
if (trainer.global_step > self._latest_update_step) and self.should_update(step_idx=step_idx):
181-
assert self._average_model is not None
182-
self._average_model.update_parameters(pl_module)
186+
self._update_average_model(trainer, pl_module)
183187
self._latest_update_step = trainer.global_step
184188

185189
@override
@@ -194,8 +198,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
194198
195199
"""
196200
if (trainer.current_epoch > self._latest_update_epoch) and self.should_update(epoch_idx=trainer.current_epoch):
197-
assert self._average_model is not None
198-
self._average_model.update_parameters(pl_module)
201+
self._update_average_model(trainer, pl_module)
199202
self._latest_update_epoch = trainer.current_epoch
200203

201204
@override
@@ -210,7 +213,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
210213
211214
"""
212215
assert self._average_model is not None
213-
self._copy_average_to_current(pl_module)
216+
self._copy_average_to_current(trainer, pl_module)
214217

215218
@override
216219
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -224,7 +227,7 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn
224227
225228
"""
226229
if self._average_model is not None:
227-
self._swap_models(pl_module)
230+
self._swap_models(trainer, pl_module)
228231

229232
@override
230233
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -238,7 +241,7 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin
238241
239242
"""
240243
if self._average_model is not None:
241-
self._swap_models(pl_module)
244+
self._swap_models(trainer, pl_module)
242245

243246
@override
244247
def state_dict(self) -> dict[str, Any]:
@@ -334,7 +337,23 @@ def on_load_checkpoint(
334337
)
335338
self._average_model.module.load_state_dict(deepcopy(checkpoint["state_dict"]), strict=False)
336339

337-
def _swap_models(self, pl_module: "pl.LightningModule") -> None:
340+
def _update_average_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
341+
"""Updates the :class:`AveragedModel` parameters.
342+
343+
Args:
344+
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
345+
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
346+
347+
"""
348+
assert self._average_model is not None
349+
if isinstance(trainer.strategy, FSDPStrategy):
350+
assert isinstance(trainer.strategy.model, nn.Module)
351+
with FullyShardedDataParallel.summon_full_params(trainer.strategy.model):
352+
self._average_model.update_parameters(trainer.strategy.model)
353+
else:
354+
self._average_model.update_parameters(pl_module)
355+
356+
def _swap_models(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
338357
"""Swaps the parameter values of the current model and the :class:`AveragedModel`.
339358
340359
Args:
@@ -343,13 +362,25 @@ def _swap_models(self, pl_module: "pl.LightningModule") -> None:
343362
"""
344363
assert self._average_model is not None
345364
average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers())
346-
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
347-
for average_param, current_param in zip(average_params, current_params):
348-
tmp = average_param.data.clone()
349-
average_param.data.copy_(current_param.data)
350-
current_param.data.copy_(tmp)
351365

352-
def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
366+
def _swap_param(a: nn.Parameter | Tensor, b: nn.Parameter | Tensor) -> None:
367+
tmp = a.data.clone()
368+
a.data.copy_(b.data)
369+
b.data.copy_(tmp)
370+
371+
def _swap(model: nn.Module) -> None:
372+
current_params = itertools.chain(model.parameters(), model.buffers())
373+
for average_param, current_param in zip(average_params, current_params):
374+
_swap_param(average_param, current_param)
375+
376+
if isinstance(trainer.strategy, FSDPStrategy):
377+
assert isinstance(trainer.strategy.model, nn.Module)
378+
with FullyShardedDataParallel.summon_full_params(trainer.strategy.model):
379+
_swap(trainer.strategy.model)
380+
else:
381+
_swap(pl_module)
382+
383+
def _copy_average_to_current(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
353384
"""Copies the parameter values from the :class:`AveragedModel` to the current model.
354385
355386
Args:
@@ -358,9 +389,18 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
358389
"""
359390
assert self._average_model is not None
360391
average_params = itertools.chain(self._average_model.module.parameters(), self._average_model.module.buffers())
361-
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
362-
for average_param, current_param in zip(average_params, current_params):
363-
current_param.data.copy_(average_param.data)
392+
393+
def _copy(model: nn.Module) -> None:
394+
current_params = itertools.chain(model.parameters(), model.buffers())
395+
for average_param, current_param in zip(average_params, current_params):
396+
current_param.data.copy_(average_param.data)
397+
398+
if isinstance(trainer.strategy, FSDPStrategy):
399+
assert isinstance(trainer.strategy.model, nn.Module)
400+
with FullyShardedDataParallel.summon_full_params(trainer.strategy.model):
401+
_copy(trainer.strategy.model)
402+
else:
403+
_copy(pl_module)
364404

365405

366406
class EMAWeightAveraging(WeightAveraging):

tests/tests_pytorch/callbacks/test_weight_averaging.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ def configure_optimizers(self) -> None:
5050
class LargeTestModel(BoringModel):
5151
def __init__(self):
5252
super().__init__()
53-
self.layer = None
53+
self.layer: Optional[nn.Module] = None
5454

5555
def configure_model(self):
56-
print("XXX configure_model")
57-
self.layer = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2))
56+
if self.layer is None:
57+
self.layer = nn.Sequential(nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 2))
5858

5959
def configure_optimizers(self):
60-
return torch.optim.SGD(self.parameters(), lr=0.01)
60+
return torch.optim.AdamW(self.parameters(), lr=0.1)
6161

6262

6363
class EMAAveragingFunction:
@@ -281,6 +281,14 @@ def test_ema_configure_model(tmp_path, strategy, accelerator, devices):
281281
assert isinstance(callback._average_model.module.layer, nn.Sequential)
282282

283283

284+
@pytest.mark.filterwarnings("ignore::FutureWarning")
285+
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
286+
def test_ema_fsdp(tmp_path):
287+
model = LargeTestModel()
288+
dataset = RandomIterableDataset(32, 32)
289+
_train(model, dataset, tmp_path, EMATestCallback(), strategy="fsdp", accelerator="gpu", devices=2)
290+
291+
284292
def _train(
285293
model: BoringModel,
286294
dataset: Dataset,

0 commit comments

Comments
 (0)