Skip to content

Commit 10554e5

Browse files
committed
Removed backend decorator as it is not appicable with workspace creation
1 parent 686db76 commit 10554e5

File tree

2 files changed

+73
-46
lines changed

2 files changed

+73
-46
lines changed

flashinfer/comm/allreduce.py

Lines changed: 68 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353

5454
import torch
5555

56-
from ..utils import backend_requirement, supported_compute_capability
5756
from .trtllm_ar import trtllm_allreduce_fusion
5857
from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion
5958
from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion
@@ -161,19 +160,18 @@ def is_buffer_size_sufficient(
161160

162161
def destroy(self) -> None:
163162
"""Destroy workspace and free resources."""
164-
if self._destroyed:
163+
if self._destroyed is True:
165164
return # Already destroyed, nothing to do
166165

167166
trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self.ipc_handles)
168167
self._destroyed = True
169168

170169

171170
# ============================================================================
172-
# BACKEND CHECKS - Hard requirements for decorator
171+
# BACKEND CHECKS - Hard requirements for backend selection
173172
# ============================================================================
174173

175174

176-
@supported_compute_capability([80, 86, 89, 90, 100])
177175
def _trtllm_workspace_check(
178176
backend: str,
179177
world_size: int,
@@ -188,9 +186,8 @@ def _trtllm_workspace_check(
188186
Check if trtllm backend CAN be used for workspace creation.
189187
190188
Hard requirements:
191-
- SM80+ compute capability (checked by decorator)
192-
- Single-node topology
193-
- Module availability
189+
- Single-node topology (multi-node not supported)
190+
194191
"""
195192
# trtllm is optimized for single-node
196193
if topology == "multi_node":
@@ -199,7 +196,6 @@ def _trtllm_workspace_check(
199196
return True
200197

201198

202-
@supported_compute_capability([90, 100])
203199
def _mnnvl_workspace_check(
204200
backend: str,
205201
world_size: int,
@@ -213,20 +209,13 @@ def _mnnvl_workspace_check(
213209
"""
214210
Check if mnnvl backend CAN be used for workspace creation.
215211
216-
Hard requirements:
217-
- SM90+ compute capability (checked by decorator)
218-
- Multi-node topology
219-
- Module availability
220212
"""
221-
# MNNVL is designed for multi-node
222-
if topology == "single_node":
223-
return False
224213

225214
return True
226215

227216

228217
# ============================================================================
229-
# HEURISTIC - Performance-based selection for decorator
218+
# HEURISTIC - Performance-based backend selection
230219
# ============================================================================
231220

232221

@@ -239,6 +228,7 @@ def _workspace_creation_heuristic(
239228
hidden_dim: int,
240229
dtype: torch.dtype,
241230
topology: str,
231+
# TODO(nvmbreughe): Remove this
242232
**kwargs,
243233
) -> list[str]:
244234
"""
@@ -276,39 +266,33 @@ def _workspace_creation_heuristic(
276266
return ["mnnvl"]
277267

278268
# Single-node scenarios
279-
problem_size = max_token_num * hidden_dim
269+
return ["mnnvl"]
270+
# problem_size = max_token_num * hidden_dim
280271

281-
# Large problems (>4M elements): trtllm optimized for throughput
282-
if problem_size > 4 * 1024 * 1024:
283-
if "trtllm" in suitable_backends:
284-
return ["trtllm"]
272+
# # Large problems (>4M elements): trtllm optimized for throughput
273+
# if problem_size > 4 * 1024 * 1024:
274+
# if "trtllm" in suitable_backends:
275+
# return ["trtllm"]
285276

286-
# Small token counts (<128): trtllm one-shot has better latency
287-
if max_token_num < 128:
288-
if "trtllm" in suitable_backends:
289-
return ["trtllm"]
277+
# # Small token counts (<128): trtllm one-shot has better latency
278+
# if max_token_num < 128:
279+
# if "trtllm" in suitable_backends:
280+
# return ["trtllm"]
290281

291-
# Small world sizes (<=4): trtllm one-shot efficient
292-
if world_size <= 4:
293-
if "trtllm" in suitable_backends:
294-
return ["trtllm"]
282+
# # Small world sizes (<=4): trtllm one-shot efficient
283+
# if world_size <= 4:
284+
# if "trtllm" in suitable_backends:
285+
# return ["trtllm"]
295286

296-
# Default: return first available
297-
return [suitable_backends[0]]
287+
# # Default: return first available
288+
# return [suitable_backends[0]]
298289

299290

300291
# ============================================================================
301-
# WORKSPACE CREATION - Uses decorator for all validation
292+
# WORKSPACE CREATION
302293
# ============================================================================
303294

304295

305-
@backend_requirement(
306-
backend_checks={
307-
"trtllm": _trtllm_workspace_check,
308-
"mnnvl": _mnnvl_workspace_check,
309-
},
310-
heuristic_func=_workspace_creation_heuristic,
311-
)
312296
def create_allreduce_fusion_workspace(
313297
backend: Literal["trtllm", "mnnvl", "auto"] = "auto",
314298
world_size: int = None,
@@ -324,7 +308,7 @@ def create_allreduce_fusion_workspace(
324308
"""
325309
Create workspace for AllReduce fusion operations.
326310
327-
Backend selection (checks + heuristics) handled by @backend_requirement decorator.
311+
Backend selection uses topology-based checks and heuristics.
328312
329313
**Important: Workspace Reusability**
330314
The workspace is allocated based on the total size (max_token_num * hidden_dim * dtype_size).
@@ -393,13 +377,51 @@ def create_allreduce_fusion_workspace(
393377
... )
394378
>>> print(workspace.backend) # "mnnvl"
395379
"""
396-
# Decorator has validated backend - now create workspace
397-
# If backend="auto", decorator has selected the best one and stored it
398-
399-
# Get actual backend (decorator resolved "auto" to concrete backend)
380+
if gpus_per_node is None:
381+
gpus_per_node = min(torch.cuda.device_count(), world_size)
382+
# Determine the actual backend to use
400383
if backend == "auto":
401-
# Decorator stored the selected backend in suitable_auto_backends
402-
actual_backend = create_allreduce_fusion_workspace.suitable_auto_backends[0]
384+
# Find suitable backends based on topology (anny CC check needs to be checked at kernel runtime, since there are no tensor available at this point)
385+
suitable_backends = []
386+
if _trtllm_workspace_check(
387+
backend=backend,
388+
world_size=world_size,
389+
rank=rank,
390+
max_token_num=max_token_num,
391+
hidden_dim=hidden_dim,
392+
dtype=dtype,
393+
topology=topology,
394+
):
395+
suitable_backends.append("trtllm")
396+
if _mnnvl_workspace_check(
397+
backend=backend,
398+
world_size=world_size,
399+
rank=rank,
400+
max_token_num=max_token_num,
401+
hidden_dim=hidden_dim,
402+
dtype=dtype,
403+
topology=topology,
404+
):
405+
suitable_backends.append("mnnvl")
406+
407+
if not suitable_backends:
408+
raise ValueError(
409+
f"No suitable backend found for topology={topology}. "
410+
f"trtllm requires single_node topology, mnnvl works with both."
411+
)
412+
413+
# Apply heuristic to select best backend
414+
selected = _workspace_creation_heuristic(
415+
suitable_backends=suitable_backends,
416+
backend=backend,
417+
world_size=world_size,
418+
rank=rank,
419+
max_token_num=max_token_num,
420+
hidden_dim=hidden_dim,
421+
dtype=dtype,
422+
topology=topology,
423+
)
424+
actual_backend = selected[0] if selected else suitable_backends[0]
403425
else:
404426
actual_backend = backend
405427

flashinfer/comm/workspace_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
class AllReduceFusionWorkspace(ABC):
2424
"""Base class for AllReduce fusion workspaces."""
2525

26+
# Explicit type annotations for mypy (needed due to __getattr__ in subclasses)
27+
world_size: int
28+
rank: int
29+
_destroyed: bool
30+
2631
def __init__(self, world_size: int, rank: int):
2732
self.world_size = world_size
2833
self.rank = rank

0 commit comments

Comments
 (0)