Skip to content

Commit 0726751

Browse files
authored
[ET-VK] Implement select_at_dim_as_symint
Differential Revision: D86340340 Pull Request resolved: #15617
1 parent dd46504 commit 0726751

File tree

18 files changed

+224
-178
lines changed

18 files changed

+224
-178
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,6 @@ runtime.python_library(
6363
],
6464
)
6565

66-
runtime.python_library(
67-
name = "remove_local_scalar_dense",
68-
srcs = ["remove_local_scalar_dense_ops.py"],
69-
visibility = [
70-
"//executorch/backends/...",
71-
],
72-
deps = [
73-
"//caffe2:torch",
74-
"//executorch/exir:pass_base",
75-
"//executorch/exir/dialects:lib",
76-
],
77-
)
78-
7966
runtime.python_library(
8067
name = "remove_redundant_ops",
8168
srcs = ["remove_redundant_ops.py"],
@@ -161,7 +148,6 @@ runtime.python_library(
161148
":fuse_quantized_ops",
162149
":insert_prepack_nodes",
163150
":remove_asserts",
164-
":remove_local_scalar_dense",
165151
":remove_redundant_ops",
166152
":replace_qdq",
167153
":squeeze_unsqueeze_inputs",

backends/vulkan/_passes/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
remove_asserts,
1717
RemoveAssertsTransform,
1818
)
19-
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
20-
RemoveLocalScalarDenseOpsTransform,
21-
)
2219
from executorch.backends.vulkan._passes.remove_redundant_ops import (
2320
RemoveRedundantOpsTransform,
2421
)
@@ -35,7 +32,6 @@
3532
"insert_prepack_nodes",
3633
"remove_asserts",
3734
"RemoveAssertsTransform",
38-
"RemoveLocalScalarDenseOpsTransform",
3935
"RemoveRedundantOpsTransform",
4036
"ReplaceQDQPass",
4137
"SqueezeUnsqueezeInputs",

backends/vulkan/_passes/remove_local_scalar_dense_ops.py

Lines changed: 0 additions & 110 deletions
This file was deleted.

backends/vulkan/custom_ops_lib.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import executorch.backends.vulkan.patterns as vk_patterns
1010
import torch.library
1111

12+
from torch._subclasses.fake_tensor import FakeTensor
13+
1214
namespace = "et_vk"
1315
lib = torch.library.Library(namespace, "DEF")
1416

@@ -614,3 +616,18 @@ def add_q8ta_q8ta_q8to_impl(
614616
)
615617
lib.impl(name, add_q8ta_q8ta_q8to_impl, "CompositeExplicitAutograd")
616618
add_q8ta_q8ta_q8to_op = getattr(getattr(torch.ops, namespace), name)
619+
620+
#############################
621+
## select_as_symint ##
622+
#############################
623+
624+
625+
def select_as_symint_impl(x: torch.Tensor, dim: int, index: int):
626+
assert isinstance(x, FakeTensor)
627+
return x.fake_mode.shape_env.create_unbacked_symint()
628+
629+
630+
name = "select_as_symint"
631+
lib.define(f"{name}(Tensor x, int dim, int index) -> SymInt")
632+
lib.impl(name, select_as_symint_impl, "Meta")
633+
select_as_symint_op = getattr(getattr(torch.ops, namespace), name)

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -184,36 +184,6 @@ def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]:
184184

185185
return False, False
186186

187-
def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]:
188-
"""
189-
Scalar tensors are usually converted to scalar values in the graph via`
190-
scalar_tensor[0].item()` in Python, which translates to a chain of
191-
`local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph.
192-
This function marks the entire chain as supported by the Vulkan delegate.
193-
194-
Later, within vulkan_preprocess there will be a graph transform which replaces
195-
the chain with passing in the scalar tensor directly.
196-
197-
Similar to the `is_linear_permute` function, this function has 2 return values.
198-
"""
199-
if node.target == exir_ops.edge.aten.select_copy.int:
200-
if len(node.users) != 1:
201-
return False, False
202-
# pyre-ignore
203-
if node.args[0].meta["val"].numel() != 1:
204-
return False, False
205-
206-
local_scalar_dense = list(node.users.keys())[0]
207-
if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default:
208-
return False, False
209-
210-
return self.is_in_local_scalar_dense_chain(local_scalar_dense)
211-
212-
if node.target == torch.ops.aten._local_scalar_dense.default:
213-
return True, all(self.node_is_compatible(user)[0] for user in node.users)
214-
215-
return False, False
216-
217187
def log_skip(self, node: torch.fx.Node, reason: str) -> None:
218188
if node.op == "call_function":
219189
logger.info(
@@ -261,16 +231,6 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
261231
self.log_skip(node, "permute node of non compatible linear node")
262232
return False
263233

264-
(
265-
is_in_local_scalar_dense_chain,
266-
dst_node_is_compatible,
267-
) = self.is_in_local_scalar_dense_chain(node)
268-
if is_in_local_scalar_dense_chain and dst_node_is_compatible:
269-
return True
270-
elif is_in_local_scalar_dense_chain:
271-
self.log_skip(node, "local scalar dense of incompatible op node")
272-
return False
273-
274234
features = None
275235
if target not in vulkan_supported_ops:
276236
# For some ops, i.e. custom ops the name is registered instead of the

backends/vulkan/patterns/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ runtime.python_library(
1212
"quantized_linear.py",
1313
"quantized_convolution.py",
1414
"quantized_binary.py",
15+
"select_as_symint.py",
1516
],
1617
visibility = [
1718
"//executorch/backends/...",

backends/vulkan/patterns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import executorch.backends.vulkan.patterns.rope # noqa
1616

17+
import executorch.backends.vulkan.patterns.select_as_symint # noqa
18+
1719
import torch
1820

1921
from executorch.backends.vulkan.patterns.pattern_registry import (
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
11+
from executorch.backends.vulkan.patterns.pattern_registry import (
12+
PatternMatch,
13+
register_pattern_detector,
14+
register_pattern_replacement,
15+
)
16+
17+
from executorch.exir import ExportedProgram
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
20+
21+
class SelectAsSymIntMatch(PatternMatch):
22+
def __init__(self, local_scalar_dense_node: torch.fx.Node) -> None:
23+
self.anchor_node = local_scalar_dense_node
24+
self.match_found = False
25+
26+
# Check if the input to local_scalar_dense is a select_copy node
27+
if len(local_scalar_dense_node.args) < 1:
28+
return
29+
30+
select_node = local_scalar_dense_node.args[0]
31+
if not isinstance(select_node, torch.fx.Node):
32+
return
33+
34+
if (
35+
select_node.op != "call_function"
36+
or select_node.target != exir_ops.edge.aten.select_copy.int
37+
):
38+
return
39+
40+
# select_copy.int has signature: select_copy(Tensor self, int dim, int index)
41+
if len(select_node.args) < 3:
42+
return
43+
44+
self.select_node = select_node
45+
46+
self.tensor_node = select_node.args[0]
47+
self.dim_node = select_node.args[1]
48+
self.index_node = select_node.args[2]
49+
50+
self.all_nodes = [
51+
self.anchor_node,
52+
self.select_node,
53+
self.tensor_node,
54+
self.dim_node,
55+
self.index_node,
56+
]
57+
58+
self.match_found = True
59+
60+
61+
@register_pattern_detector("select_as_symint")
62+
def find_select_as_symint_patterns(
63+
node: torch.fx.Node,
64+
) -> Optional[SelectAsSymIntMatch]:
65+
if node.target != torch.ops.aten._local_scalar_dense.default:
66+
return None
67+
68+
matched_pattern = SelectAsSymIntMatch(node)
69+
if matched_pattern.match_found:
70+
return matched_pattern
71+
72+
return None
73+
74+
75+
##
76+
## Pattern Replacement
77+
##
78+
79+
80+
@register_pattern_replacement("select_as_symint")
81+
def replace_select_local_scalar_dense_with_select_as_symint(
82+
ep: ExportedProgram,
83+
graph_module: torch.fx.GraphModule,
84+
match: SelectAsSymIntMatch,
85+
):
86+
with graph_module.graph.inserting_before(match.anchor_node):
87+
new_node = graph_module.graph.create_node(
88+
"call_function",
89+
exir_ops.edge.et_vk.select_as_symint.default,
90+
args=(
91+
match.tensor_node,
92+
match.dim_node,
93+
match.index_node,
94+
),
95+
)
96+
97+
new_node.meta["val"] = match.anchor_node.meta["val"]
98+
match.anchor_node.replace_all_uses_with(new_node)
99+
100+
# # Remove both the local_scalar_dense and select_copy nodes
101+
# graph_module.graph.erase_node(match.anchor_node)
102+
# # Only erase select_node if it has no other users
103+
# if len(match.select_node.users) == 0:
104+
# graph_module.graph.erase_node(match.select_node)

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
649649
}
650650
}
651651

652-
if (should_propagate_resize) {
652+
if (should_propagate_resize || compute_graph->has_data_dependent_shapes()) {
653653
compute_graph->propagate_resize();
654654
}
655655

backends/vulkan/runtime/api/containers/StagingBuffer.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,20 @@ class StagingBuffer final {
112112
inline void set_staging_zeros() {
113113
memset(data(), 0, nbytes());
114114
}
115+
116+
template <typename T>
117+
T select_element_at_dim(
118+
const std::vector<int64_t>& sizes,
119+
const int64_t dim,
120+
const int64_t index) {
121+
int64_t stride = 1;
122+
for (size_t i = dim + 1; i < sizes.size(); ++i) {
123+
stride *= sizes[i];
124+
}
125+
const int64_t offset = index * stride;
126+
const T* typed_data = reinterpret_cast<const T*>(data());
127+
return typed_data[offset];
128+
}
115129
};
116130

117131
} // namespace api

0 commit comments

Comments
 (0)