Skip to content

Commit 9ee7e67

Browse files
committed
Added new subgraph definition paradigm and revised matching logic
1 parent 51197bc commit 9ee7e67

File tree

3 files changed

+230
-40
lines changed

3 files changed

+230
-40
lines changed

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,8 @@ def partition_graph(self) -> torch.fx.GraphModule:
233233
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
234234

235235
subgraphs = self.break_subgraphs(
236-
subgraphs, subgraph_size_budget=self.calculate_size_budget()
236+
subgraphs,
237+
subgraph_size_budget=500 * 1024 * 1024, # self.calculate_size_budget()
237238
)
238239

239240
# Set the number of TRT engines to be generated
@@ -309,6 +310,11 @@ def break_subgraphs(
309310
"""
310311
This function breaks the subgraphs into smaller subgraphs to save CPU memory.
311312
"""
313+
from torch_tensorrt.dynamo.partitioning.fusion_patterns import (
314+
get_node_in_fusion_pattern,
315+
)
316+
317+
self.fusion_patterns = get_node_in_fusion_pattern(self.module.graph)
312318
new_subgraphs = []
313319
# We throw an error if the remaining memory is almost empty compared to the model size.
314320
# i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation.
@@ -328,9 +334,26 @@ def break_subgraphs(
328334
new_subgraphs.append(broken_subgraphs[0])
329335
subgraph = broken_subgraphs[1]
330336
new_subgraphs.append(subgraph)
331-
337+
self._varify_all_fusion_nodes_in_same_subgraph(new_subgraphs)
332338
return new_subgraphs
333339

340+
def _varify_all_fusion_nodes_in_same_subgraph(
341+
self, subgraphs: List[Subgraph]
342+
) -> None:
343+
node_to_subgraph = {}
344+
for i, s in enumerate(subgraphs):
345+
for n in s.nodes:
346+
node_to_subgraph[n] = i
347+
348+
fusion_nodes_map_list = [
349+
len({node_to_subgraph[n] for n in ns}) == 1
350+
for ns in self.fusion_patterns.values()
351+
]
352+
assert all(
353+
fusion_nodes_map_list
354+
), "All fusion nodes must be in the same subgraph"
355+
logger.info("All fusion nodes are in the same subgraph.")
356+
334357
def break_subgraph_by_size(
335358
self, subgraph: Subgraph, size_to_break: int
336359
) -> Tuple[List[Subgraph], int, int]:
@@ -376,9 +399,13 @@ def step_and_validate(
376399
while True:
377400
new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs)
378401
nodes_in_first_subgraph = set(new_subgraphs[0].nodes)
402+
nodes_in_second_subgraph = set(new_subgraphs[1].nodes)
379403
leaf_node = self.get_leaf_node(nodes_in_first_subgraph)
380404
broken_fusion = self.step_if_break_fusion(
381-
new_subgraphs, leaf_node, nodes_in_first_subgraph
405+
new_subgraphs,
406+
leaf_node,
407+
nodes_in_first_subgraph,
408+
nodes_in_second_subgraph,
382409
)
383410
if not broken_fusion or len(new_subgraphs[1].nodes) == 0:
384411
break
@@ -390,57 +417,37 @@ def step_if_break_fusion(
390417
subgraphs: List[Subgraph],
391418
leaf_nodes: set[torch.fx.Node],
392419
nodes_in_first_subgraph: set[torch.fx.Node],
420+
nodes_in_second_subgraph: set[torch.fx.Node],
393421
) -> bool:
394422

395423
def add_nodes(node: torch.fx.Node) -> None:
396424
"""
397425
This function adds a node and all its previous nodes to the first subgraph and removes it from the second subgraph in post order.
398426
"""
399-
if node.op in CALLABLE_NODE_OPS and node not in nodes_in_first_subgraph:
427+
if (
428+
node.op in CALLABLE_NODE_OPS
429+
and node not in nodes_in_first_subgraph
430+
and node in nodes_in_second_subgraph
431+
):
432+
# Exclude all nodes already in the first subgraph
400433
nodes_in_first_subgraph.add(node)
434+
nodes_in_second_subgraph.remove(node)
401435
for input_node in node._input_nodes:
402436
add_nodes(input_node)
403437
subgraphs[0].nodes.append(node)
404438
subgraphs[1].nodes.remove(node)
405439

406-
def match_subgraph_and_step(node: torch.fx.Node) -> bool:
407-
added_nodes = False
408-
for op_list in NON_BREAKABLE_OP_LISTS:
409-
for i, op in enumerate(op_list):
410-
if i != len(op_list) - 1 and op in str(node.target):
411-
# Search following ops forward using BFS. We skip search previous ops because
412-
# even if it's just a subset of fusion graph, we still want it to be fused.
413-
414-
users = node.users.keys()
415-
matching_nodes: set[torch.fx.Node] = set()
416-
for following_op_idx in range(i + 1, len(op_list)):
417-
matching_nodes = set()
418-
for user in users:
419-
if op_list[following_op_idx] in str(user.target):
420-
matching_nodes.add(user)
421-
if not matching_nodes:
422-
break
423-
users = set()
424-
for matching_node in matching_nodes:
425-
for next_user in matching_node.users:
426-
users.add(next_user)
427-
428-
for matching_node in matching_nodes:
429-
added_nodes = True
430-
add_nodes(matching_node)
431-
432-
if added_nodes:
433-
# Early terminate the search if we have found a match because preceeding matches can cover following matches
434-
break
435-
436-
return True if added_nodes else False
437-
438-
found_match = False
440+
fusion_broken = False
439441
for leaf in leaf_nodes:
440-
if match_subgraph_and_step(leaf):
441-
found_match = True
442+
for node in self.fusion_patterns.get(leaf, []):
443+
if (
444+
node not in nodes_in_first_subgraph
445+
and node in nodes_in_second_subgraph
446+
):
447+
fusion_broken = True
448+
add_nodes(node)
442449

443-
return found_match
450+
return fusion_broken
444451

445452
def get_leaf_node(
446453
self, nodes_in_first_subgraph: set[torch.fx.Node]
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from typing import Dict, List, Set
2+
3+
import torch
4+
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
5+
from torch.ops import aten
6+
7+
8+
class ConvBNReLU(torch.nn.Module): # type: ignore[misc]
9+
def __init__(self) -> None:
10+
super().__init__()
11+
12+
def forward(
13+
self,
14+
x: torch.Tensor,
15+
weight: torch.Tensor,
16+
bias: torch.Tensor,
17+
stride: List[int],
18+
padding: List[int],
19+
dilation: List[int],
20+
transposed: bool,
21+
output_padding: List[int],
22+
groups: int,
23+
bn_weight: torch.Tensor,
24+
bn_bias: torch.Tensor,
25+
running_mean: torch.Tensor,
26+
running_var: torch.Tensor,
27+
momentum: float,
28+
eps: float,
29+
) -> torch.Tensor:
30+
x = aten.convolution.default(
31+
x,
32+
weight,
33+
bias,
34+
stride,
35+
padding,
36+
dilation,
37+
transposed,
38+
output_padding,
39+
groups,
40+
)
41+
x = aten._native_batch_norm_legit_no_training.default(
42+
x, bn_weight, bn_bias, running_mean, running_var, momentum, eps
43+
)[0]
44+
x = aten.relu.default(x)
45+
return x
46+
47+
48+
class ConvReLU(torch.nn.Module): # type: ignore[misc]
49+
def __init__(self) -> None:
50+
super().__init__()
51+
52+
def forward(
53+
self,
54+
x: torch.Tensor,
55+
weight: torch.Tensor,
56+
bias: torch.Tensor,
57+
stride: List[int],
58+
padding: List[int],
59+
dilation: List[int],
60+
transposed: bool,
61+
output_padding: List[int],
62+
groups: int,
63+
) -> torch.Tensor:
64+
x = aten.convolution.default(
65+
x,
66+
weight,
67+
bias,
68+
stride,
69+
padding,
70+
dilation,
71+
transposed,
72+
output_padding,
73+
groups,
74+
)
75+
x = aten.relu.default(x)
76+
return x
77+
78+
79+
class ConvGelu(torch.nn.Module): # type: ignore[misc]
80+
def __init__(self) -> None:
81+
super().__init__()
82+
83+
def forward(
84+
self,
85+
x: torch.Tensor,
86+
weight: torch.Tensor,
87+
bias: torch.Tensor,
88+
stride: List[int],
89+
padding: List[int],
90+
dilation: List[int],
91+
transposed: bool,
92+
output_padding: List[int],
93+
groups: int,
94+
) -> torch.Tensor:
95+
x = aten.convolution.default(
96+
x,
97+
weight,
98+
bias,
99+
stride,
100+
padding,
101+
dilation,
102+
transposed,
103+
output_padding,
104+
groups,
105+
)
106+
x = aten.gelu.default(x)
107+
return x
108+
109+
110+
class ConvSilu(torch.nn.Module): # type: ignore[misc]
111+
def __init__(self) -> None:
112+
super().__init__()
113+
114+
def forward(
115+
self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
116+
) -> torch.Tensor:
117+
x = aten.convolution.default(
118+
x, weight, bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
119+
)
120+
x = aten.silu.default(x)
121+
return x
122+
123+
124+
class MulAdd(torch.nn.Module): # type: ignore[misc]
125+
def __init__(self) -> None:
126+
super().__init__()
127+
128+
def forward(
129+
self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
130+
) -> torch.Tensor:
131+
x = aten.mul.Tensor(x, weight)
132+
x = aten.add.Tensor(x, bias)
133+
return x
134+
135+
136+
class MulMul(torch.nn.Module): # type: ignore[misc]
137+
def __init__(self) -> None:
138+
super().__init__()
139+
140+
def forward(
141+
self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
142+
) -> torch.Tensor:
143+
x = aten.mul.Tensor(x, y)
144+
x = aten.mul.Tensor(x, z)
145+
return x
146+
147+
148+
All_FUSION_PATTERNS = [
149+
ConvBNReLU,
150+
ConvReLU,
151+
ConvGelu,
152+
ConvSilu,
153+
MulAdd,
154+
MulMul,
155+
]
156+
157+
158+
def get_node_in_fusion_pattern(
159+
graph: torch.fx.Graph,
160+
) -> Dict[torch.fx.Node, Set[torch.fx.Node]]:
161+
"""
162+
This function gets the nodes map of the fusion pattern from the graph.
163+
Key: node that appears in the fusion pattern
164+
Value: the list of nodes that should be fused together
165+
"""
166+
fusion_nodes = {}
167+
for pattern in All_FUSION_PATTERNS:
168+
pattern_graph = torch.fx.symbolic_trace(pattern())
169+
subgraph_matcher = SubgraphMatcher(pattern_graph.graph)
170+
match_result = subgraph_matcher.match(graph)
171+
for match in match_result:
172+
fusion_group = {
173+
node
174+
for node in match.nodes_map.values()
175+
if node
176+
and type(node) == torch.fx.Node
177+
and node.op == "call_function"
178+
and node not in match.placeholder_nodes
179+
}
180+
for node in fusion_group:
181+
fusion_nodes[node] = fusion_group
182+
183+
return fusion_nodes

py/torch_tensorrt/dynamo/partitioning/fusion_subgraphs.py

Whitespace-only changes.

0 commit comments

Comments
 (0)