5353
5454import torch
5555
56- from ..utils import backend_requirement , supported_compute_capability
5756from .trtllm_ar import trtllm_allreduce_fusion
5857from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion
5958from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion
@@ -161,19 +160,18 @@ def is_buffer_size_sufficient(
161160
162161 def destroy (self ) -> None :
163162 """Destroy workspace and free resources."""
164- if self ._destroyed :
163+ if self ._destroyed is True :
165164 return # Already destroyed, nothing to do
166165
167166 trtllm_destroy_ipc_workspace_for_all_reduce_fusion (self .ipc_handles )
168167 self ._destroyed = True
169168
170169
171170# ============================================================================
172- # BACKEND CHECKS - Hard requirements for decorator
171+ # BACKEND CHECKS - Hard requirements for backend selection
173172# ============================================================================
174173
175174
176- @supported_compute_capability ([80 , 86 , 89 , 90 , 100 ])
177175def _trtllm_workspace_check (
178176 backend : str ,
179177 world_size : int ,
@@ -188,9 +186,8 @@ def _trtllm_workspace_check(
188186 Check if trtllm backend CAN be used for workspace creation.
189187
190188 Hard requirements:
191- - SM80+ compute capability (checked by decorator)
192- - Single-node topology
193- - Module availability
189+ - Single-node topology (multi-node not supported)
190+
194191 """
195192 # trtllm is optimized for single-node
196193 if topology == "multi_node" :
@@ -199,7 +196,6 @@ def _trtllm_workspace_check(
199196 return True
200197
201198
202- @supported_compute_capability ([90 , 100 ])
203199def _mnnvl_workspace_check (
204200 backend : str ,
205201 world_size : int ,
@@ -213,20 +209,13 @@ def _mnnvl_workspace_check(
213209 """
214210 Check if mnnvl backend CAN be used for workspace creation.
215211
216- Hard requirements:
217- - SM90+ compute capability (checked by decorator)
218- - Multi-node topology
219- - Module availability
220212 """
221- # MNNVL is designed for multi-node
222- if topology == "single_node" :
223- return False
224213
225214 return True
226215
227216
228217# ============================================================================
229- # HEURISTIC - Performance-based selection for decorator
218+ # HEURISTIC - Performance-based backend selection
230219# ============================================================================
231220
232221
@@ -239,6 +228,7 @@ def _workspace_creation_heuristic(
239228 hidden_dim : int ,
240229 dtype : torch .dtype ,
241230 topology : str ,
231+ # TODO(nvmbreughe): Remove this
242232 ** kwargs ,
243233) -> list [str ]:
244234 """
@@ -276,39 +266,33 @@ def _workspace_creation_heuristic(
276266 return ["mnnvl" ]
277267
278268 # Single-node scenarios
279- problem_size = max_token_num * hidden_dim
269+ return ["mnnvl" ]
270+ # problem_size = max_token_num * hidden_dim
280271
281- # Large problems (>4M elements): trtllm optimized for throughput
282- if problem_size > 4 * 1024 * 1024 :
283- if "trtllm" in suitable_backends :
284- return ["trtllm" ]
272+ # # Large problems (>4M elements): trtllm optimized for throughput
273+ # if problem_size > 4 * 1024 * 1024:
274+ # if "trtllm" in suitable_backends:
275+ # return ["trtllm"]
285276
286- # Small token counts (<128): trtllm one-shot has better latency
287- if max_token_num < 128 :
288- if "trtllm" in suitable_backends :
289- return ["trtllm" ]
277+ # # Small token counts (<128): trtllm one-shot has better latency
278+ # if max_token_num < 128:
279+ # if "trtllm" in suitable_backends:
280+ # return ["trtllm"]
290281
291- # Small world sizes (<=4): trtllm one-shot efficient
292- if world_size <= 4 :
293- if "trtllm" in suitable_backends :
294- return ["trtllm" ]
282+ # # Small world sizes (<=4): trtllm one-shot efficient
283+ # if world_size <= 4:
284+ # if "trtllm" in suitable_backends:
285+ # return ["trtllm"]
295286
296- # Default: return first available
297- return [suitable_backends [0 ]]
287+ # # Default: return first available
288+ # return [suitable_backends[0]]
298289
299290
300291# ============================================================================
301- # WORKSPACE CREATION - Uses decorator for all validation
292+ # WORKSPACE CREATION
302293# ============================================================================
303294
304295
305- @backend_requirement (
306- backend_checks = {
307- "trtllm" : _trtllm_workspace_check ,
308- "mnnvl" : _mnnvl_workspace_check ,
309- },
310- heuristic_func = _workspace_creation_heuristic ,
311- )
312296def create_allreduce_fusion_workspace (
313297 backend : Literal ["trtllm" , "mnnvl" , "auto" ] = "auto" ,
314298 world_size : int = None ,
@@ -324,7 +308,7 @@ def create_allreduce_fusion_workspace(
324308 """
325309 Create workspace for AllReduce fusion operations.
326310
327- Backend selection (checks + heuristics) handled by @backend_requirement decorator .
311+ Backend selection uses topology-based checks and heuristics .
328312
329313 **Important: Workspace Reusability**
330314 The workspace is allocated based on the total size (max_token_num * hidden_dim * dtype_size).
@@ -393,13 +377,51 @@ def create_allreduce_fusion_workspace(
393377 ... )
394378 >>> print(workspace.backend) # "mnnvl"
395379 """
396- # Decorator has validated backend - now create workspace
397- # If backend="auto", decorator has selected the best one and stored it
398-
399- # Get actual backend (decorator resolved "auto" to concrete backend)
380+ if gpus_per_node is None :
381+ gpus_per_node = min (torch .cuda .device_count (), world_size )
382+ # Determine the actual backend to use
400383 if backend == "auto" :
401- # Decorator stored the selected backend in suitable_auto_backends
402- actual_backend = create_allreduce_fusion_workspace .suitable_auto_backends [0 ]
384+ # Find suitable backends based on topology (anny CC check needs to be checked at kernel runtime, since there are no tensor available at this point)
385+ suitable_backends = []
386+ if _trtllm_workspace_check (
387+ backend = backend ,
388+ world_size = world_size ,
389+ rank = rank ,
390+ max_token_num = max_token_num ,
391+ hidden_dim = hidden_dim ,
392+ dtype = dtype ,
393+ topology = topology ,
394+ ):
395+ suitable_backends .append ("trtllm" )
396+ if _mnnvl_workspace_check (
397+ backend = backend ,
398+ world_size = world_size ,
399+ rank = rank ,
400+ max_token_num = max_token_num ,
401+ hidden_dim = hidden_dim ,
402+ dtype = dtype ,
403+ topology = topology ,
404+ ):
405+ suitable_backends .append ("mnnvl" )
406+
407+ if not suitable_backends :
408+ raise ValueError (
409+ f"No suitable backend found for topology={ topology } . "
410+ f"trtllm requires single_node topology, mnnvl works with both."
411+ )
412+
413+ # Apply heuristic to select best backend
414+ selected = _workspace_creation_heuristic (
415+ suitable_backends = suitable_backends ,
416+ backend = backend ,
417+ world_size = world_size ,
418+ rank = rank ,
419+ max_token_num = max_token_num ,
420+ hidden_dim = hidden_dim ,
421+ dtype = dtype ,
422+ topology = topology ,
423+ )
424+ actual_backend = selected [0 ] if selected else suitable_backends [0 ]
403425 else :
404426 actual_backend = backend
405427
0 commit comments