@@ -332,6 +332,7 @@ def __init__(
332332 offload_config : Optional [OffloadConfig ] = None ,
333333 state_dict_on_rank_0_only : bool = False ,
334334 gradient_predivide_factor : Optional [float ] = None ,
335+ zero2_process_group : Optional [ProcessGroup ] = None ,
335336 ):
336337 try :
337338 import torch ._C
@@ -380,6 +381,9 @@ def __init__(
380381 "parameter uses all the available ranks for the optimal performance."
381382 )
382383 self .reshard_after_forward = self ._orig_reshard_after_forward = reshard_after_forward
384+
385+ self .zero2_process_group = zero2_process_group
386+
383387 self .disable_reshard_on_root = disable_reshard_on_root
384388 self .mixed_precision = mixed_precision
385389 self .fp32_reduce_scatter = fp32_reduce_scatter
@@ -518,6 +522,9 @@ def __init__(
518522 if isinstance (m , FullyShardedDataParallel ):
519523 m ._free_ssd_offload ()
520524
525+ if self .zero2_process_group is not None :
526+ assert not self .move_params_to_cpu
527+
521528 def _get_gradient_predivide_factor (self , world_size : int ) -> float :
522529 factor : int = 1
523530 while world_size % factor == 0 and world_size / factor > factor :
@@ -1419,7 +1426,10 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
14191426 outputs = self .module (* args , ** kwargs )
14201427
14211428 if self .reshard_after_forward :
1422- self ._free_full_params ()
1429+ if self .zero2_process_group is not None :
1430+ self ._zero2_shard_to_smaller_group ()
1431+ else :
1432+ self ._free_full_params ()
14231433 if self .mixed_precision or self .move_params_to_cpu :
14241434 self ._free_fp16_param_shard ()
14251435
@@ -1499,7 +1509,10 @@ def _pre_backward_hook(*unused: Any) -> None:
14991509 # idempotent. So in case they are called unnecessarily, they don't incur much
15001510 # overhead.
15011511 if self .reshard_after_forward :
1502- self ._rebuild_full_params ()
1512+ if self .zero2_process_group is not None :
1513+ self ._zero2_rebuild_full_params ()
1514+ else :
1515+ self ._rebuild_full_params ()
15031516 if (
15041517 self .reshard_after_forward
15051518 and self ._fsdp_forward_ordering is not None
@@ -2006,6 +2019,126 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
20062019 torch .cuda .current_stream ().wait_stream (self ._streams ["all_gather" ])
20072020 return output_tensors
20082021
2022+
2023+ @torch .no_grad ()
2024+ def _zero2_rebuild_full_params (self , force_full_precision : bool = False , wait_for_all_gather = True ) -> Optional [List [Tuple [torch .Tensor , bool ]]]:
2025+ """
2026+ Gather all shards of params.
2027+
2028+ Note, this is idempotent if full params are already gathered. Callers
2029+ assume the idempotency. So please keep it that way.
2030+
2031+ Args:
2032+ force_full_precision (bool, Optional): by default params will be gathered
2033+ in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
2034+ ``True``, in which case they will be gathered in full precision
2035+ (e.g., FP32), possibly in fresh storage. The parameter that's being
2036+ rebuilt will end up in full precision as well.
2037+
2038+ Returns:
2039+ A list of tuples, where the first element is the full-sized param
2040+ and the second element is a bool indicating if it's safe for the
2041+ caller to free the full-sized param. This will be ``None`` if
2042+ ``force_full_precision=False`` and the full params are already gathered.
2043+ """
2044+ output_tensors : List [Tuple [torch .Tensor , bool ]] = []
2045+
2046+ def update_p_data (custom_output_tensor : Optional [torch .Tensor ] = None ) -> None :
2047+ """
2048+ Helper function to update p.data pointer.
2049+
2050+ Args:
2051+ custom_output_tensor (torch.Tensor, Optional): if not None, this
2052+ tensor contains the data we just gathered.
2053+ """
2054+ if custom_output_tensor is not None :
2055+ assert p ._is_sharded
2056+ p .data = custom_output_tensor
2057+ output_tensors .append ((p .data , True ))
2058+ elif not p ._is_sharded :
2059+ if (self .mixed_precision or self .move_params_to_cpu ) and not force_full_precision :
2060+ assert p ._fp16_shard is not None
2061+ p .data = p ._fp16_shard
2062+ output_tensors .append ((p .data , True ))
2063+ else :
2064+ # Here p.data == p._fp32_shard, so it's not safe to free.
2065+ output_tensors .append ((p .data , False ))
2066+ else :
2067+ p .data = p ._full_param_padded
2068+ output_tensors .append ((p .data , True ))
2069+ # Trim any padding and reshape to match original size.
2070+ p .data = p .data [: p ._orig_size .numel ()].view (p ._orig_size )
2071+
2072+ if self ._has_shared_params :
2073+ # self.has_full_params flag can be out of sync if a shared param is
2074+ # sharded by another FSDP instance. An example is that in eval case
2075+ # with reshard_after_forward=False but the sharing instance has
2076+ # reshard_after_forward=True. Then, on the second forward, the
2077+ # other instance can shard the shared param and but this instance
2078+ # can mistakenly think the full param is already gathered from the
2079+ # has_full_params flag.
2080+ #
2081+ # Therefore, we update the flag accordingly here.
2082+ self .has_full_params = not any (p ._full_param_padded .storage ().size () == 0 for p in self .params )
2083+
2084+ # Early exit if we already have full params and don't need full precision.
2085+ if self .has_full_params and not force_full_precision :
2086+ if wait_for_all_gather :
2087+ torch .cuda .current_stream ().wait_stream (self ._streams ["all_gather" ])
2088+ for p in self .params :
2089+ update_p_data ()
2090+ return output_tensors
2091+
2092+ self .has_full_params = True
2093+
2094+ with torch .cuda .stream (self ._streams ["all_gather" ]):
2095+
2096+ for p in self .params :
2097+ if not p ._is_sharded : # e.g., when world_size == 1
2098+ update_p_data ()
2099+ else :
2100+ # Skip if already built. Only shared param can be rebuilt multiple times.
2101+ # A corner case is p._orig_size = (1,), which means the shape equality is
2102+ # not a perfect check. But we assume we don't share a param with shape (1,).
2103+ if p .data .shape == p ._orig_size and hasattr (p , "_is_shared" ) and p ._is_shared :
2104+ continue
2105+ # If self.move_params_to_cpu and force_full_precision, we need to cast
2106+ # the FP32 CPU param to CUDA for the all-gather.
2107+ p_data = p .data .to (p ._full_param_padded .device , non_blocking = True )
2108+
2109+ p_size = p ._full_param_padded .size ()
2110+ assert p_size .numel () % self .world_size == 0
2111+ if self .mixed_precision and force_full_precision :
2112+ # Allocate fresh tensor in full precision since we are in
2113+ # mixed precision and full precision rebuild is asked.
2114+ output_tensor = p_data .new_zeros (p_size )
2115+ else :
2116+ if p ._full_param_padded .storage ().size () != p_size .numel ():
2117+ # Allocate based on full size from all shards.
2118+ alloc_storage_ (p ._full_param_padded , size = p_size )
2119+ output_tensor = p ._full_param_padded
2120+
2121+ # Fill output_tensor with (p.data for each shard in self.world_size)
2122+ if hasattr (dist , "_all_gather_base" ) and enable_nccl_base_collectives :
2123+ # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
2124+ dist ._all_gather_base (output_tensor , p ._zero2_fp16_shard , group = self .zero2_process_group )
2125+ else :
2126+ chunks = list (output_tensor .chunk (self .world_size ))
2127+ dist .all_gather (chunks , p ._zero2_fp16_shard , group = self .zero2_process_group )
2128+
2129+ # Set p.data = output_tensor (with padding trimmed)
2130+ update_p_data (output_tensor )
2131+
2132+ if (self .mixed_precision or self .move_params_to_cpu ) and not force_full_precision :
2133+ self ._free_zero2_param_shard ([p ])
2134+
2135+ if self .move_params_to_cpu and (self .params [0 ].dtype == self .compute_dtype ):
2136+ self ._free_zero2_param_shard ([p ])
2137+ if wait_for_all_gather :
2138+ torch .cuda .current_stream ().wait_stream (self ._streams ["all_gather" ])
2139+ return output_tensors
2140+
2141+
20092142 @torch .no_grad ()
20102143 def _use_full_params (self ) -> None :
20112144 """
@@ -2074,6 +2207,38 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
20742207 free_storage_ (p ._full_param_padded )
20752208 torch .cuda .current_stream ().synchronize ()
20762209
2210+
2211+ def _zero2_shard_to_smaller_group (self , params : Optional [List [Parameter ]] = None ):
2212+ if params is None :
2213+ params = self .params
2214+ self .has_full_params = False
2215+ current_stream = torch .cuda .current_stream ()
2216+ for p in params :
2217+ if not p ._is_sharded : # e.g., world_size == 1
2218+ if self .mixed_precision or self .move_params_to_cpu :
2219+ self ._free_fp16_param_shard ([p ])
2220+ continue
2221+ # Cases for when zero2 world size > 1 but less than zero3 size
2222+ zero2_world_size = dist .get_world_size (self .zero2_process_group )
2223+ zero2_rank = dist .get_rank (self .zero2_process_group )
2224+ chunks = p ._full_param_padded .chunk (zero2_world_size )
2225+
2226+ p ._zero2_fp16_shard = torch .empty_like (chunks [zero2_rank ])
2227+ p ._zero2_fp16_shard .copy_ (chunks [zero2_rank ])
2228+
2229+ # Don't let PyTorch reuse this memory until all work in the current
2230+ # stream is complete.
2231+ p ._full_param_padded .record_stream (current_stream )
2232+ # There may be external references to the Tensor Storage that we
2233+ # can't modify, such as references that are created by
2234+ # ctx.save_for_backward in the forward pass. Thus when we
2235+ # unshard parameters, we should reuse the original Tensor
2236+ # Storage object and unshard it in-place. For now, just resize
2237+ # the Storage to 0 to save memory.
2238+ free_storage_ (p ._full_param_padded )
2239+ torch .cuda .current_stream ().synchronize ()
2240+
2241+
20772242 def local_metadata_dict (self ) -> Dict [str , Any ]:
20782243 """
20792244 Get the information needed to reconstruct the model from shards offline.
@@ -2238,6 +2403,19 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No
22382403 p ._fp16_shard .record_stream (current_stream )
22392404 free_storage_ (p ._fp16_shard )
22402405
2406+ @torch .no_grad ()
2407+ def _free_zero2_param_shard (self , params : Optional [List [Parameter ]] = None ) -> None :
2408+ """Free storage for FP16 shards for a list of params."""
2409+ if params is None :
2410+ params = self .params
2411+ current_stream = torch .cuda .current_stream ()
2412+ for p in params :
2413+ if p ._zero2_fp16_shard is not None :
2414+ # _fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
2415+ # free it until the work in the current stream completes.
2416+ p ._zero2_fp16_shard .record_stream (current_stream )
2417+ free_storage_ (p ._zero2_fp16_shard )
2418+
22412419 def assert_state (self , state : Union [TrainingState , List [TrainingState ]]) -> None :
22422420 """Assert we are in the given state."""
22432421 # Since assert can be turned off and this error checking
0 commit comments