Skip to content

Commit 437c7df

Browse files
committed
Ensured that we can flattend the I/O tensors.
1 parent 30b30a0 commit 437c7df

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

flashinfer/comm/allreduce.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -668,12 +668,31 @@ def allreduce_fusion(
668668

669669
# Flatten all tensors to 1D for legacy trtllm_allreduce_fusion API
670670
# The legacy API expects flattened tensors and explicit token_num/hidden_dim
671-
input_flat = input.flatten()
672-
output_flat = output.flatten()
673-
residual_in_flat = residual_in.flatten() if residual_in is not None else None
674-
residual_out_flat = residual_out.flatten() if residual_out is not None else None
675-
norm_out_flat = norm_out.flatten() if norm_out is not None else None
676-
quant_out_flat = quant_out.flatten() if quant_out is not None else None
671+
# We require contiguous tensors so that view(-1) creates a view (not a copy),
672+
# ensuring writes to the flattened tensors are reflected in the original 2D tensors
673+
def _flatten_checked(t, name):
674+
if not t.is_contiguous():
675+
raise ValueError(f"{name} must be contiguous")
676+
return t.view(-1)
677+
678+
input_flat = _flatten_checked(input, "input")
679+
output_flat = _flatten_checked(output, "output")
680+
residual_in_flat = (
681+
_flatten_checked(residual_in, "residual_in")
682+
if residual_in is not None
683+
else None
684+
)
685+
residual_out_flat = (
686+
_flatten_checked(residual_out, "residual_out")
687+
if residual_out is not None
688+
else None
689+
)
690+
norm_out_flat = (
691+
_flatten_checked(norm_out, "norm_out") if norm_out is not None else None
692+
)
693+
quant_out_flat = (
694+
_flatten_checked(quant_out, "quant_out") if quant_out is not None else None
695+
)
677696

678697
# Call legacy API with flattened tensors
679698
# Note: pattern and layout_code are ints but legacy API uses pseudo-type hints

0 commit comments

Comments
 (0)