diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index 98d77550a8..5f186002dc 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -39,6 +39,17 @@ from .vllm_ar import register_buffer as vllm_register_buffer from .vllm_ar import register_graph_buffers as vllm_register_graph_buffers +# Unified AllReduce Fusion API +from .allreduce import AllReduceFusionWorkspace as AllReduceFusionWorkspace +from .trtllm_mnnvl_ar import ( + MNNVLAllReduceFusionWorkspace as MNNVLAllReduceFusionWorkspace, +) +from .allreduce import TRTLLMAllReduceFusionWorkspace as TRTLLMAllReduceFusionWorkspace +from .allreduce import allreduce_fusion as allreduce_fusion +from .allreduce import ( + create_allreduce_fusion_workspace as create_allreduce_fusion_workspace, +) + # MNNVL A2A (Throughput Backend) from .trtllm_moe_alltoall import MoeAlltoAll as MoeAlltoAll from .trtllm_moe_alltoall import moe_a2a_combine as moe_a2a_combine diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py new file mode 100644 index 0000000000..cfafa99220 --- /dev/null +++ b/flashinfer/comm/allreduce.py @@ -0,0 +1,702 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Unified AllReduce Fusion API + +This module provides a unified interface for AllReduce + RMSNorm fusion operations +across different backends (TensorRT-LLM, MNNVL). + +Example usage: + >>> # Auto-select best backend based on topology + >>> workspace = create_allreduce_fusion_workspace( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) + >>> + >>> # Perform AllReduce + RMSNorm fusion + >>> prenorm = torch.empty_like(hidden_states) + >>> normed = torch.empty_like(hidden_states) + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... launch_with_pdl=True, + ... residual_out=prenorm, + ... norm_out=normed, + ... residual_in=residual, + ... rms_gamma=norm_weight + ... ) + >>> + >>> workspace.destroy() +""" + +from typing import Union, Literal, Optional, Tuple, List, cast, Any +from .workspace_base import AllReduceFusionWorkspace + +import torch + +from .trtllm_ar import trtllm_allreduce_fusion +from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion +from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion +from .trtllm_ar import check_trtllm_allreduce_fusion_workspace_metadata + +from .mapping import Mapping + +from .mnnvl import CommBackend + +# Note: AllReduceFusionPattern and QuantizationSFLayout are pseudo-types (classes with int constants) +# Import them for runtime use but type hint as int for mypy compatibility +from .trtllm_ar import AllReduceFusionPattern +from .trtllm_mnnvl_ar import MNNVLAllReduceFusionWorkspace +from .trtllm_mnnvl_ar import trtllm_mnnvl_allreduce +from .trtllm_mnnvl_ar import trtllm_mnnvl_fused_allreduce_add_rmsnorm + +# ============================================================================ +# WORKSPACE IMPLEMENTATIONS +# ============================================================================ +# +# Workspace classes wrap the underlying backend workspace implementations: +# - TRTLLMAllReduceFusionWorkspace: Wraps trtllm_create_ipc_workspace_for_all_reduce_fusion +# - MNNVLAllReduceFusionWorkspace: Wraps MNNVL workspace (see trtllm_mnnvl_ar.py) +# +# Each workspace: +# 1. Calls the backend-specific workspace creation function in __init__ +# 2. Stores the internal workspace as _internal_workspace +# 3. Exposes essential attributes for the unified API +# 4. Can be destroyed using workspace.destroy() +# ============================================================================ + + +class TRTLLMAllReduceFusionWorkspace(AllReduceFusionWorkspace): + """TensorRT-LLM workspace for AllReduce fusion.""" + + def __init__( + self, + tp_size: int, + tp_rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype = torch.float16, + process_group: Optional["torch.distributed.ProcessGroup"] = None, + ): + """ + Create TensorRT-LLM AllReduce fusion workspace. + + Args: + tp_size: Tensor parallel size (world size) + tp_rank: Tensor parallel rank + max_token_num: Maximum number of tokens + hidden_dim: Hidden dimension size + dtype: Data type + process_group: PyTorch distributed process group + **kwargs: Additional arguments for workspace creation + """ + super().__init__(tp_size, tp_rank) + + # Call the actual workspace creation function + self._internal_workspace = trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=tp_rank, + tp_size=tp_size, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + group=process_group, + create_metadata=True, + use_fp32_lamport=dtype == torch.float32, + ) + + # Store essential attributes for easy access + # Cast to 3-tuple to make linter happy, since we always call with create_metadata=True + workspace_tuple = cast( + Tuple[List[List[int]], torch.Tensor, dict], self._internal_workspace + ) + self.ipc_handles = workspace_tuple[0] + self.workspace_tensor = workspace_tuple[1] + self.metadata = workspace_tuple[2] + + @property + def backend(self) -> str: + return "trtllm" + + def __getattr__(self, name): + """Delegate attribute access to internal workspace if not found.""" + if name.startswith("_"): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + return getattr(self._internal_workspace, name) + + def is_buffer_size_sufficient( + self, + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + use_oneshot: Optional[Any] = None, + ) -> bool: + try: + check_trtllm_allreduce_fusion_workspace_metadata( + num_tokens, hidden_dim, tp_size, dtype, self.metadata + ) + return True + except ValueError as e: + print(f"Workspace is insufficient for problem size. {e}") + return False + + def destroy(self) -> None: + """Destroy workspace and free resources.""" + if getattr(self, "_destroyed", False): + return # Already destroyed, nothing to do + + trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self.ipc_handles) + self._destroyed = True + + +# ============================================================================ +# BACKEND CHECKS - Hard requirements for backend selection +# ============================================================================ + + +def _trtllm_workspace_check( + backend: str, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + topology: Literal["single_node", "multi_node"], +) -> bool: + """ + Check if trtllm backend CAN be used for workspace creation. + + Hard requirements: + - Single-node topology (multi-node not supported) + + """ + # trtllm is optimized for single-node + if topology == "multi_node": + return False + + return True + + +def _mnnvl_workspace_check( + backend: str, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + topology: Literal["single_node", "multi_node"], +) -> bool: + """ + Check if mnnvl backend CAN be used for workspace creation. + + """ + + if topology == "multi_node": + return True + + return True + + +# ============================================================================ +# HEURISTIC - Performance-based backend selection +# ============================================================================ + + +def _workspace_creation_heuristic( + suitable_backends: list[str], + backend: str, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + topology: Literal["single_node", "multi_node"], +) -> list[str]: + """ + Select best backend for workspace creation based on performance. + + Called by decorator after checking which backends pass requirements. + Uses benchmarking data to pick fastest option. + + Args: + suitable_backends: List of backends that passed hard requirement checks + backend: Requested backend ("auto", "trtllm", or "mnnvl") + world_size: Number of ranks + rank: Current rank + max_token_num: Maximum number of tokens + hidden_dim: Hidden dimension size + dtype: Data type + topology: Network topology ("single_node" or "multi_node") + **kwargs: Additional arguments + + Note that at this point, the backend selection does not take "runtime parameters" into account, such as layout_code, and fusion pattern. + + Returns: + List containing the selected backend (single element) + """ + if not suitable_backends: + return [] + + if len(suitable_backends) == 1: + return suitable_backends + + # Decision tree based on benchmark data + + # Multi-node: MNNVL is designed for this + if topology == "multi_node": + if "mnnvl" in suitable_backends: + return ["mnnvl"] + else: + return [suitable_backends[0]] + + # Single-node scenarios + # From benchmarking data, we can see that MNNVL is either on par (smaller problem sizes) or significantly faster than TRTLLM (larger problem sizes such as hidden_dim=8192, token_num=64 for TP=4), for single-node scenarios. + # However, trtllm has a larger support surface (more fusion patterns, more quantization support, etc.) + if "mnnvl" in suitable_backends: + return ["mnnvl"] + else: + return [suitable_backends[0]] + + +# ============================================================================ +# WORKSPACE CREATION +# ============================================================================ + + +def create_allreduce_fusion_workspace( + backend: Literal["trtllm", "mnnvl", "auto"] = "auto", + world_size: int = None, + rank: int = None, + max_token_num: int = None, + hidden_dim: int = None, + dtype: torch.dtype = None, + topology: Literal["single_node", "multi_node"] = "single_node", + process_group: Optional["torch.distributed.ProcessGroup"] = None, + gpus_per_node: int = None, + comm_backend: Optional[CommBackend] = None, +) -> AllReduceFusionWorkspace: + """ + Create workspace for AllReduce fusion operations. + + Backend selection uses topology-based checks and heuristics. + + **Important: Workspace Reusability** + The workspace is allocated based on the total size (max_token_num * hidden_dim * dtype_size). + You can reuse the same workspace with different shapes as long as the total size fits: + + - Workspace(max_token_num=2048, hidden_dim=4096) can handle: + - (token_num=2048, hidden_dim=4096) ✓ + - (token_num=1024, hidden_dim=4096) ✓ + - (token_num=4096, hidden_dim=2048) ✓ (same total size) + - (token_num=1024, hidden_dim=8192) ✓ (same total size) + - (token_num=4096, hidden_dim=4096) ✗ (too large) + + Use `workspace.is_buffer_size_sufficient(token_num, hidden_dim, dtype)` to check before use. + + Args: + backend: Backend to use ("trtllm", "mnnvl", or "auto") + "auto" uses heuristic to select best backend based on topology + and problem size + world_size: Number of ranks in the process group + rank: Current rank ID + max_token_num: Maximum number of tokens to support + hidden_dim: Hidden dimension size + dtype: Data type for communication tensors + topology: Network topology hint for backend selection + "single_node" - All ranks on one node (default) + "multi_node" - Ranks span multiple nodes + process_group: PyTorch distributed process group (for trtllm backend). + gpus_per_node: Number of GPUs per node (for multi-node topology). + comm_backend: Communication backend to use (for multi-node topology). + + Returns: + Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) + The workspace type determines which backend will be used in allreduce_fusion() + + Raises: + BackendSupportedError: If no suitable backend available for the configuration + ValueError: If problem size not supported for the specified backend + + Examples: + >>> # Auto-select best backend based on topology + >>> workspace = create_allreduce_fusion_workspace( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) + >>> print(workspace.backend) # "trtllm" + >>> print(workspace.get_workspace_capacity()) # 8388608 elements + + >>> # Check if workspace can handle different problem sizes + >>> workspace.is_buffer_size_sufficient(1024, 4096, 8, torch.bfloat16) # True + >>> workspace.is_buffer_size_sufficient(4096, 2048, 8, torch.bfloat16) # True (same total) + + >>> # Explicit backend selection + >>> workspace = create_allreduce_fusion_workspace( + ... backend="mnnvl", + ... world_size=16, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="multi_node" + ... ) + >>> print(workspace.backend) # "mnnvl" + """ + if gpus_per_node is None: + gpus_per_node = min(torch.cuda.device_count(), world_size) + # Determine the actual backend to use + if backend == "auto": + # Find suitable backends based on topology (anny CC check needs to be checked at kernel runtime, since there are no tensor available at this point) + suitable_backends = [] + if _trtllm_workspace_check( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + topology=topology, + ): + suitable_backends.append("trtllm") + if _mnnvl_workspace_check( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + topology=topology, + ): + suitable_backends.append("mnnvl") + + if not suitable_backends: + raise ValueError( + f"No suitable backend found for topology={topology}. " + f"trtllm requires single_node topology, mnnvl works with both." + ) + + # Apply heuristic to select best backend + selected = _workspace_creation_heuristic( + suitable_backends=suitable_backends, + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + topology=topology, + ) + actual_backend = selected[0] + else: + actual_backend = backend + + # Create workspace for selected backend using workspace constructors + if actual_backend == "trtllm": + return TRTLLMAllReduceFusionWorkspace( + tp_size=world_size, + tp_rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + process_group=process_group, + ) + + elif actual_backend == "mnnvl": + mapping = Mapping( + world_size=world_size, + rank=rank, + gpus_per_node=gpus_per_node, + tp_size=world_size, + ) + return MNNVLAllReduceFusionWorkspace( + mapping=mapping, + max_num_tokens=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + comm_backend=comm_backend, + ) + else: + raise RuntimeError(f"Unknown backend: {actual_backend}") + + +# ============================================================================ +# MAIN API - NO backend parameter, infers from workspace type +# ============================================================================ + + +def allreduce_fusion( + input: torch.Tensor, + workspace: AllReduceFusionWorkspace, + pattern: int, + launch_with_pdl: bool = False, + # ===== OUTPUT tensors (pre-allocated, will be filled) ===== + output: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, + scale_out: Optional[torch.Tensor] = None, + # ===== INPUT parameters ===== + residual_in: Optional[torch.Tensor] = None, + rms_gamma: Optional[torch.Tensor] = None, + rms_eps: float = 1e-6, + scale_factor: Optional[Union[torch.Tensor, float]] = None, + layout_code: Optional[int] = None, + # ===== Control parameters ===== + use_oneshot: Optional[bool] = None, + fp32_acc: bool = False, +) -> torch.Tensor: + """ + AllReduce + RMSNorm fusion operation. + + Backend is automatically determined from workspace type. If you need another backend, create the workspace for the desired backend. + + Supports multiple fusion patterns: + - AllReduce only + - AllReduce + Residual + RMSNorm + - AllReduce + Residual + RMSNorm + Quantization (FP8/FP4) + + **Note on Workspace Reusability:** + You can reuse the same workspace with different (token_num, hidden_dim) combinations + as long as `workspace.is_buffer_size_sufficient(token_num, hidden_dim, tp_size, dtype)` returns True. + + Args: + input: Input tensor [token_num, hidden_dim] + workspace: Workspace object (type determines backend, see create_allreduce_fusion_workspace) + pattern: Fusion pattern (AllReduceFusionPattern constant, 0-5) + - kAllReduce = 0 + - kARResidualRMSNorm = 1 + - kARResidualRMSNormFP8Quant = 2 + - kARResidualRMSNormFP4Quant = 3 + - kARResidualRMSNormOutFP8Quant = 4 + - kARResidualRMSNormOutFP4Quant = 5 + Note: MNNVL only supports patterns 0 and 1 + launch_with_pdl: Use Persistent Dependency Launch + + # ===== OUTPUT tensors (pre-allocated, filled by function) ===== + output: AllReduce output [token_num, hidden_dim] + residual_out: Prenorm output (after residual add, before norm) [token_num, hidden_dim] + norm_out: Normalized output [token_num, hidden_dim] + quant_out: Quantized output [token_num, hidden_dim] [trtllm only] + scale_out: Quantization scale factors [trtllm only] + + # ===== INPUT parameters ===== + residual_in: Residual tensor to ADD [token_num, hidden_dim] + rms_gamma: RMSNorm weight [hidden_dim] + rms_eps: RMSNorm epsilon for numerical stability + scale_factor: Input scale factor for quantization [trtllm only] + layout_code: Scale factor layout (QuantizationSFLayout) [trtllm only] + + # ===== Control parameters ===== + use_oneshot: Use oneshot strategy vs twoshot + If None, uses internal heuristics. + Note that the MNNVL backend needs to be initialized with a sufficiently large workspace if one_shot is used. + fp32_acc: [trtllm only] Use FP32 accumulation for AllReduce + + Returns: + Output tensor (typically norm_out for fusion cases, output otherwise) + + Examples: + >>> # Basic AllReduce + Residual + RMSNorm + >>> workspace = create_allreduce_fusion_workspace( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) + >>> + >>> # Pre-allocate output tensors + >>> prenorm = torch.empty_like(hidden_states) + >>> normed = torch.empty_like(hidden_states) + >>> + >>> # Call fusion - backend inferred from workspace type + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... pattern=AllReduceFusionPattern.kARResidualRMSNorm, + ... launch_with_pdl=True, + ... residual_out=prenorm, + ... norm_out=normed, + ... residual_in=residual, + ... rms_gamma=norm_weight + ... ) + >>> # output == normed (final result) + + >>> # With FP8 quantization + >>> quant = torch.empty_like(hidden_states, dtype=torch.float8_e4m3fn) + >>> scales = torch.empty(token_num * hidden_dim // 16, dtype=torch.float16) + >>> + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... pattern=AllReduceFusionPattern.kARResidualRMSNormFP8Quant, + ... norm_out=normed, + ... quant_out=quant, + ... scale_out=scales, + ... residual_in=residual, + ... rms_gamma=norm_weight, + ... scale_factor=scale_tensor + ... ) + """ + # Dispatch based on workspace type + if isinstance(workspace, TRTLLMAllReduceFusionWorkspace): + # TensorRT-LLM backend implementation + # Extract shape from 2D input + token_num, hidden_dim = input.shape + + # Allocate output if needed (keep 2D shape) + if output is None: + output = torch.empty_like(input) + + # Flatten all tensors to 1D for legacy trtllm_allreduce_fusion API + # The legacy API expects flattened tensors and explicit token_num/hidden_dim + # We require contiguous tensors so that view(-1) creates a view (not a copy), + # ensuring writes to the flattened tensors are reflected in the original 2D tensors + def _flatten_checked(t, name): + if not t.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + return t.view(-1) + + input_flat = _flatten_checked(input, "input") + output_flat = _flatten_checked(output, "output") + residual_in_flat = ( + _flatten_checked(residual_in, "residual_in") + if residual_in is not None + else None + ) + residual_out_flat = ( + _flatten_checked(residual_out, "residual_out") + if residual_out is not None + else None + ) + norm_out_flat = ( + _flatten_checked(norm_out, "norm_out") if norm_out is not None else None + ) + quant_out_flat = ( + _flatten_checked(quant_out, "quant_out") if quant_out is not None else None + ) + + # Call legacy API with flattened tensors + # Note: pattern and layout_code are ints but legacy API uses pseudo-type hints + trtllm_allreduce_fusion( + allreduce_in=input_flat, + world_size=workspace.world_size, + world_rank=workspace.rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace.workspace_tensor, + launch_with_pdl=launch_with_pdl, + trigger_completion_at_end=launch_with_pdl, # Same meaning + fp32_acc=fp32_acc, + pattern_code=pattern, # type: ignore[arg-type] + use_oneshot=use_oneshot, + allreduce_out=output_flat, + residual_in=residual_in_flat, + residual_out=residual_out_flat, + norm_out=norm_out_flat, + quant_out=quant_out_flat, + scale_out=scale_out, # scale_out is not reshaped + rms_gamma=rms_gamma, # 1D tensor, no reshape needed + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=layout_code, # type: ignore[arg-type] + metadata=workspace.metadata, + ) + + # Return the most downstream output (already in 2D shape from input views) + if norm_out is not None: + return norm_out + elif quant_out is not None: + return quant_out + else: + return output + + elif isinstance(workspace, MNNVLAllReduceFusionWorkspace): + if ( + pattern != AllReduceFusionPattern.kARResidualRMSNorm + and pattern != AllReduceFusionPattern.kAllReduce + ): + raise ValueError( + f"MNNVL AllReduce+RMS fusion does not support pattern {pattern}. Please try the TRTLLM backend instead." + ) + + if layout_code is not None: + raise ValueError( + "MNNVL AllReduce does not support quantization fusion and thus no layout_code" + ) + + # MNNVL backend implementation + if pattern == AllReduceFusionPattern.kAllReduce: + # AllReduce only + if output is None: + output = torch.empty_like(input) + trtllm_mnnvl_allreduce( + input=input, + workspace=workspace, + launch_with_pdl=launch_with_pdl, + output=output, + ) + return output + + elif pattern == AllReduceFusionPattern.kARResidualRMSNorm: + # AllReduce + Residual + RMSNorm fusion + # Validate required parameters + if residual_in is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires residual_in") + if rms_gamma is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires rms_gamma") + + # Allocate output tensors if not provided + if norm_out is None: + norm_out = torch.empty_like(input) + if residual_out is None: + residual_out = torch.empty_like(input) + + # Call the MNNVL fusion function + norm_result, residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm( + input=input, + residual_in=residual_in, + gamma=rms_gamma, + workspace=workspace, + epsilon=rms_eps, + output=norm_out, + residual_out=residual_out, + launch_with_pdl=launch_with_pdl, + ) + return norm_result + + else: + raise ValueError(f"Unsupported pattern for MNNVL backend: {pattern}") + + else: + raise TypeError( + f"Unknown workspace type: {type(workspace)}. " + f"Expected TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace" + ) diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 33bb7ac97b..85e953c766 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -123,7 +123,7 @@ def trtllm_lamport_initialize_all( ) @deprecated( - "trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce are deprecated, use trtllm_create_ipc_workspace_for_all_reduce_fusion and trtllm_allreduce_fusion instead" + "trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce are deprecated and will be removed in the next major bump, use allreduce.py instead." ) @register_custom_op( "flashinfer::trtllm_custom_all_reduce", @@ -398,7 +398,7 @@ def trtllm_moe_finalize_allreduce_fusion( @deprecated( - "trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce are deprecated, use trtllm_create_ipc_workspace_for_all_reduce_fusion and trtllm_allreduce_fusion instead" + "trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce are deprecated and will be removed in the next major bump, use allreduce.py instead." ) def trtllm_create_ipc_workspace_for_all_reduce( rank: int, @@ -500,7 +500,9 @@ def trtllm_destroy_ipc_workspace_for_all_reduce( MAX_COMM_SIZE = 2147483647 & ~((1 << 21) - 1) # MAX_INT32 rounded down to 2MB -# @TODO(nvmbreughe): on a next major bump, remove create_metadata and make create_metadata=True the default behavior +@deprecated( + "use the unified API allreduce.py instead. It will internally call trtllm_create_ipc_workspace_for_all_reduce_fusion." +) def trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank: int, tp_size: int, @@ -804,6 +806,54 @@ def _should_use_oneshot( return comm_size_mb <= _use_oneshot_heuristics[world_size] +def check_trtllm_allreduce_fusion_workspace_metadata( + token_num: int, + hidden_dim: int, + world_size: int, + dtype: torch.dtype, + metadata: dict, +) -> None: + errors = [] + required_keys = ["max_token_num", "tp_size", "hidden_dim", "use_fp32_lamport"] + for key in required_keys: + if key not in metadata: + errors.append(f"Workspace metadata is missing required key: {key}") + if errors: + error_msg = "Workspace metadata validation failed:\n" + "\n".join( + f" - {e}" for e in errors + ) + raise ValueError(error_msg) + + # world_size must match tp_size (flag size depends on it) + if world_size != metadata["tp_size"]: + errors.append( + f"world_size ({world_size}) does not match workspace tp_size ({metadata['tp_size']}). " + f"Workspace was created for tp_size={metadata['tp_size']}." + ) + + # token_num * hidden_dim must not exceed max_token_num * hidden_dim + if token_num * hidden_dim > metadata["max_token_num"] * metadata["hidden_dim"]: + errors.append( + f"token_num ({token_num}) * hidden_dim ({hidden_dim}) exceeds workspace max_token_num ({metadata['max_token_num']}) * hidden_dim ({metadata['hidden_dim']}). " + f"This may cause Illegal Memory Access." + ) + + # use_fp32_lamport must match + if metadata["use_fp32_lamport"] != (dtype == torch.float32): + errors.append( + f"use_fp32_lamport ({metadata['use_fp32_lamport']}) does not match allreduce_in.dtype ({dtype}). " + f"Workspace was created for use_fp32_lamport={metadata['use_fp32_lamport']}." + ) + if errors: + error_msg = "Workspace validation failed:\n" + "\n".join( + f" - {e}" for e in errors + ) + raise ValueError(error_msg) + + +@deprecated( + "use the unified API allreduce.py instead. It will internally call trtllm_allreduce_fusion." +) def trtllm_allreduce_fusion( allreduce_in: torch.Tensor, world_size: int, @@ -858,50 +908,9 @@ def trtllm_allreduce_fusion( # Validate against workspace metadata if provided if metadata is not None: - errors = [] - required_keys = ["max_token_num", "tp_size", "hidden_dim", "use_fp32_lamport"] - for key in required_keys: - if key not in metadata: - errors.append(f"Workspace metadata is missing required key: {key}") - if errors: - error_msg = "Workspace metadata validation failed:\n" + "\n".join( - f" - {e}" for e in errors - ) - raise ValueError(error_msg) - - # Check 1: token_num must not exceed max_token_num - if token_num > metadata["max_token_num"]: - errors.append( - f"token_num ({token_num}) exceeds workspace max_token_num ({metadata['max_token_num']}). " - f"This may cause Illegal Memory Access." - ) - - # Check 2: world_size must match tp_size - if world_size != metadata["tp_size"]: - errors.append( - f"world_size ({world_size}) does not match workspace tp_size ({metadata['tp_size']}). " - f"Workspace was created for tp_size={metadata['tp_size']}." - ) - - # Check 3: hidden_dim must match - if hidden_dim != metadata["hidden_dim"]: - errors.append( - f"hidden_dim ({hidden_dim}) does not match workspace hidden_dim ({metadata['hidden_dim']}). " - f"Workspace was created for hidden_dim={metadata['hidden_dim']}." - ) - - # Check 4: use_fp32_lamport must match - if metadata["use_fp32_lamport"] != (allreduce_in.dtype == torch.float32): - errors.append( - f"use_fp32_lamport ({metadata['use_fp32_lamport']}) does not match allreduce_in.dtype ({allreduce_in.dtype}). " - f"Workspace was created for use_fp32_lamport={metadata['use_fp32_lamport']}." - ) - - if errors: - error_msg = "Workspace validation failed:\n" + "\n".join( - f" - {e}" for e in errors - ) - raise ValueError(error_msg) + check_trtllm_allreduce_fusion_workspace_metadata( + token_num, hidden_dim, world_size, allreduce_in.dtype, metadata + ) if use_oneshot is None: use_oneshot = _should_use_oneshot( diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 308ee4bda6..dfcb8317e5 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -18,6 +18,7 @@ from ..jit import gen_trtllm_mnnvl_comm_module from ..utils import register_custom_op from .mnnvl import McastGPUBuffer, CommBackend, MPIBackend +from .workspace_base import AllReduceFusionWorkspace def mpi_barrier(): @@ -47,7 +48,7 @@ def select_strategy( MNNVL_ONE_SHOT_THRESHOLD = 64 * 1024 * 8 * 2 -class MNNVLAllreduceFusionWorkspace: +class MNNVLAllReduceFusionWorkspace(AllReduceFusionWorkspace): NUM_LAMPORT_BUFFERS = 3 def __init__( @@ -75,6 +76,7 @@ def __init__( dtype: The data type of the tensors to be reduced. buffer_size_in_bytes: The requested size in bytes for each lamport buffer. The actual allocation size may be larger due to alignment requirements. The actual usable size will be NUM_LAMPORT_BUFFERS * actual_buffer_size_per_lamport_buffer. """ + super().__init__(mapping.world_size, mapping.rank) if buffer_size_in_bytes is None: assert ( @@ -222,6 +224,22 @@ def get_required_buffer_size_bytes( ) return buffer_size + @property + def backend(self) -> str: + return "mnnvl" + + def destroy(self) -> None: + """Destroy workspace and free resources.""" + if getattr(self, "_destroyed", False): + return # Already destroyed, nothing to do + + del self.mcast_buffer_handle + del self.buffer_flags + del self.uc_ptrs_dev + del self.uc_ptr_local + del self.mc_ptr + self._destroyed = True + @functools.cache def get_trtllm_mnnvl_comm_module(): @@ -307,7 +325,7 @@ def trtllm_mnnvl_allreduce_fusion( def trtllm_mnnvl_allreduce( input: torch.Tensor, - workspace: MNNVLAllreduceFusionWorkspace, + workspace: MNNVLAllReduceFusionWorkspace, launch_with_pdl: bool, output: Optional[torch.Tensor] = None, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, @@ -327,7 +345,7 @@ def trtllm_mnnvl_allreduce( Args: input: Local Input Shard [num_tokens, hidden_dim] - workspace: MNNVLAllreduceFusionWorkspace + workspace: MNNVLAllReduceFusionWorkspace launch_with_pdl: Whether to launch with PDL output: Output tensor to store the result, empty tensor will be created if not provided. strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. @@ -387,7 +405,7 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( input: torch.Tensor, residual_in: torch.Tensor, gamma: torch.Tensor, - workspace: MNNVLAllreduceFusionWorkspace, + workspace: MNNVLAllReduceFusionWorkspace, epsilon: Optional[float] = None, output: Optional[torch.Tensor] = None, residual_out: Optional[torch.Tensor] = None, @@ -404,7 +422,7 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( input: Input tensor [num_tokens, hidden_dim] residual_in: Residual input tensor [num_tokens, hidden_dim] gamma: Gamma tensor [hidden_dim] - workspace: MNNVLAllreduceFusionWorkspace + workspace: MNNVLAllReduceFusionWorkspace epsilon: The epsilon parameter for RMSNorm, torch.finfo.eps will be used if not provided. output: Output tensor for normalized results [num_tokens, hidden_dim], empty tensor will be created if not provided. residual_out: Residual output tensor [num_tokens, hidden_dim], empty tensor will be created if not provided. @@ -479,7 +497,7 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( # Legacy API that has been deprecated; Left for backward compatibility @deprecated( - "get_allreduce_mnnvl_workspace is deprecated, use MNNVLAllreduceFusionWorkspace class to manage the workspace instead" + "get_allreduce_mnnvl_workspace is deprecated, use MNNVLAllReduceFusionWorkspace class to manage the workspace instead" ) def get_allreduce_mnnvl_workspace( mapping: Mapping, @@ -522,7 +540,7 @@ def get_allreduce_mnnvl_workspace( ) * (lcm_hidden_dim * stride) # Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout. - workspace = MNNVLAllreduceFusionWorkspace( + workspace = MNNVLAllReduceFusionWorkspace( mapping, buffer_size_in_bytes=buffer_size_in_bytes, comm_backend=comm_backend_for_handle_transfer, diff --git a/flashinfer/comm/workspace_base.py b/flashinfer/comm/workspace_base.py new file mode 100644 index 0000000000..5de8d07483 --- /dev/null +++ b/flashinfer/comm/workspace_base.py @@ -0,0 +1,89 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from abc import ABC, abstractmethod +from typing import Optional, Any + +import torch + + +class AllReduceFusionWorkspace(ABC): + """Base class for AllReduce fusion workspaces.""" + + # Explicit type annotations for mypy (needed due to __getattr__ in subclasses) + world_size: int + rank: int + _destroyed: bool + + def __init__(self, world_size: int, rank: int): + self.world_size = world_size + self.rank = rank + self._destroyed = False + + @property + @abstractmethod + def backend(self) -> str: + """Return backend name.""" + pass + + @abstractmethod + def destroy(self) -> None: + """ + Destroy workspace and free resources. + + This should be called explicitly when done using the workspace. + Prefer using AllReduceFusionContext context manager for automatic cleanup. + """ + pass + + @abstractmethod + def is_buffer_size_sufficient( + self, + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + use_oneshot: Optional[Any] = None, + ) -> bool: + pass + + def __del__(self): + """ + Destructor - safety net if destroy() wasn't called explicitly. + + Warns if cleanup wasn't done properly. Not recommended to rely on this + as __del__ timing is non-deterministic and can cause issues with + distributed/CUDA resources. + """ + if not self._destroyed: + import warnings + + warnings.warn( + f"{self.__class__.__name__} was not explicitly destroyed. " + f"Call workspace.destroy() or use AllReduceFusionContext to ensure " + f"proper cleanup of distributed/CUDA resources.", + ResourceWarning, + stacklevel=2, + ) + try: + self.destroy() + except Exception as e: + # Can't raise in __del__, just warn + warnings.warn( + f"Error during automatic cleanup of {self.__class__.__name__}: {e}", + ResourceWarning, + stacklevel=2, + ) diff --git a/tests/comm/test_allreduce_negative.py b/tests/comm/test_allreduce_negative.py new file mode 100644 index 0000000000..ff518893e2 --- /dev/null +++ b/tests/comm/test_allreduce_negative.py @@ -0,0 +1,272 @@ +# Negative tests for unified AllReduce API +# Run with: mpirun -np pytest tests/comm/test_allreduce_negative.py -vv -s + +import pytest +import torch +import torch.distributed as dist + +import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar + +from flashinfer.comm import ( + create_allreduce_fusion_workspace, + allreduce_fusion, + AllReduceFusionPattern, + QuantizationSFLayout, +) + +# Test helpers +from tests.test_helpers.comm import ( + setup_mpi_and_cuda, + init_torch_distributed_from_mpi, + cleanup_torch_distributed, +) + + +class TestMNNVLUnsupportedPatterns: + """Test that MNNVL backend properly rejects unsupported fusion patterns.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup workspace for each test.""" + self.rank, self.world_size, self.gpus_per_node = setup_mpi_and_cuda() + + # Create MNNVL workspace + self.workspace = create_allreduce_fusion_workspace( + backend="mnnvl", + world_size=self.world_size, + rank=self.rank, + max_token_num=128, + hidden_dim=2880, + dtype=torch.float16, + topology="single_node", + gpus_per_node=self.gpus_per_node, + ) + + yield + + # Cleanup + if self.workspace is not None: + self.workspace.destroy() + trtllm_mnnvl_ar.mpi_barrier() + + def _create_test_tensors(self, seq_len: int = 16, hidden_dim: int = 2880): + """Create test tensors for allreduce operations.""" + input_tensor = torch.randn( + seq_len, hidden_dim, dtype=torch.float16, device="cuda" + ) + residual = torch.randn(seq_len, hidden_dim, dtype=torch.float16, device="cuda") + rms_gamma = torch.randn(hidden_dim, dtype=torch.float16, device="cuda") + return input_tensor, residual, rms_gamma + + @pytest.mark.parametrize( + "pattern", + [ + AllReduceFusionPattern.kARResidualRMSNormFP8Quant, + AllReduceFusionPattern.kARResidualRMSNormFP4Quant, + AllReduceFusionPattern.kARResidualRMSNormOutFP8Quant, + AllReduceFusionPattern.kARResidualRMSNormOutFP4Quant, + ], + ) + def test_unsupported_quantization_patterns(self, pattern): + """Test that MNNVL rejects quantization fusion patterns.""" + input_tensor, residual, rms_gamma = self._create_test_tensors() + + with pytest.raises(ValueError, match="does not support pattern"): + allreduce_fusion( + input=input_tensor, + workspace=self.workspace, + pattern=pattern, + launch_with_pdl=True, + residual_in=residual, + rms_gamma=rms_gamma, + ) + + @pytest.mark.parametrize( + "layout_code", + [ + QuantizationSFLayout.LINEAR, + QuantizationSFLayout.SWIZZLED_128x4, + QuantizationSFLayout.SWIZZLED_8x4, + ], + ) + def test_layout_code_not_supported(self, layout_code): + """Test that MNNVL rejects any layout_code specification.""" + input_tensor, residual, rms_gamma = self._create_test_tensors() + + # Test with kAllReduce pattern + with pytest.raises(ValueError, match="does not support quantization fusion"): + allreduce_fusion( + input=input_tensor, + workspace=self.workspace, + pattern=AllReduceFusionPattern.kAllReduce, + launch_with_pdl=True, + layout_code=layout_code, + ) + + # Test with kARResidualRMSNorm pattern + with pytest.raises(ValueError, match="does not support quantization fusion"): + allreduce_fusion( + input=input_tensor, + workspace=self.workspace, + pattern=AllReduceFusionPattern.kARResidualRMSNorm, + launch_with_pdl=True, + residual_in=residual, + rms_gamma=rms_gamma, + layout_code=layout_code, + ) + + +class TestMNNVLMissingRequiredParameters: + """Test that MNNVL backend properly validates required parameters.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup workspace for each test.""" + self.rank, self.world_size, self.gpus_per_node = setup_mpi_and_cuda() + + # Create MNNVL workspace + self.workspace = create_allreduce_fusion_workspace( + backend="mnnvl", + world_size=self.world_size, + rank=self.rank, + max_token_num=128, + hidden_dim=2880, + dtype=torch.float16, + topology="single_node", + gpus_per_node=self.gpus_per_node, + ) + + yield + + # Cleanup + if self.workspace is not None: + self.workspace.destroy() + trtllm_mnnvl_ar.mpi_barrier() + + def test_rmsnorm_missing_residual_in(self): + """Test that kARResidualRMSNorm requires residual_in.""" + input_tensor = torch.randn(16, 2880, dtype=torch.float16, device="cuda") + rms_gamma = torch.randn(2880, dtype=torch.float16, device="cuda") + + with pytest.raises(ValueError, match="requires residual_in"): + allreduce_fusion( + input=input_tensor, + workspace=self.workspace, + pattern=AllReduceFusionPattern.kARResidualRMSNorm, + launch_with_pdl=True, + rms_gamma=rms_gamma, + # residual_in is missing + ) + + def test_rmsnorm_missing_rms_gamma(self): + """Test that kARResidualRMSNorm requires rms_gamma.""" + input_tensor = torch.randn(16, 2880, dtype=torch.float16, device="cuda") + residual = torch.randn(16, 2880, dtype=torch.float16, device="cuda") + + with pytest.raises(ValueError, match="requires rms_gamma"): + allreduce_fusion( + input=input_tensor, + workspace=self.workspace, + pattern=AllReduceFusionPattern.kARResidualRMSNorm, + launch_with_pdl=True, + residual_in=residual, + # rms_gamma is missing + ) + + +@pytest.mark.parametrize("backend", ["mnnvl", "trtllm"]) +class TestBufferSizeSufficient: + """Test is_buffer_size_sufficient method for different backends.""" + + @pytest.fixture(autouse=True) + def setup(self, backend): + """Setup workspace with small buffer for testing.""" + self.backend = backend + self.rank, self.world_size, self.gpus_per_node = setup_mpi_and_cuda() + + # Initialize torch.distributed for trtllm backend + self.process_group = None + if backend == "trtllm": + init_torch_distributed_from_mpi() + self.process_group = dist.group.WORLD + + # Create workspace with small max_token_num to test buffer limits + self.max_token_num = 64 + self.hidden_dim = 2880 + self.dtype = torch.float16 + + self.workspace = create_allreduce_fusion_workspace( + backend=backend, + world_size=self.world_size, + rank=self.rank, + max_token_num=self.max_token_num, + hidden_dim=self.hidden_dim, + dtype=self.dtype, + topology="single_node", + gpus_per_node=self.gpus_per_node, + process_group=self.process_group, + ) + + yield + + # Cleanup + if self.workspace is not None: + self.workspace.destroy() + if backend == "trtllm": + cleanup_torch_distributed() + trtllm_mnnvl_ar.mpi_barrier() + + def test_buffer_sufficient_for_smaller_size(self, backend): + """Test that is_buffer_size_sufficient returns True for sizes within capacity.""" + # Use smaller size than max_token_num + result = self.workspace.is_buffer_size_sufficient( + tp_size=self.world_size, + num_tokens=self.max_token_num // 2, + hidden_dim=self.hidden_dim, + dtype=self.dtype, + ) + assert result is True, ( + f"[{backend}] Buffer should be sufficient for smaller token count" + ) + + def test_buffer_sufficient_for_exact_size(self, backend): + """Test that is_buffer_size_sufficient returns True for exact capacity.""" + result = self.workspace.is_buffer_size_sufficient( + tp_size=self.world_size, + num_tokens=self.max_token_num, + hidden_dim=self.hidden_dim, + dtype=self.dtype, + ) + assert result is True, ( + f"[{backend}] Buffer should be sufficient for exact max token count" + ) + + def test_buffer_insufficient_for_larger_size(self, backend): + """Test that is_buffer_size_sufficient returns False for sizes exceeding capacity.""" + # Calculate the actual buffer capacity and use a size that definitely exceeds it + elem_size = torch.tensor([], dtype=self.dtype).element_size() + + if backend == "mnnvl": + # For MNNVL two-shot: buffer_size >= 2 * ceil(num_tokens/tp_size) * tp_size * hidden_dim * elem_size + max_tokens_in_buffer = self.workspace.buffer_size_bytes // ( + 2 * self.hidden_dim * elem_size + ) + else: + # For TRTLLM: use metadata to determine max capacity + max_tokens_in_buffer = ( + self.workspace.metadata["max_token_num"] + * self.workspace.metadata["hidden_dim"] + ) // self.hidden_dim + + large_num_tokens = max_tokens_in_buffer * 10 # Use 10x the capacity + + result = self.workspace.is_buffer_size_sufficient( + tp_size=self.world_size, + num_tokens=large_num_tokens, + hidden_dim=self.hidden_dim, + dtype=self.dtype, + ) + assert result is False, ( + f"[{backend}] Buffer should be insufficient for {large_num_tokens} tokens " + f"(buffer can hold ~{max_tokens_in_buffer})" + ) diff --git a/tests/comm/test_allreduce_unified_api.py b/tests/comm/test_allreduce_unified_api.py new file mode 100644 index 0000000000..732a3ddb92 --- /dev/null +++ b/tests/comm/test_allreduce_unified_api.py @@ -0,0 +1,311 @@ +# Test for unified AllReduce API with multiple backends +# Run with: mpirun -np pytest tests/comm/test_allreduce_unified_api.py -vv -s +import traceback +from typing import Tuple + +import pytest +import torch +import torch.distributed as dist +from mpi4py import MPI + +import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar + +# Unified API imports +from flashinfer.comm import ( + create_allreduce_fusion_workspace, + allreduce_fusion, + AllReduceFusionPattern, + AllReduceFusionWorkspace, +) + +# Use flashinfer.norm.rmsnorm as reference implementation. +from flashinfer.norm import rmsnorm + +# Test helpers +from tests.test_helpers.comm import ( + init_torch_distributed_from_mpi, + cleanup_torch_distributed, +) + + +@torch.inference_mode() +def run_allreduce_fusion_test( + x: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + rank: int, + fusion: bool, + reference_output: tuple[torch.Tensor, ...], + workspace: AllReduceFusionWorkspace, +): + """Test function using the unified API (create_allreduce_fusion_workspace + allreduce_fusion).""" + MPI.COMM_WORLD.barrier() + + def func( + input, + residual, + norm_weight, + eps, + enable_fusion, + workspace, + ): + # For both fused and unfused cases: + shape = input.shape + input = input.view(-1, shape[-1]) + use_pdl = True + + if enable_fusion: + trtllm_mnnvl_ar.mpi_barrier() + + # Use unified API + norm_out = torch.empty_like(input) + residual_out = torch.empty_like(input) + + allreduce_fusion( + input=input, + workspace=workspace, + pattern=AllReduceFusionPattern.kARResidualRMSNorm, + launch_with_pdl=use_pdl, + residual_out=residual_out, + norm_out=norm_out, + residual_in=residual.view(-1, shape[-1]), + rms_gamma=norm_weight, + rms_eps=eps, + ) + + return norm_out.view(shape), residual_out.view(shape) + + else: + # Use unified API for AllReduce only + output = torch.empty_like(input) + + allreduce_fusion( + input=input, + workspace=workspace, + pattern=AllReduceFusionPattern.kAllReduce, + launch_with_pdl=use_pdl, + output=output, + ) + return (output.view(shape),) + + output = func(x.clone(), residual.clone(), norm_weight, eps, fusion, workspace) + + assert output[0].shape == reference_output[0].shape + + if rank == 0: + print("output[0] (first 10 values):", output[0].flatten()[:10]) + print( + "reference_output[0] (first 10 values):", + reference_output[0].flatten()[:10], + ) + + if fusion: + print("output[1] (first 10 values):", output[1].flatten()[:10]) + print( + "reference_output[1] (first 10 values):", + reference_output[1].flatten()[:10], + ) + + torch.testing.assert_close( + output[0], + reference_output[0], + rtol=0.05, + atol=0.15, + ) + + if fusion: + torch.testing.assert_close( + output[1], + reference_output[1], + rtol=0.05, + atol=0.15, + ) + + +def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool): + """Prepare test data distributed across MPI ranks.""" + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + if rank == 0: + x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype) + residual = torch.randn((seq_len, hidden_size), dtype=dtype) + norm_weight = torch.randn((hidden_size,), dtype=dtype) + else: + x_full = None + residual = None + norm_weight = None + + # Use lowercase bcast() for Python object broadcasting + x_full = comm.bcast(x_full, root=0) + residual = comm.bcast(residual, root=0) + norm_weight = comm.bcast(norm_weight, root=0) + + x_full = x_full.cuda() + residual = residual.cuda() + norm_weight = norm_weight.cuda() + + x_local = x_full[rank, :, :] + reference_output: Tuple[torch.Tensor, ...] = None + if fusion: + # Fused case: AllReduce + Residual Add + RMS Norm + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + residual_out = allreduce_result + residual # Add residual + norm_out = rmsnorm( + residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False + ) + + reference_output = (norm_out, residual_out) + else: + # Non-fused case: Only AllReduce + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + reference_output = (allreduce_result,) + return (x_local, residual, norm_weight), reference_output + + +def run_allreduce_test( + monkeypatch, + seq_lens: list[int], + fusion: bool, + dtype: torch.dtype, + hidden_size: int, + backend: str, +): + """Core test logic for AllReduce operations using the unified API. + + Args: + monkeypatch: pytest monkeypatch fixture + seq_lens: List of sequence lengths to test + fusion: Whether to test fused allreduce+rmsnorm or just allreduce + dtype: Data type for tensors + hidden_size: Hidden dimension size + backend: Backend to use ("auto", "trtllm", "mnnvl") + """ + + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + gpus_per_node = torch.cuda.device_count() + + if gpus_per_node == 0: + pytest.skip("AllReduce test requires at least one CUDA device per node") + if world_size < 2: + pytest.skip(f"This test requires at least 2 MPI ranks, got {world_size}") + + # Set CUDA device based on rank + local_rank = rank % gpus_per_node + torch.cuda.set_device(local_rank) + + # Initialize torch.distributed for trtllm backend (needed for IPC workspace) + # TODO: check if it is ok to do this with auto backend + process_group = None + if backend in ("trtllm", "auto"): + init_torch_distributed_from_mpi() + process_group = dist.group.WORLD + + if local_rank == 0: + print(f"Running AllReduce test with {world_size} ranks, backend={backend}") + print(f"Rank {rank} using GPU {torch.cuda.current_device()}") + + eps = 1e-5 + torch.manual_seed(42 + rank) + + workspace = None + + try: + # Create workspace using unified API + workspace = create_allreduce_fusion_workspace( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max(seq_lens), + hidden_dim=hidden_size, + dtype=dtype, + topology="single_node", + gpus_per_node=gpus_per_node, + process_group=process_group, + ) + + print(f"Rank {rank}: Created workspace with backend={workspace.backend}") + + # Prepare test data for all sequence lengths + test_data = [] + for seq_len in seq_lens: + (x_local, residual, norm_weight), reference_output = prepare_test_data( + seq_len, hidden_size, dtype, fusion + ) + test_data.append( + (seq_len, x_local, residual, norm_weight, reference_output) + ) + + # Test each sequence length with the same workspace + for seq_len, x, residual, norm_weight, reference_output in test_data: + if rank == 0: + print( + f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" + ) + + run_allreduce_fusion_test( + x, + residual, + norm_weight, + eps, + rank, + fusion, + reference_output, + workspace, + ) + + # Synchronize before next test + trtllm_mnnvl_ar.mpi_barrier() + + print( + f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}, backend={backend}" + ) + + except Exception as e: + failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype}, backend={backend} failed: {e}" + print(failure_message) + print(traceback.format_exc()) + + # Gather failure status from all ranks for logging + all_failures = MPI.COMM_WORLD.allgather(True) + failed_ranks = [i for i, failed in enumerate(all_failures) if failed] + if rank == 0: + print(f"Test failed on ranks: {failed_ranks}") + + raise + + finally: + if workspace is not None: + workspace.destroy() + # Cleanup torch.distributed if we initialized it + if backend in ("trtllm", "auto"): + cleanup_torch_distributed() + + # Final synchronization + trtllm_mnnvl_ar.mpi_barrier() + + +@pytest.mark.parametrize( + "seq_lens", + [[1], [4], [15], [27, 11, 24, 256], [127], [998, 2048]], +) +@pytest.mark.parametrize("fusion", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [2880, 7168]) +@pytest.mark.parametrize("backend", ["auto", "trtllm", "mnnvl"]) +def test_allreduce_unified( + monkeypatch, + seq_lens: list[int], + fusion: bool, + dtype: torch.dtype, + hidden_size: int, + backend: str, +): + """Test AllReduce with unified API across different backends. + + Run with: mpirun -np pytest tests/comm/test_allreduce_unified_api.py -vv -s + """ + run_allreduce_test(monkeypatch, seq_lens, fusion, dtype, hidden_size, backend) diff --git a/tests/comm/test_trtllm_allreduce_fusion.py b/tests/comm/test_trtllm_allreduce_fusion.py index c3aa8c8252..dab4877fb9 100644 --- a/tests/comm/test_trtllm_allreduce_fusion.py +++ b/tests/comm/test_trtllm_allreduce_fusion.py @@ -22,7 +22,9 @@ SCALE_FACTOR_RANGE = (-1, 1) -def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_init_port): +def _run_correctness_worker( + world_size, rank, dtype, hidden_dim, distributed_init_port, legacy_api=True +): device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) distributed_init_method = f"tcp://localhost:{distributed_init_port}" @@ -57,18 +59,33 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini lamport_use_fp32 = dtype == torch.float32 - # create workspace for allreduce fusion with metadata - ipc_handles, workspace_tensor, workspace_metadata = ( - comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( - rank, - world_size, - MAX_TOKEN_NUM, - hidden_dim, - group=group, - use_fp32_lamport=lamport_use_fp32, - create_metadata=True, # Get metadata for validation + # Create workspace - choose between legacy and new API + if legacy_api: + # Legacy API: create workspace for allreduce fusion with metadata + ipc_handles, workspace_tensor, workspace_metadata = ( + comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + MAX_TOKEN_NUM, + hidden_dim, + group=group, + use_fp32_lamport=lamport_use_fp32, + create_metadata=True, # Get metadata for validation + ) + ) + else: + workspace = None + # New unified API: create workspace + workspace = comm.create_allreduce_fusion_workspace( + backend="trtllm", + world_size=world_size, + rank=rank, + max_token_num=MAX_TOKEN_NUM, + hidden_dim=hidden_dim, + dtype=dtype, + topology="single_node", + process_group=group, ) - ) test_loop = 5 @@ -163,60 +180,128 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(test_loop): - comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - world_size=world_size, - world_rank=rank, - token_num=token_num, - hidden_dim=hidden_dim, - workspace_ptrs=workspace_tensor, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=all_reduce_out, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - quant_out=quant_out, - scale_out=scale_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - scale_factor=scale_factor, - layout_code=swizzled_layout_code, - metadata=workspace_metadata, - ) + if legacy_api: + # Legacy API - uses flattened tensors + comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + world_size=world_size, + world_rank=rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=all_reduce_out, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + metadata=workspace_metadata, + ) + else: + # New unified API - expects 2D tensors [token_num, hidden_dim] + comm.allreduce_fusion( + input=allreduce_in.view( + token_num, hidden_dim + ), + workspace=workspace, + launch_with_pdl=launch_with_pdl, + output=all_reduce_out.view( + token_num, hidden_dim + ), + residual_in=residual_in.view( + token_num, hidden_dim + ), + residual_out=residual_out.view( + token_num, hidden_dim + ), + norm_out=norm_out.view( + token_num, hidden_dim + ), + quant_out=quant_out.view( + token_num, hidden_dim + ), + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + pattern=pattern_code, + use_oneshot=use_oneshot, + fp32_acc=fp32_acc, + ) # NOTE: in real case, you dont have to set all optional params. You could set those required by fusion pattern. # capture g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): for _ in range(test_loop): - comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - world_size=world_size, - world_rank=rank, - token_num=token_num, - hidden_dim=hidden_dim, - workspace_ptrs=workspace_tensor, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=all_reduce_out, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - quant_out=quant_out, - scale_out=scale_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - scale_factor=scale_factor, - layout_code=swizzled_layout_code, - metadata=workspace_metadata, - ) + if legacy_api: + # Legacy API - uses flattened tensors + comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + world_size=world_size, + world_rank=rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=all_reduce_out, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + metadata=workspace_metadata, + ) + else: + # New unified API - expects 2D tensors [token_num, hidden_dim] + comm.allreduce_fusion( + input=allreduce_in.view( + token_num, hidden_dim + ), + workspace=workspace, + launch_with_pdl=launch_with_pdl, + output=all_reduce_out.view( + token_num, hidden_dim + ), + residual_in=residual_in.view( + token_num, hidden_dim + ), + residual_out=residual_out.view( + token_num, hidden_dim + ), + norm_out=norm_out.view( + token_num, hidden_dim + ), + quant_out=quant_out.view( + token_num, hidden_dim + ), + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + pattern=pattern_code, + use_oneshot=use_oneshot, + fp32_acc=fp32_acc, + ) # replay g.replay() torch.cuda.synchronize() @@ -307,9 +392,14 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini finally: dist.barrier(group=group) - comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion( - ipc_handles, group=group - ) + # Destroy workspace - choose between legacy and new API + if legacy_api: + comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion( + ipc_handles, group=group + ) + elif workspace is not None: + # New unified API + workspace.destroy() dist.destroy_process_group(group=group) @@ -358,7 +448,8 @@ def multi_process_parallel( @pytest.mark.parametrize("world_size", [2, 4, 8]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_dim", [1024, 2048, 4096, 7168, 8192]) -def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim): +@pytest.mark.parametrize("legacy_api", [True, False]) +def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim, legacy_api): np.random.seed(42) torch.manual_seed(42) torch.cuda.manual_seed_all(42) @@ -367,17 +458,22 @@ def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim): pytest.skip( f"world_size {world_size} is greater than available_gpus {available_gpus}" ) - print(f"Running test for world_size={world_size}") + api_str = "legacy" if legacy_api else "unified" + print(f"Running test for world_size={world_size} with {api_str} API") multi_process_parallel( world_size, dtype, hidden_dim, _run_correctness_worker, - target_args=(), + target_args=(legacy_api,), ) - print(f"allreduce fusion tp = {world_size}: OK") + print(f"allreduce fusion tp = {world_size} ({api_str} API): OK") if __name__ == "__main__": - test_trtllm_allreduce_fusion(2, torch.float16, 1024) + # Test both legacy and unified APIs + print("Testing legacy API...") + test_trtllm_allreduce_fusion(2, torch.float16, 1024, legacy_api=True) + print("\nTesting unified API...") + test_trtllm_allreduce_fusion(2, torch.float16, 1024, legacy_api=False) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index 78ce392b7a..ce7880e406 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -22,7 +22,7 @@ def row_linear_residual_norm_fusion_forward( mapping: Mapping, fusion: bool, reference_output: tuple[torch.Tensor, ...], - workspace: trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace, + workspace: trtllm_mnnvl_ar.MNNVLAllReduceFusionWorkspace, ): tensor_parallel_rank = mapping.tp_rank MPI.COMM_WORLD.barrier() @@ -341,7 +341,7 @@ def run_mnnvl_ar_full( ) else: - workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace( + workspace = trtllm_mnnvl_ar.MNNVLAllReduceFusionWorkspace( mapping, max_num_tokens=max(seq_lens), hidden_dim=hidden_size, diff --git a/tests/test_helpers/comm.py b/tests/test_helpers/comm.py new file mode 100644 index 0000000000..fcd4e0a23b --- /dev/null +++ b/tests/test_helpers/comm.py @@ -0,0 +1,67 @@ +# Helper functions for communication tests +import os + +import pytest +import torch +import torch.distributed as dist +from mpi4py import MPI + + +def setup_mpi_and_cuda(): + """Setup MPI and CUDA device for tests. + + Returns: + tuple: (rank, world_size, gpus_per_node) + + Raises: + pytest.skip: If no CUDA devices or fewer than 2 MPI ranks + """ + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + gpus_per_node = torch.cuda.device_count() + + if gpus_per_node == 0: + pytest.skip("Tests require at least one CUDA device per node") + if world_size < 2: + pytest.skip(f"Tests require at least 2 MPI ranks, got {world_size}") + + local_rank = rank % gpus_per_node + torch.cuda.set_device(local_rank) + + return rank, world_size, gpus_per_node + + +def init_torch_distributed_from_mpi(): + """Initialize torch.distributed using MPI rank info. + + This allows running torch.distributed operations within an MPI context. + Safe to call multiple times - will skip if already initialized. + """ + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + + if dist.is_initialized(): + return + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + dist.init_process_group( + backend="nccl", + rank=rank, + world_size=world_size, + ) + + +def cleanup_torch_distributed(): + """Cleanup torch.distributed if initialized. + + Safe to call even if torch.distributed was not initialized. + """ + if dist.is_initialized(): + dist.destroy_process_group()