@@ -1714,39 +1714,33 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17141714 # Switch to FP32 shard after backward.
17151715 self ._use_fp32_param_shard ([param ])
17161716
1717- if self .fp32_reduce_scatter :
1718- if param .grad is not None :
1719- if param .main_grad is not None :
1720- param .main_grad .add_ (param .grad .data .float ())
1721- else :
1722- param .main_grad = param .grad .data .float ()
1723- param .grad = None
1724-
17251717 if not self ._require_backward_grad_sync :
17261718 return
17271719
17281720 # Wait for all work in the current stream to finish, then start the
17291721 # reductions in post_backward stream.
17301722 self ._streams ["post_backward" ].wait_stream (torch .cuda .current_stream ())
17311723 with torch .cuda .stream (self ._streams ["post_backward" ]):
1732- # orig_grad_data = param.main_grad.data
1724+ if param .main_grad is not None :
1725+ orig_grad_data = param .main_grad
1726+ else :
1727+ orig_grad_data = param .grad
17331728
17341729 if self .fp32_reduce_scatter :
1735- # Cast grad to FP32. with .main_grad params are already in FP32.
1736- if param .main_grad is not None :
1737- orig_grad_data = param .main_grad .data
1738- else :
1739- orig_grad_data = param .grad .data .to (torch .float32 )
1740- else :
1741- orig_grad_data = param .grad .data
1730+ if param .grad is not None :
1731+ if param .main_grad is not None :
1732+ param .main_grad .copy_ (param .grad .float ())
1733+ else :
1734+ param .main_grad = param .grad .float ()
1735+ param .grad = None
17421736
17431737 if self .gradient_predivide_factor > 1 :
17441738 # Average grad by world_size for consistency with PyTorch DDP.
17451739 # param.grad.data.div_(self.gradient_predivide_factor)
17461740 if param .main_grad is not None :
1747- param .main_grad .data . div_ (self .gradient_predivide_factor )
1741+ param .main_grad .div_ (self .gradient_predivide_factor )
17481742 else :
1749- param .grad .data . div_ (self .gradient_predivide_factor )
1743+ param .grad .div_ (self .gradient_predivide_factor )
17501744
17511745 if param ._is_sharded :
17521746 assert self ._reducer is not None
@@ -1755,10 +1749,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17551749 # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
17561750 # matter, neglecting rounding.
17571751 if param .main_grad is not None :
1758- grad = param .main_grad . data
1752+ grad = param .main_grad
17591753 param .main_grad = None
17601754 else :
1761- grad = param .grad . data
1755+ grad = param .grad
17621756 # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
17631757 #
17641758 # The effect on memory consumption is not usually significant. No extra memory is allocated if this
@@ -1781,9 +1775,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
17811775 # case grads should be all-reduced here.
17821776 assert self .world_size == 1
17831777 if param .main_grad is not None :
1784- self ._post_reduction_hook (param , param .main_grad . data )
1778+ self ._post_reduction_hook (param , param .main_grad )
17851779 else :
1786- self ._post_reduction_hook (param , param .grad . data )
1780+ self ._post_reduction_hook (param , param .grad )
17871781
17881782 # After _post_backward_hook returns, orig_grad_data will eventually
17891783 # go out of scope, at which point it could otherwise be freed for
0 commit comments