@@ -233,7 +233,8 @@ def partition_graph(self) -> torch.fx.GraphModule:
233233 subgraphs = self .remove_small_acc_subgraphs (subgraphs )
234234
235235 subgraphs = self .break_subgraphs (
236- subgraphs , subgraph_size_budget = self .calculate_size_budget ()
236+ subgraphs ,
237+ subgraph_size_budget = 500 * 1024 * 1024 , # self.calculate_size_budget()
237238 )
238239
239240 # Set the number of TRT engines to be generated
@@ -309,6 +310,11 @@ def break_subgraphs(
309310 """
310311 This function breaks the subgraphs into smaller subgraphs to save CPU memory.
311312 """
313+ from torch_tensorrt .dynamo .partitioning .fusion_patterns import (
314+ get_node_in_fusion_pattern ,
315+ )
316+
317+ self .fusion_patterns = get_node_in_fusion_pattern (self .module .graph )
312318 new_subgraphs = []
313319 # We throw an error if the remaining memory is almost empty compared to the model size.
314320 # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation.
@@ -328,9 +334,26 @@ def break_subgraphs(
328334 new_subgraphs .append (broken_subgraphs [0 ])
329335 subgraph = broken_subgraphs [1 ]
330336 new_subgraphs .append (subgraph )
331-
337+ self . _varify_all_fusion_nodes_in_same_subgraph ( new_subgraphs )
332338 return new_subgraphs
333339
340+ def _varify_all_fusion_nodes_in_same_subgraph (
341+ self , subgraphs : List [Subgraph ]
342+ ) -> None :
343+ node_to_subgraph = {}
344+ for i , s in enumerate (subgraphs ):
345+ for n in s .nodes :
346+ node_to_subgraph [n ] = i
347+
348+ fusion_nodes_map_list = [
349+ len ({node_to_subgraph [n ] for n in ns }) == 1
350+ for ns in self .fusion_patterns .values ()
351+ ]
352+ assert all (
353+ fusion_nodes_map_list
354+ ), "All fusion nodes must be in the same subgraph"
355+ logger .info ("All fusion nodes are in the same subgraph." )
356+
334357 def break_subgraph_by_size (
335358 self , subgraph : Subgraph , size_to_break : int
336359 ) -> Tuple [List [Subgraph ], int , int ]:
@@ -376,9 +399,13 @@ def step_and_validate(
376399 while True :
377400 new_subgraphs = self .validate_and_correct_subgraphs (new_subgraphs )
378401 nodes_in_first_subgraph = set (new_subgraphs [0 ].nodes )
402+ nodes_in_second_subgraph = set (new_subgraphs [1 ].nodes )
379403 leaf_node = self .get_leaf_node (nodes_in_first_subgraph )
380404 broken_fusion = self .step_if_break_fusion (
381- new_subgraphs , leaf_node , nodes_in_first_subgraph
405+ new_subgraphs ,
406+ leaf_node ,
407+ nodes_in_first_subgraph ,
408+ nodes_in_second_subgraph ,
382409 )
383410 if not broken_fusion or len (new_subgraphs [1 ].nodes ) == 0 :
384411 break
@@ -390,57 +417,37 @@ def step_if_break_fusion(
390417 subgraphs : List [Subgraph ],
391418 leaf_nodes : set [torch .fx .Node ],
392419 nodes_in_first_subgraph : set [torch .fx .Node ],
420+ nodes_in_second_subgraph : set [torch .fx .Node ],
393421 ) -> bool :
394422
395423 def add_nodes (node : torch .fx .Node ) -> None :
396424 """
397425 This function adds a node and all its previous nodes to the first subgraph and removes it from the second subgraph in post order.
398426 """
399- if node .op in CALLABLE_NODE_OPS and node not in nodes_in_first_subgraph :
427+ if (
428+ node .op in CALLABLE_NODE_OPS
429+ and node not in nodes_in_first_subgraph
430+ and node in nodes_in_second_subgraph
431+ ):
432+ # Exclude all nodes already in the first subgraph
400433 nodes_in_first_subgraph .add (node )
434+ nodes_in_second_subgraph .remove (node )
401435 for input_node in node ._input_nodes :
402436 add_nodes (input_node )
403437 subgraphs [0 ].nodes .append (node )
404438 subgraphs [1 ].nodes .remove (node )
405439
406- def match_subgraph_and_step (node : torch .fx .Node ) -> bool :
407- added_nodes = False
408- for op_list in NON_BREAKABLE_OP_LISTS :
409- for i , op in enumerate (op_list ):
410- if i != len (op_list ) - 1 and op in str (node .target ):
411- # Search following ops forward using BFS. We skip search previous ops because
412- # even if it's just a subset of fusion graph, we still want it to be fused.
413-
414- users = node .users .keys ()
415- matching_nodes : set [torch .fx .Node ] = set ()
416- for following_op_idx in range (i + 1 , len (op_list )):
417- matching_nodes = set ()
418- for user in users :
419- if op_list [following_op_idx ] in str (user .target ):
420- matching_nodes .add (user )
421- if not matching_nodes :
422- break
423- users = set ()
424- for matching_node in matching_nodes :
425- for next_user in matching_node .users :
426- users .add (next_user )
427-
428- for matching_node in matching_nodes :
429- added_nodes = True
430- add_nodes (matching_node )
431-
432- if added_nodes :
433- # Early terminate the search if we have found a match because preceeding matches can cover following matches
434- break
435-
436- return True if added_nodes else False
437-
438- found_match = False
440+ fusion_broken = False
439441 for leaf in leaf_nodes :
440- if match_subgraph_and_step (leaf ):
441- found_match = True
442+ for node in self .fusion_patterns .get (leaf , []):
443+ if (
444+ node not in nodes_in_first_subgraph
445+ and node in nodes_in_second_subgraph
446+ ):
447+ fusion_broken = True
448+ add_nodes (node )
442449
443- return found_match
450+ return fusion_broken
444451
445452 def get_leaf_node (
446453 self , nodes_in_first_subgraph : set [torch .fx .Node ]
0 commit comments