2121from typing import Any , Optional , Union
2222
2323import torch
24+ from torch import Tensor , nn
25+ from torch .distributed .fsdp import FullyShardedDataParallel
2426from torch .optim .swa_utils import AveragedModel , get_ema_avg_fn
2527from typing_extensions import override
2628
2729import lightning .pytorch as pl
2830from lightning .pytorch .callbacks .callback import Callback
31+ from lightning .pytorch .strategies .fsdp import FSDPStrategy
2932from lightning .pytorch .utilities .model_helpers import is_overridden
3033from lightning .pytorch .utilities .rank_zero import rank_zero_info , rank_zero_warn
3134from lightning .pytorch .utilities .types import STEP_OUTPUT
@@ -56,10 +59,11 @@ class WeightAveraging(Callback):
5659 provided by Lightning.
5760
5861 Note:
59- To ensure that the :class:`AveragedModel` will contain all layers, ``setup()`` will call
60- :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` before instantiating the
61- :class:`AveragedModel`. However, that hook is not called in a strategy aware context, sharded models do not work
62- with weight averaging, and a warning will be issued.
62+ Sharded models are challenging for weight averaging. To ensure that the :class:`AveragedModel` will contain all
63+ parameters, ``setup()`` will call :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` before
64+ instantiating the :class:`AveragedModel`. However, that hook is not called in a strategy aware context, meaning
65+ that the full model is initialized in CPU memory. Furthermore, every time the averaged model is updated, the
66+ full parameters are summoned in GPU memory.
6367
6468 Example::
6569
@@ -149,8 +153,9 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
149153 # AveragedModel. However, sharding will not be done and a warning will be issued.
150154 if is_overridden ("configure_model" , pl_module ):
151155 rank_zero_warn (
152- "You're using the WeightAveraging callback with a model that overrides the configure_model "
153- "callback. WeightAveraging doesn't support sharding model layers, so you may run out of memory."
156+ "You're using the WeightAveraging callback with a model that overrides the configure_model() hook. "
157+ "WeightAveraging will construct the model and the average model in CPU memory, so you may run out "
158+ "of memory during initialization."
154159 )
155160 pl_module .configure_model ()
156161
@@ -178,8 +183,7 @@ def on_train_batch_end(
178183 # make step_idx consistent with epoch_idx, we'll pass a zero-based index.
179184 step_idx = trainer .global_step - 1
180185 if (trainer .global_step > self ._latest_update_step ) and self .should_update (step_idx = step_idx ):
181- assert self ._average_model is not None
182- self ._average_model .update_parameters (pl_module )
186+ self ._update_average_model (trainer , pl_module )
183187 self ._latest_update_step = trainer .global_step
184188
185189 @override
@@ -194,8 +198,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
194198
195199 """
196200 if (trainer .current_epoch > self ._latest_update_epoch ) and self .should_update (epoch_idx = trainer .current_epoch ):
197- assert self ._average_model is not None
198- self ._average_model .update_parameters (pl_module )
201+ self ._update_average_model (trainer , pl_module )
199202 self ._latest_update_epoch = trainer .current_epoch
200203
201204 @override
@@ -210,7 +213,7 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
210213
211214 """
212215 assert self ._average_model is not None
213- self ._copy_average_to_current (pl_module )
216+ self ._copy_average_to_current (trainer , pl_module )
214217
215218 @override
216219 def on_validation_epoch_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
@@ -224,7 +227,7 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn
224227
225228 """
226229 if self ._average_model is not None :
227- self ._swap_models (pl_module )
230+ self ._swap_models (trainer , pl_module )
228231
229232 @override
230233 def on_validation_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
@@ -238,7 +241,7 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin
238241
239242 """
240243 if self ._average_model is not None :
241- self ._swap_models (pl_module )
244+ self ._swap_models (trainer , pl_module )
242245
243246 @override
244247 def state_dict (self ) -> dict [str , Any ]:
@@ -334,7 +337,23 @@ def on_load_checkpoint(
334337 )
335338 self ._average_model .module .load_state_dict (deepcopy (checkpoint ["state_dict" ]), strict = False )
336339
337- def _swap_models (self , pl_module : "pl.LightningModule" ) -> None :
340+ def _update_average_model (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
341+ """Updates the :class:`AveragedModel` parameters.
342+
343+ Args:
344+ trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
345+ pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
346+
347+ """
348+ assert self ._average_model is not None
349+ if isinstance (trainer .strategy , FSDPStrategy ):
350+ assert isinstance (trainer .strategy .model , nn .Module )
351+ with FullyShardedDataParallel .summon_full_params (trainer .strategy .model ):
352+ self ._average_model .update_parameters (trainer .strategy .model )
353+ else :
354+ self ._average_model .update_parameters (pl_module )
355+
356+ def _swap_models (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
338357 """Swaps the parameter values of the current model and the :class:`AveragedModel`.
339358
340359 Args:
@@ -343,13 +362,25 @@ def _swap_models(self, pl_module: "pl.LightningModule") -> None:
343362 """
344363 assert self ._average_model is not None
345364 average_params = itertools .chain (self ._average_model .module .parameters (), self ._average_model .module .buffers ())
346- current_params = itertools .chain (pl_module .parameters (), pl_module .buffers ())
347- for average_param , current_param in zip (average_params , current_params ):
348- tmp = average_param .data .clone ()
349- average_param .data .copy_ (current_param .data )
350- current_param .data .copy_ (tmp )
351365
352- def _copy_average_to_current (self , pl_module : "pl.LightningModule" ) -> None :
366+ def _swap_param (a : nn .Parameter | Tensor , b : nn .Parameter | Tensor ) -> None :
367+ tmp = a .data .clone ()
368+ a .data .copy_ (b .data )
369+ b .data .copy_ (tmp )
370+
371+ def _swap (model : nn .Module ) -> None :
372+ current_params = itertools .chain (model .parameters (), model .buffers ())
373+ for average_param , current_param in zip (average_params , current_params ):
374+ _swap_param (average_param , current_param )
375+
376+ if isinstance (trainer .strategy , FSDPStrategy ):
377+ assert isinstance (trainer .strategy .model , nn .Module )
378+ with FullyShardedDataParallel .summon_full_params (trainer .strategy .model ):
379+ _swap (trainer .strategy .model )
380+ else :
381+ _swap (pl_module )
382+
383+ def _copy_average_to_current (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
353384 """Copies the parameter values from the :class:`AveragedModel` to the current model.
354385
355386 Args:
@@ -358,9 +389,18 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
358389 """
359390 assert self ._average_model is not None
360391 average_params = itertools .chain (self ._average_model .module .parameters (), self ._average_model .module .buffers ())
361- current_params = itertools .chain (pl_module .parameters (), pl_module .buffers ())
362- for average_param , current_param in zip (average_params , current_params ):
363- current_param .data .copy_ (average_param .data )
392+
393+ def _copy (model : nn .Module ) -> None :
394+ current_params = itertools .chain (model .parameters (), model .buffers ())
395+ for average_param , current_param in zip (average_params , current_params ):
396+ current_param .data .copy_ (average_param .data )
397+
398+ if isinstance (trainer .strategy , FSDPStrategy ):
399+ assert isinstance (trainer .strategy .model , nn .Module )
400+ with FullyShardedDataParallel .summon_full_params (trainer .strategy .model ):
401+ _copy (trainer .strategy .model )
402+ else :
403+ _copy (pl_module )
364404
365405
366406class EMAWeightAveraging (WeightAveraging ):
0 commit comments