2626"""
2727
2828import concurrent .futures
29+ import gc
2930import logging
3031import os
3132import socket
@@ -185,6 +186,7 @@ def __init__(
185186 init_sync : bool = True ,
186187 max_retries : Optional [int ] = None ,
187188 quorum_retries : int = 0 ,
189+ dataloader_fn : Optional [Callable [[int , int , int ], None ]] = None ,
188190 ) -> None :
189191 """
190192 Args:
@@ -365,6 +367,17 @@ def __init__(
365367
366368 self ._update_fr_path ()
367369
370+ # The number of batches committed in the current epoch.Compare to _batches_committed,
371+ # _current_batches_committed will reset to 0 when next epoch starts.
372+ self ._current_batches_committed = 0
373+ self ._epoch = 0
374+ self ._loaded_epoch = 0
375+ self ._loaded_current_batches_committed = 0
376+ self ._dataloader_fn = dataloader_fn
377+ self ._dataloader_dirty = False
378+ self ._dataloader_iter = None
379+ self ._accumulation_steps = 1
380+
368381 def allow_state_dict_read (self ) -> None :
369382 if self ._is_state_dict_read_allowed :
370383 return
@@ -438,6 +451,12 @@ def allreduce(
438451 return _DummyWork (tensor )
439452
440453 self .wait_quorum ()
454+
455+ # If dirty, the result will not be committed, so return empty tensor.
456+ if self ._dataloader_dirty :
457+ work = _DummyWork (torch .zeros_like (tensor ))
458+ return _ManagedWork (self , work , tensor )
459+
441460 num_participants : int = self .num_participants ()
442461
443462 if not self .is_participating ():
@@ -678,6 +697,8 @@ def _async_quorum(
678697 if self ._use_async_quorum or not allow_heal
679698 else (replica_rank , replica_world_size )
680699 )
700+ self ._replica_rank = replica_rank
701+ self ._replica_world_size = replica_world_size
681702
682703 # For fixed with spares we need to ensure that we don't have more
683704 # participating replicas than the min replica size.
@@ -691,6 +712,7 @@ def _async_quorum(
691712 ):
692713 self ._participating_replica_rank = None
693714
715+ quorum_changed = False
694716 if quorum_id != self ._quorum_id :
695717 self .quorum_logger .info (
696718 "" ,
@@ -737,6 +759,7 @@ def _async_quorum(
737759 self ._logger .exception (f"got exception in pg configure: { e } " )
738760 self .report_error (e )
739761 return
762+ quorum_changed = True
740763
741764 if allow_heal :
742765 # run recovery on the recovery stream if available
@@ -807,6 +830,38 @@ def _async_quorum(
807830 else None
808831 )
809832
833+ # reconfigure dataloader after healing so that we can get offset from other replica group
834+ if quorum_changed and self ._dataloader_fn :
835+ self .reconfigure_dataloader ()
836+ self ._dataloader_dirty = True
837+
838+ def get_batch_samples (self , epoch = 0 , num_batches = None , batch_size = None , total_batch_size = None ):
839+ # In general, `start_quorum` might not have been called during the first loop,
840+ # and the dataloader might not have been initialized yet. In this case, we should
841+ # return immediately and set the dirty flag to avoid computation and commit.
842+ if not self ._dataloader_iter :
843+ self ._dataloader_dirty = True
844+ return []
845+ # If the recovery worker is behind the current epoch, we should skip computation and commit.
846+ if epoch < self ._loaded_epoch :
847+ return None
848+
849+ if total_batch_size != None and batch_size != None :
850+ num_batches = total_batch_size // (batch_size * self ._replica_world_size )
851+
852+ assert num_batches is not None , ("num_batches must be specified or "
853+ "total_batch_size and batch_size must be specified" )
854+
855+ batch_samples = []
856+ for _ in range (num_batches ):
857+ try :
858+ batch_samples .append (next (self ._dataloader_iter ))
859+ except StopIteration :
860+ break
861+ self ._dataloader_dirty = False
862+ self ._accumulation_steps = len (batch_samples )
863+ return batch_samples if batch_samples else None
864+
810865 def _update_fr_path (self ) -> None :
811866 """
812867 Update the path that flight recorder will dump the traces to.
@@ -921,9 +976,14 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
921976
922977 # decide whether we're in a healthy state to increase the step count
923978 if should_commit :
924- self ._step += 1
925- self ._batches_committed += self .num_participants ()
926979 self ._commit_failures = 0 # Reset failure counter on success
980+ if not self ._dataloader_dirty :
981+ self ._step += 1
982+ self ._batches_committed += self .num_participants () * self ._accumulation_steps
983+ self ._current_batches_committed += self .num_participants () * self ._accumulation_steps
984+ return True
985+ else :
986+ return False
927987 else :
928988 self ._commit_failures += 1
929989 # Check if we've hit max retries
@@ -934,8 +994,7 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
934994 msg = f"should_commit failed { self ._commit_failures } times consecutively, exceeding max_retries={ self ._max_retries } "
935995 self ._logger .exception (msg )
936996 raise RuntimeError (msg )
937-
938- return should_commit
997+ return False
939998
940999 def load_state_dict (self , state_dict : Dict [str , int ]) -> None :
9411000 """
@@ -948,6 +1007,11 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
9481007 """
9491008 self ._step = state_dict ["step" ]
9501009 self ._batches_committed = state_dict ["batches_committed" ]
1010+ self ._loaded_epoch = state_dict ["epoch" ]
1011+ self ._loaded_current_batches_committed = state_dict ["current_batches_committed" ]
1012+ if self ._loaded_epoch == 0 :
1013+ self ._epoch = 0
1014+ self ._current_batches_committed = self ._loaded_current_batches_committed
9511015
9521016 def _manager_state_dict (self ) -> Dict [str , object ]:
9531017 with self ._state_dict_lock .r_lock ():
@@ -969,7 +1033,8 @@ def state_dict(self) -> Dict[str, int]:
9691033 Returns:
9701034 the state dict for this manager
9711035 """
972- return {"step" : self ._step , "batches_committed" : self ._batches_committed }
1036+ return {"step" : self ._step , "batches_committed" : self ._batches_committed ,
1037+ "epoch" : self ._epoch , "current_batches_committed" : self ._current_batches_committed }
9731038
9741039 def current_step (self ) -> int :
9751040 """
@@ -1047,6 +1112,23 @@ def is_participating(self) -> bool:
10471112 return False
10481113 return True
10491114
1115+ def reconfigure_dataloader (self ):
1116+ dataloader = self ._dataloader_fn (self ._replica_world_size ,
1117+ self ._replica_rank , self ._current_batches_committed )
1118+ dataloader .sampler .set_epoch (self ._epoch )
1119+ self ._dataloader_iter = iter (dataloader )
1120+ # cleanup for old dataloader
1121+ gc .collect ()
1122+
1123+ def next_epoch (self ):
1124+ self ._epoch += 1
1125+ if self ._loaded_epoch == self ._epoch :
1126+ self ._current_batches_committed = self ._loaded_current_batches_committed
1127+ else :
1128+ self ._current_batches_committed = 0
1129+ if self ._dataloader_fn :
1130+ self .reconfigure_dataloader ()
1131+ self ._dataloader_dirty = False
10501132
10511133class _ManagerLogger :
10521134 def __init__ (self , manager : Manager , replica_id : str , group_rank : int ) -> None :
0 commit comments