@@ -668,12 +668,31 @@ def allreduce_fusion(
668668
669669 # Flatten all tensors to 1D for legacy trtllm_allreduce_fusion API
670670 # The legacy API expects flattened tensors and explicit token_num/hidden_dim
671- input_flat = input .flatten ()
672- output_flat = output .flatten ()
673- residual_in_flat = residual_in .flatten () if residual_in is not None else None
674- residual_out_flat = residual_out .flatten () if residual_out is not None else None
675- norm_out_flat = norm_out .flatten () if norm_out is not None else None
676- quant_out_flat = quant_out .flatten () if quant_out is not None else None
671+ # We require contiguous tensors so that view(-1) creates a view (not a copy),
672+ # ensuring writes to the flattened tensors are reflected in the original 2D tensors
673+ def _flatten_checked (t , name ):
674+ if not t .is_contiguous ():
675+ raise ValueError (f"{ name } must be contiguous" )
676+ return t .view (- 1 )
677+
678+ input_flat = _flatten_checked (input , "input" )
679+ output_flat = _flatten_checked (output , "output" )
680+ residual_in_flat = (
681+ _flatten_checked (residual_in , "residual_in" )
682+ if residual_in is not None
683+ else None
684+ )
685+ residual_out_flat = (
686+ _flatten_checked (residual_out , "residual_out" )
687+ if residual_out is not None
688+ else None
689+ )
690+ norm_out_flat = (
691+ _flatten_checked (norm_out , "norm_out" ) if norm_out is not None else None
692+ )
693+ quant_out_flat = (
694+ _flatten_checked (quant_out , "quant_out" ) if quant_out is not None else None
695+ )
677696
678697 # Call legacy API with flattened tensors
679698 # Note: pattern and layout_code are ints but legacy API uses pseudo-type hints
0 commit comments