From 9bf631d271ca4fbe9c39099d456989624b31e828 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 5 Nov 2025 13:16:44 -0800 Subject: [PATCH] [ET-VK] Implement select_at_dim_as_symint ## Context The SDPA custom op accepts the `input_pos` (i.e. cache position) argument as a symbolic integer. The value of the symbolic integer is obtained by selecting the first element of a cache position input tensor and converting it to symint via local_scalar_dense. Currently, ET-VK handles this in a hacky manner. 1. the select + local_scalar_dense op pattern is removed, and the cache pos tensor is passed directly into the custom sdpa ops 2. Single element tensors that have users that are all select + local_scalar_dense will be interpreted as symints instead of tensors Unfortunately, this technique will not work for the huggingface implementation of transformer models, since the cache pos input tensor has not just a single element but is expected to be a vector of integer cache positions corresponding to all cache positions that will be updated. ## Changes Introduce a custom op to capture the select + local_scalar_dense op pattern, which is the proper way to handle the op pattern. Note that a custom op is needed because this op needs to access the staging buffer data of the input tensor, whereas `select` would typically be executed via a compute shader. The reason for this is because the `input_pos` value is needed to configure the sizes of attention weight tensors participating in the custom SDPA op, so the value must be set before any command buffers are dispatched. As a consequence of this change, the previous handling of select + local scalar dense can also be removed. Differential Revision: [D86340340](https://our.internmc.facebook.com/intern/diff/D86340340/) [ghstack-poisoned] --- backends/vulkan/_passes/TARGETS | 13 --- backends/vulkan/_passes/__init__.py | 4 - .../_passes/remove_local_scalar_dense_ops.py | 110 ------------------ backends/vulkan/custom_ops_lib.py | 17 +++ .../vulkan/partitioner/vulkan_partitioner.py | 40 ------- backends/vulkan/patterns/TARGETS | 1 + backends/vulkan/patterns/__init__.py | 2 + backends/vulkan/patterns/select_as_symint.py | 104 +++++++++++++++++ backends/vulkan/runtime/VulkanBackend.cpp | 2 +- .../runtime/api/containers/StagingBuffer.h | 14 +++ .../vulkan/runtime/graph/ComputeGraph.cpp | 11 ++ backends/vulkan/runtime/graph/ComputeGraph.h | 10 ++ backends/vulkan/runtime/graph/GraphConfig.cpp | 1 + backends/vulkan/runtime/graph/GraphConfig.h | 5 +- .../vulkan/runtime/graph/ops/ExecuteNode.cpp | 10 +- .../vulkan/runtime/graph/ops/ExecuteNode.h | 4 +- .../runtime/graph/ops/impl/SymIntOps.cpp | 49 ++++++++ backends/vulkan/vulkan_preprocess.py | 4 - 18 files changed, 224 insertions(+), 177 deletions(-) delete mode 100644 backends/vulkan/_passes/remove_local_scalar_dense_ops.py create mode 100644 backends/vulkan/patterns/select_as_symint.py diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index ae1a0b79654..168fc5a7d2b 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -63,19 +63,6 @@ runtime.python_library( ], ) -runtime.python_library( - name = "remove_local_scalar_dense", - srcs = ["remove_local_scalar_dense_ops.py"], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - "//caffe2:torch", - "//executorch/exir:pass_base", - "//executorch/exir/dialects:lib", - ], -) - runtime.python_library( name = "remove_redundant_ops", srcs = ["remove_redundant_ops.py"], diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 169bd60543c..8d305ababe4 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -16,9 +16,6 @@ remove_asserts, RemoveAssertsTransform, ) -from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( - RemoveLocalScalarDenseOpsTransform, -) from executorch.backends.vulkan._passes.remove_redundant_ops import ( RemoveRedundantOpsTransform, ) @@ -35,7 +32,6 @@ "insert_prepack_nodes", "remove_asserts", "RemoveAssertsTransform", - "RemoveLocalScalarDenseOpsTransform", "RemoveRedundantOpsTransform", "ReplaceQDQPass", "SqueezeUnsqueezeInputs", diff --git a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py b/backends/vulkan/_passes/remove_local_scalar_dense_ops.py deleted file mode 100644 index 6ce3572ec0c..00000000000 --- a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import torch -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - -from torch._subclasses.fake_tensor import FakeTensor - - -def node_is_local_scalar_dense_chain(node: torch.fx.Node) -> bool: - """ - Converting a tensor to a scalar via tensor[0].item() creates a index_select + - local_scalar_dense pattern in the graph. Check if a node is the start of this pattern. - """ - if ( - node.op == "call_function" - and node.target == exir_ops.edge.aten.select_copy.int - and len(node.users) == 1 - ): - user = list(node.users.keys())[0] - return user.target == torch.ops.aten._local_scalar_dense.default - - return False - - -def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None: - """ - A scalar tensor in the Vulkan backend is a tensor that can be represented as a scalar - value instead of a Tensor object. The criteria for identifying a tensor as a scalar - tensor are as follows: - - 1. The tensor has only 1 element - 2. One of the node's uses is converting it to a scalar via `tensor[0].item()`, which - creates a index_select + local_scalar_dense pattern in the graph - - If any of these criteria are fulfilled, then tag the node for the tensor to mark it - so that it is added as a scalar value during serialization. - """ - tensor_val = node.meta["val"] - if not isinstance(tensor_val, FakeTensor): - return - - # Scalar tensors must have only one element - if tensor_val.numel() != 1: - return - - for user in node.users: - if node_is_local_scalar_dense_chain(user): - node.meta["etvk_is_scalar_tensor"] = True - - -def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None: - """ - Remove the index_select + local_scalar_dense pattern in the graph in favor of passing - the original scalar tensor directly. - """ - replace_node = node.args[0] - assert isinstance(replace_node, torch.fx.Node) - # If the argument to the local_scalar_dense op is a select op with only - # one user, and the argument to the select op is a tensor with only one - # element (i.e. a scalar tensor), then replace the entire pattern with the - # scalar tensor. - if ( - replace_node.op == "call_function" - and replace_node.target == exir_ops.edge.aten.select_copy.int - ): - # pyre-ignore - if replace_node.args[0].meta["val"].numel() == 1: - replace_node = replace_node.args[0] - assert isinstance(replace_node, torch.fx.Node) - assert replace_node.meta.get("etvk_is_scalar_tensor", True) - - with graph.inserting_after(node): - node.replace_all_uses_with(replace_node) - - -def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph: - """ - The purpose of this pass is twofold: - 1. Tag scalar tensors (see `tag_node_if_scalar_tensor()` for the criteria) - 2. Remove the index_select + local_scalar_dense pattern in the graph in favor of - passing the original scalar tensor directly (see `remove_local_scalar_dense_chain()`) - - This makes it easier to deal with scalar tensors in the Vulkan backend. In particular, - it allows serializing scalar tensors as SymInt objects instead of Tensor objects. - Because scalar tensors are often used to inform tensor shapes, their values need to - be easily accessed by the CPU during resizing logic, while also being able to reflect - updates to their value in any GPU shaders that reference them. - """ - target_op = torch.ops.aten._local_scalar_dense.default - for node in graph.nodes: - tag_node_if_scalar_tensor(node) - - if node.op == "call_function" and node.target == target_op: - remove_local_scalar_dense_chain(graph, node) - - graph.eliminate_dead_code() - return graph - - -class RemoveLocalScalarDenseOpsTransform(ExportPass): - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph) - return PassResult(graph_module, True) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 6e5aa926d37..682087585ef 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -9,6 +9,8 @@ import executorch.backends.vulkan.patterns as vk_patterns import torch.library +from torch._subclasses.fake_tensor import FakeTensor + namespace = "et_vk" lib = torch.library.Library(namespace, "DEF") @@ -614,3 +616,18 @@ def add_q8ta_q8ta_q8to_impl( ) lib.impl(name, add_q8ta_q8ta_q8to_impl, "CompositeExplicitAutograd") add_q8ta_q8ta_q8to_op = getattr(getattr(torch.ops, namespace), name) + +############################# +## select_as_symint ## +############################# + + +def select_as_symint_impl(x: torch.Tensor, dim: int, index: int): + assert isinstance(x, FakeTensor) + return x.fake_mode.shape_env.create_unbacked_symint() + + +name = "select_as_symint" +lib.define(f"{name}(Tensor x, int dim, int index) -> SymInt") +lib.impl(name, select_as_symint_impl, "Meta") +select_as_symint_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 059b3a07be0..bc3bf14bf14 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -184,36 +184,6 @@ def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]: return False, False - def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]: - """ - Scalar tensors are usually converted to scalar values in the graph via` - scalar_tensor[0].item()` in Python, which translates to a chain of - `local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph. - This function marks the entire chain as supported by the Vulkan delegate. - - Later, within vulkan_preprocess there will be a graph transform which replaces - the chain with passing in the scalar tensor directly. - - Similar to the `is_linear_permute` function, this function has 2 return values. - """ - if node.target == exir_ops.edge.aten.select_copy.int: - if len(node.users) != 1: - return False, False - # pyre-ignore - if node.args[0].meta["val"].numel() != 1: - return False, False - - local_scalar_dense = list(node.users.keys())[0] - if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default: - return False, False - - return self.is_in_local_scalar_dense_chain(local_scalar_dense) - - if node.target == torch.ops.aten._local_scalar_dense.default: - return True, all(self.node_is_compatible(user)[0] for user in node.users) - - return False, False - def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": logger.info( @@ -261,16 +231,6 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901 self.log_skip(node, "permute node of non compatible linear node") return False - ( - is_in_local_scalar_dense_chain, - dst_node_is_compatible, - ) = self.is_in_local_scalar_dense_chain(node) - if is_in_local_scalar_dense_chain and dst_node_is_compatible: - return True - elif is_in_local_scalar_dense_chain: - self.log_skip(node, "local scalar dense of incompatible op node") - return False - features = None if target not in vulkan_supported_ops: # For some ops, i.e. custom ops the name is registered instead of the diff --git a/backends/vulkan/patterns/TARGETS b/backends/vulkan/patterns/TARGETS index 285efe2b933..ddc9cd77c04 100644 --- a/backends/vulkan/patterns/TARGETS +++ b/backends/vulkan/patterns/TARGETS @@ -12,6 +12,7 @@ runtime.python_library( "quantized_linear.py", "quantized_convolution.py", "quantized_binary.py", + "select_as_symint.py", ], visibility = [ "//executorch/backends/...", diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index e23dfc7629c..9239416dc2d 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -14,6 +14,8 @@ import executorch.backends.vulkan.patterns.rope # noqa +import executorch.backends.vulkan.patterns.select_as_symint # noqa + import torch from executorch.backends.vulkan.patterns.pattern_registry import ( diff --git a/backends/vulkan/patterns/select_as_symint.py b/backends/vulkan/patterns/select_as_symint.py new file mode 100644 index 00000000000..e7226b08188 --- /dev/null +++ b/backends/vulkan/patterns/select_as_symint.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +class SelectAsSymIntMatch(PatternMatch): + def __init__(self, local_scalar_dense_node: torch.fx.Node) -> None: + self.anchor_node = local_scalar_dense_node + self.match_found = False + + # Check if the input to local_scalar_dense is a select_copy node + if len(local_scalar_dense_node.args) < 1: + return + + select_node = local_scalar_dense_node.args[0] + if not isinstance(select_node, torch.fx.Node): + return + + if ( + select_node.op != "call_function" + or select_node.target != exir_ops.edge.aten.select_copy.int + ): + return + + # select_copy.int has signature: select_copy(Tensor self, int dim, int index) + if len(select_node.args) < 3: + return + + self.select_node = select_node + + self.tensor_node = select_node.args[0] + self.dim_node = select_node.args[1] + self.index_node = select_node.args[2] + + self.all_nodes = [ + self.anchor_node, + self.select_node, + self.tensor_node, + self.dim_node, + self.index_node, + ] + + self.match_found = True + + +@register_pattern_detector("select_as_symint") +def find_select_as_symint_patterns( + node: torch.fx.Node, +) -> Optional[SelectAsSymIntMatch]: + if node.target != torch.ops.aten._local_scalar_dense.default: + return None + + matched_pattern = SelectAsSymIntMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +@register_pattern_replacement("select_as_symint") +def replace_select_local_scalar_dense_with_select_as_symint( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: SelectAsSymIntMatch, +): + with graph_module.graph.inserting_before(match.anchor_node): + new_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.select_as_symint.default, + args=( + match.tensor_node, + match.dim_node, + match.index_node, + ), + ) + + new_node.meta["val"] = match.anchor_node.meta["val"] + match.anchor_node.replace_all_uses_with(new_node) + + # # Remove both the local_scalar_dense and select_copy nodes + # graph_module.graph.erase_node(match.anchor_node) + # # Only erase select_node if it has no other users + # if len(match.select_node.users) == 0: + # graph_module.graph.erase_node(match.select_node) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index fe8cc83c481..cfa1242fbbf 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -649,7 +649,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { } } - if (should_propagate_resize) { + if (should_propagate_resize || compute_graph->has_data_dependent_shapes()) { compute_graph->propagate_resize(); } diff --git a/backends/vulkan/runtime/api/containers/StagingBuffer.h b/backends/vulkan/runtime/api/containers/StagingBuffer.h index 6d0e5a4a457..09788e66b0f 100644 --- a/backends/vulkan/runtime/api/containers/StagingBuffer.h +++ b/backends/vulkan/runtime/api/containers/StagingBuffer.h @@ -112,6 +112,20 @@ class StagingBuffer final { inline void set_staging_zeros() { memset(data(), 0, nbytes()); } + + template + T select_element_at_dim( + const std::vector& sizes, + const int64_t dim, + const int64_t index) { + int64_t stride = 1; + for (size_t i = dim + 1; i < sizes.size(); ++i) { + stride *= sizes[i]; + } + const int64_t offset = index * stride; + const T* typed_data = reinterpret_cast(data()); + return typed_data[offset]; + } }; } // namespace api diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 2ec63a89df8..d7f98d3244f 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -683,6 +683,17 @@ int32_t ComputeGraph::read_symint(const ValueRef idx) { return get_symint(idx)->get(); } +ValueRef ComputeGraph::staging_of(const ValueRef idx) { + for (size_t i = 0; i < inputs_.size(); ++i) { + if (inputs_[i].value == idx) { + if (is_valid(inputs_[i].staging)) { + return inputs_[i].staging; + } + } + } + VK_THROW("Could not find staging buffer for value at index ", idx); +} + SharedObject& ComputeGraph::get_shared_object(const int64_t idx) { if (idx >= shared_objects_.size()) { shared_objects_.resize(static_cast(idx + 1)); diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index dbd5536279c..f7de7e183de 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -824,6 +824,8 @@ class ComputeGraph final { inputs_.push_back({idx, kDummyValueRef}); } + ValueRef staging_of(const ValueRef idx); + inline void set_val_as_output(const ValueRef idx) { outputs_.push_back({idx, kDummyValueRef}); } @@ -1081,6 +1083,14 @@ class ComputeGraph final { return can_use_int8_dot_product_; } + inline void set_has_data_dependent_shapes() { + config_.has_data_dependent_shapes = true; + } + + inline bool has_data_dependent_shapes() const { + return config_.has_data_dependent_shapes; + } + /* * Check whether the GPU supports 8 bit buffers. */ diff --git a/backends/vulkan/runtime/graph/GraphConfig.cpp b/backends/vulkan/runtime/graph/GraphConfig.cpp index 20b8f6f7c00..9a919a42573 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.cpp +++ b/backends/vulkan/runtime/graph/GraphConfig.cpp @@ -64,6 +64,7 @@ GraphConfig::GraphConfig() { enable_local_wg_size_override = false; local_wg_size_override = {}; + has_data_dependent_shapes = false; expect_dynamic_shapes = false; force_resize = false; diff --git a/backends/vulkan/runtime/graph/GraphConfig.h b/backends/vulkan/runtime/graph/GraphConfig.h index 7533df3b685..9a753775650 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.h +++ b/backends/vulkan/runtime/graph/GraphConfig.h @@ -33,8 +33,11 @@ struct GraphConfig final { bool enable_local_wg_size_override; utils::uvec3 local_wg_size_override; + // If true, then resize functions should always be called even if input shapes + // have not changed. + bool has_data_dependent_shapes = false; // Whether or not the ComputeGraph should expect input shapes to be dynamic - bool expect_dynamic_shapes; + bool expect_dynamic_shapes = false; // Used for testing/debugging only. Forces ExecuteNode to trigger the resize // function even if none of the inputs have been updated. bool force_resize = false; diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp index aa46ee76336..40cc67517ea 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp @@ -14,15 +14,19 @@ ExecuteNode::ExecuteNode( const ResizeFunction& resize_fn, const std::vector& resize_args, const std::vector& args, - const std::string& name) + const std::string& name, + const bool has_data_dependent_shape) : resize_fn_(resize_fn), resize_args_(resize_args), args_(args), - name_(name) {} + name_(name), + has_data_dependent_shape_(has_data_dependent_shape) {} bool ExecuteNode::trigger_resize(ComputeGraph* graph) { bool any_arg_updated = was_any_arg_updated(graph); - if (resize_fn_ && (any_arg_updated || graph->graphconfig().force_resize)) { + if (resize_fn_ && + (any_arg_updated || graph->graphconfig().force_resize || + has_data_dependent_shape_)) { resize_fn_(graph, args_, resize_args_); any_arg_updated = true; } diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 323036cef90..4dbad882dea 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -57,7 +57,8 @@ class ExecuteNode { const ResizeFunction& resize_fn = nullptr, const std::vector& resize_args = {}, const std::vector& args = {}, - const std::string& name = "Graph Node"); + const std::string& name = "Graph Node", + const bool has_data_dependent_shape = false); virtual ~ExecuteNode() = default; @@ -87,6 +88,7 @@ class ExecuteNode { const std::vector resize_args_; const std::vector args_; const std::string name_; + bool has_data_dependent_shape_ = false; }; } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp index f07522d2578..eb03639abf1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp @@ -81,9 +81,58 @@ void sym_add(ComputeGraph& graph, const std::vector& args) { new ExecuteNode(resize_sym_add_node, args)); } +void select_as_symint_impl( + ComputeGraph* graph, + const std::vector& unused, + const std::vector& args) { + (void)unused; // Unused parameter + + const ValueRef x = args.at(0); + const ValueRef dim = args.at(1); + const ValueRef index = args.at(2); + const ValueRef out = args.at(3); + + const int64_t dim_val = graph->extract_scalar(dim); + int64_t index_val = graph->extract_scalar(index); + + const std::vector x_sizes = graph->sizes_of(x); + const vkapi::ScalarType x_dtype = graph->dtype_of(x); + + if (index_val < 0) { + index_val += x_sizes[dim_val]; + } + + const StagingPtr x_staging = graph->get_staging(graph->staging_of(x)); + + int32_t x_val; + switch (x_dtype) { + case vkapi::ScalarType::Int: + x_val = x_staging->select_element_at_dim( + x_sizes, dim_val, index_val); + break; + case vkapi::ScalarType::Long: + x_val = static_cast(x_staging->select_element_at_dim( + x_sizes, dim_val, index_val)); + break; + default: + VK_THROW("Unsupported dtype for select_as_symint"); + } + + graph->set_symint(out, x_val); +} + +void select_as_symint(ComputeGraph& graph, const std::vector& args) { + select_as_symint_impl(&graph, {}, args); + + graph.execute_nodes().emplace_back(new ExecuteNode( + select_as_symint_impl, args, {}, "select_as_symint", true)); + graph.set_has_data_dependent_shapes(); +} + REGISTER_OPERATORS { VK_REGISTER_OP(sym_size.int, sym_size_int); VK_REGISTER_OP(add, sym_add); + VK_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint); } } // namespace vkcompute diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 876f7fa8900..57863703498 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -22,7 +22,6 @@ FoldQDQPass, FuseQuantizedOpsTransform, insert_prepack_nodes, - RemoveLocalScalarDenseOpsTransform, RemoveRedundantOpsTransform, ReplaceQDQPass, SqueezeUnsqueezeInputs, @@ -193,9 +192,6 @@ def preprocess( # noqa: C901 program, [ RemoveAssertsTransform(), - # Since this pass may replace a scalar argument with a tensor argument, - # this pass may result in a non ATen compliant graph structure. - RemoveLocalScalarDenseOpsTransform(), insert_prepack_nodes, ], )