diff --git a/.lintrunner.toml b/.lintrunner.toml index b366c141799..d0c9c6aef6a 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -449,3 +449,24 @@ command = [ "--", "@{{PATHSFILE}}", ] + +[[linter]] +code = 'ETVKNODEBUG' +include_patterns = [ + "backends/vulkan/**/*.glsl", +] +command = [ + 'python', + '-m', + 'lintrunner_adapters', + 'run', + 'grep_linter', + '--pattern=((DEBUG_MODE)|(GL_EXT_debug_printf))', + '--linter-name=ETVKNODEBUG', + '--error-name=Using DEBUG_MODE or GL_EXT_debug_printf in Vulkan shader', + """--error-description=\ + #define DEBUG_MODE or #extension GL_EXT_debug_printf should only be used during development! + """, + '--', + '@{{PATHSFILE}}', +] diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index ae1a0b79654..903a4a92b8e 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -63,19 +63,6 @@ runtime.python_library( ], ) -runtime.python_library( - name = "remove_local_scalar_dense", - srcs = ["remove_local_scalar_dense_ops.py"], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - "//caffe2:torch", - "//executorch/exir:pass_base", - "//executorch/exir/dialects:lib", - ], -) - runtime.python_library( name = "remove_redundant_ops", srcs = ["remove_redundant_ops.py"], @@ -161,7 +148,6 @@ runtime.python_library( ":fuse_quantized_ops", ":insert_prepack_nodes", ":remove_asserts", - ":remove_local_scalar_dense", ":remove_redundant_ops", ":replace_qdq", ":squeeze_unsqueeze_inputs", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 169bd60543c..8d305ababe4 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -16,9 +16,6 @@ remove_asserts, RemoveAssertsTransform, ) -from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( - RemoveLocalScalarDenseOpsTransform, -) from executorch.backends.vulkan._passes.remove_redundant_ops import ( RemoveRedundantOpsTransform, ) @@ -35,7 +32,6 @@ "insert_prepack_nodes", "remove_asserts", "RemoveAssertsTransform", - "RemoveLocalScalarDenseOpsTransform", "RemoveRedundantOpsTransform", "ReplaceQDQPass", "SqueezeUnsqueezeInputs", diff --git a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py b/backends/vulkan/_passes/remove_local_scalar_dense_ops.py deleted file mode 100644 index 6ce3572ec0c..00000000000 --- a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import torch -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - -from torch._subclasses.fake_tensor import FakeTensor - - -def node_is_local_scalar_dense_chain(node: torch.fx.Node) -> bool: - """ - Converting a tensor to a scalar via tensor[0].item() creates a index_select + - local_scalar_dense pattern in the graph. Check if a node is the start of this pattern. - """ - if ( - node.op == "call_function" - and node.target == exir_ops.edge.aten.select_copy.int - and len(node.users) == 1 - ): - user = list(node.users.keys())[0] - return user.target == torch.ops.aten._local_scalar_dense.default - - return False - - -def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None: - """ - A scalar tensor in the Vulkan backend is a tensor that can be represented as a scalar - value instead of a Tensor object. The criteria for identifying a tensor as a scalar - tensor are as follows: - - 1. The tensor has only 1 element - 2. One of the node's uses is converting it to a scalar via `tensor[0].item()`, which - creates a index_select + local_scalar_dense pattern in the graph - - If any of these criteria are fulfilled, then tag the node for the tensor to mark it - so that it is added as a scalar value during serialization. - """ - tensor_val = node.meta["val"] - if not isinstance(tensor_val, FakeTensor): - return - - # Scalar tensors must have only one element - if tensor_val.numel() != 1: - return - - for user in node.users: - if node_is_local_scalar_dense_chain(user): - node.meta["etvk_is_scalar_tensor"] = True - - -def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None: - """ - Remove the index_select + local_scalar_dense pattern in the graph in favor of passing - the original scalar tensor directly. - """ - replace_node = node.args[0] - assert isinstance(replace_node, torch.fx.Node) - # If the argument to the local_scalar_dense op is a select op with only - # one user, and the argument to the select op is a tensor with only one - # element (i.e. a scalar tensor), then replace the entire pattern with the - # scalar tensor. - if ( - replace_node.op == "call_function" - and replace_node.target == exir_ops.edge.aten.select_copy.int - ): - # pyre-ignore - if replace_node.args[0].meta["val"].numel() == 1: - replace_node = replace_node.args[0] - assert isinstance(replace_node, torch.fx.Node) - assert replace_node.meta.get("etvk_is_scalar_tensor", True) - - with graph.inserting_after(node): - node.replace_all_uses_with(replace_node) - - -def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph: - """ - The purpose of this pass is twofold: - 1. Tag scalar tensors (see `tag_node_if_scalar_tensor()` for the criteria) - 2. Remove the index_select + local_scalar_dense pattern in the graph in favor of - passing the original scalar tensor directly (see `remove_local_scalar_dense_chain()`) - - This makes it easier to deal with scalar tensors in the Vulkan backend. In particular, - it allows serializing scalar tensors as SymInt objects instead of Tensor objects. - Because scalar tensors are often used to inform tensor shapes, their values need to - be easily accessed by the CPU during resizing logic, while also being able to reflect - updates to their value in any GPU shaders that reference them. - """ - target_op = torch.ops.aten._local_scalar_dense.default - for node in graph.nodes: - tag_node_if_scalar_tensor(node) - - if node.op == "call_function" and node.target == target_op: - remove_local_scalar_dense_chain(graph, node) - - graph.eliminate_dead_code() - return graph - - -class RemoveLocalScalarDenseOpsTransform(ExportPass): - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph) - return PassResult(graph_module, True) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 6e5aa926d37..682087585ef 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -9,6 +9,8 @@ import executorch.backends.vulkan.patterns as vk_patterns import torch.library +from torch._subclasses.fake_tensor import FakeTensor + namespace = "et_vk" lib = torch.library.Library(namespace, "DEF") @@ -614,3 +616,18 @@ def add_q8ta_q8ta_q8to_impl( ) lib.impl(name, add_q8ta_q8ta_q8to_impl, "CompositeExplicitAutograd") add_q8ta_q8ta_q8to_op = getattr(getattr(torch.ops, namespace), name) + +############################# +## select_as_symint ## +############################# + + +def select_as_symint_impl(x: torch.Tensor, dim: int, index: int): + assert isinstance(x, FakeTensor) + return x.fake_mode.shape_env.create_unbacked_symint() + + +name = "select_as_symint" +lib.define(f"{name}(Tensor x, int dim, int index) -> SymInt") +lib.impl(name, select_as_symint_impl, "Meta") +select_as_symint_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 059b3a07be0..bc3bf14bf14 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -184,36 +184,6 @@ def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]: return False, False - def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]: - """ - Scalar tensors are usually converted to scalar values in the graph via` - scalar_tensor[0].item()` in Python, which translates to a chain of - `local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph. - This function marks the entire chain as supported by the Vulkan delegate. - - Later, within vulkan_preprocess there will be a graph transform which replaces - the chain with passing in the scalar tensor directly. - - Similar to the `is_linear_permute` function, this function has 2 return values. - """ - if node.target == exir_ops.edge.aten.select_copy.int: - if len(node.users) != 1: - return False, False - # pyre-ignore - if node.args[0].meta["val"].numel() != 1: - return False, False - - local_scalar_dense = list(node.users.keys())[0] - if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default: - return False, False - - return self.is_in_local_scalar_dense_chain(local_scalar_dense) - - if node.target == torch.ops.aten._local_scalar_dense.default: - return True, all(self.node_is_compatible(user)[0] for user in node.users) - - return False, False - def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": logger.info( @@ -261,16 +231,6 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901 self.log_skip(node, "permute node of non compatible linear node") return False - ( - is_in_local_scalar_dense_chain, - dst_node_is_compatible, - ) = self.is_in_local_scalar_dense_chain(node) - if is_in_local_scalar_dense_chain and dst_node_is_compatible: - return True - elif is_in_local_scalar_dense_chain: - self.log_skip(node, "local scalar dense of incompatible op node") - return False - features = None if target not in vulkan_supported_ops: # For some ops, i.e. custom ops the name is registered instead of the diff --git a/backends/vulkan/patterns/TARGETS b/backends/vulkan/patterns/TARGETS index 285efe2b933..ddc9cd77c04 100644 --- a/backends/vulkan/patterns/TARGETS +++ b/backends/vulkan/patterns/TARGETS @@ -12,6 +12,7 @@ runtime.python_library( "quantized_linear.py", "quantized_convolution.py", "quantized_binary.py", + "select_as_symint.py", ], visibility = [ "//executorch/backends/...", diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index e23dfc7629c..9239416dc2d 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -14,6 +14,8 @@ import executorch.backends.vulkan.patterns.rope # noqa +import executorch.backends.vulkan.patterns.select_as_symint # noqa + import torch from executorch.backends.vulkan.patterns.pattern_registry import ( diff --git a/backends/vulkan/patterns/select_as_symint.py b/backends/vulkan/patterns/select_as_symint.py new file mode 100644 index 00000000000..e7226b08188 --- /dev/null +++ b/backends/vulkan/patterns/select_as_symint.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +class SelectAsSymIntMatch(PatternMatch): + def __init__(self, local_scalar_dense_node: torch.fx.Node) -> None: + self.anchor_node = local_scalar_dense_node + self.match_found = False + + # Check if the input to local_scalar_dense is a select_copy node + if len(local_scalar_dense_node.args) < 1: + return + + select_node = local_scalar_dense_node.args[0] + if not isinstance(select_node, torch.fx.Node): + return + + if ( + select_node.op != "call_function" + or select_node.target != exir_ops.edge.aten.select_copy.int + ): + return + + # select_copy.int has signature: select_copy(Tensor self, int dim, int index) + if len(select_node.args) < 3: + return + + self.select_node = select_node + + self.tensor_node = select_node.args[0] + self.dim_node = select_node.args[1] + self.index_node = select_node.args[2] + + self.all_nodes = [ + self.anchor_node, + self.select_node, + self.tensor_node, + self.dim_node, + self.index_node, + ] + + self.match_found = True + + +@register_pattern_detector("select_as_symint") +def find_select_as_symint_patterns( + node: torch.fx.Node, +) -> Optional[SelectAsSymIntMatch]: + if node.target != torch.ops.aten._local_scalar_dense.default: + return None + + matched_pattern = SelectAsSymIntMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +@register_pattern_replacement("select_as_symint") +def replace_select_local_scalar_dense_with_select_as_symint( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: SelectAsSymIntMatch, +): + with graph_module.graph.inserting_before(match.anchor_node): + new_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.select_as_symint.default, + args=( + match.tensor_node, + match.dim_node, + match.index_node, + ), + ) + + new_node.meta["val"] = match.anchor_node.meta["val"] + match.anchor_node.replace_all_uses_with(new_node) + + # # Remove both the local_scalar_dense and select_copy nodes + # graph_module.graph.erase_node(match.anchor_node) + # # Only erase select_node if it has no other users + # if len(match.select_node.users) == 0: + # graph_module.graph.erase_node(match.select_node) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index fe8cc83c481..cfa1242fbbf 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -649,7 +649,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { } } - if (should_propagate_resize) { + if (should_propagate_resize || compute_graph->has_data_dependent_shapes()) { compute_graph->propagate_resize(); } diff --git a/backends/vulkan/runtime/api/containers/StagingBuffer.h b/backends/vulkan/runtime/api/containers/StagingBuffer.h index 6d0e5a4a457..09788e66b0f 100644 --- a/backends/vulkan/runtime/api/containers/StagingBuffer.h +++ b/backends/vulkan/runtime/api/containers/StagingBuffer.h @@ -112,6 +112,20 @@ class StagingBuffer final { inline void set_staging_zeros() { memset(data(), 0, nbytes()); } + + template + T select_element_at_dim( + const std::vector& sizes, + const int64_t dim, + const int64_t index) { + int64_t stride = 1; + for (size_t i = dim + 1; i < sizes.size(); ++i) { + stride *= sizes[i]; + } + const int64_t offset = index * stride; + const T* typed_data = reinterpret_cast(data()); + return typed_data[offset]; + } }; } // namespace api diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 2ec63a89df8..d7f98d3244f 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -683,6 +683,17 @@ int32_t ComputeGraph::read_symint(const ValueRef idx) { return get_symint(idx)->get(); } +ValueRef ComputeGraph::staging_of(const ValueRef idx) { + for (size_t i = 0; i < inputs_.size(); ++i) { + if (inputs_[i].value == idx) { + if (is_valid(inputs_[i].staging)) { + return inputs_[i].staging; + } + } + } + VK_THROW("Could not find staging buffer for value at index ", idx); +} + SharedObject& ComputeGraph::get_shared_object(const int64_t idx) { if (idx >= shared_objects_.size()) { shared_objects_.resize(static_cast(idx + 1)); diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index dbd5536279c..f7de7e183de 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -824,6 +824,8 @@ class ComputeGraph final { inputs_.push_back({idx, kDummyValueRef}); } + ValueRef staging_of(const ValueRef idx); + inline void set_val_as_output(const ValueRef idx) { outputs_.push_back({idx, kDummyValueRef}); } @@ -1081,6 +1083,14 @@ class ComputeGraph final { return can_use_int8_dot_product_; } + inline void set_has_data_dependent_shapes() { + config_.has_data_dependent_shapes = true; + } + + inline bool has_data_dependent_shapes() const { + return config_.has_data_dependent_shapes; + } + /* * Check whether the GPU supports 8 bit buffers. */ diff --git a/backends/vulkan/runtime/graph/GraphConfig.cpp b/backends/vulkan/runtime/graph/GraphConfig.cpp index 20b8f6f7c00..9a919a42573 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.cpp +++ b/backends/vulkan/runtime/graph/GraphConfig.cpp @@ -64,6 +64,7 @@ GraphConfig::GraphConfig() { enable_local_wg_size_override = false; local_wg_size_override = {}; + has_data_dependent_shapes = false; expect_dynamic_shapes = false; force_resize = false; diff --git a/backends/vulkan/runtime/graph/GraphConfig.h b/backends/vulkan/runtime/graph/GraphConfig.h index 7533df3b685..9a753775650 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.h +++ b/backends/vulkan/runtime/graph/GraphConfig.h @@ -33,8 +33,11 @@ struct GraphConfig final { bool enable_local_wg_size_override; utils::uvec3 local_wg_size_override; + // If true, then resize functions should always be called even if input shapes + // have not changed. + bool has_data_dependent_shapes = false; // Whether or not the ComputeGraph should expect input shapes to be dynamic - bool expect_dynamic_shapes; + bool expect_dynamic_shapes = false; // Used for testing/debugging only. Forces ExecuteNode to trigger the resize // function even if none of the inputs have been updated. bool force_resize = false; diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp index aa46ee76336..40cc67517ea 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp @@ -14,15 +14,19 @@ ExecuteNode::ExecuteNode( const ResizeFunction& resize_fn, const std::vector& resize_args, const std::vector& args, - const std::string& name) + const std::string& name, + const bool has_data_dependent_shape) : resize_fn_(resize_fn), resize_args_(resize_args), args_(args), - name_(name) {} + name_(name), + has_data_dependent_shape_(has_data_dependent_shape) {} bool ExecuteNode::trigger_resize(ComputeGraph* graph) { bool any_arg_updated = was_any_arg_updated(graph); - if (resize_fn_ && (any_arg_updated || graph->graphconfig().force_resize)) { + if (resize_fn_ && + (any_arg_updated || graph->graphconfig().force_resize || + has_data_dependent_shape_)) { resize_fn_(graph, args_, resize_args_); any_arg_updated = true; } diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 323036cef90..4dbad882dea 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -57,7 +57,8 @@ class ExecuteNode { const ResizeFunction& resize_fn = nullptr, const std::vector& resize_args = {}, const std::vector& args = {}, - const std::string& name = "Graph Node"); + const std::string& name = "Graph Node", + const bool has_data_dependent_shape = false); virtual ~ExecuteNode() = default; @@ -87,6 +88,7 @@ class ExecuteNode { const std::vector resize_args_; const std::vector args_; const std::string name_; + bool has_data_dependent_shape_ = false; }; } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl index 8b69642d2e9..d0bd1809d11 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl @@ -25,8 +25,6 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -#extension GL_EXT_debug_printf : enable -#define DEBUG_MODE #include "indexing.glslh" #include "common.glslh" diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl index 971f66f93e5..4f51e9ff679 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_scalar_texture.glsl @@ -22,7 +22,6 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -#define DEBUG_MODE #include "indexing.glslh" ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh index 62c0922e3e3..8340a8b9b2f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -9,6 +9,10 @@ #ifndef COMMON_GLSLH #define COMMON_GLSLH +#ifdef DEBUG_MODE +#extension GL_EXT_debug_printf : enable +#endif + #define mul_2(x) ((x) << 1) #define mul_4(x) ((x) << 2) #define mul_8(x) ((x) << 3) diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.glsl index 8b519a67eb6..c1a21e44c60 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.glsl @@ -19,7 +19,6 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -#define DEBUG_MODE #include "indexing.glslh" ${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl index ecfc10415a1..b064d8a3295 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl @@ -20,7 +20,6 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -#define DEBUG_MODE #include "common.glslh" #include "indexing.glslh" diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh index c4feb17ef2e..0e30faa5d66 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh @@ -278,8 +278,6 @@ TextureElementIndex tensor_idx_to_texture_element_idx_simple( #ifdef DEBUG_MODE -#extension GL_EXT_debug_printf : enable - void printTensorIndex(const TensorIndex tidx) { debugPrintfEXT( "TensorIndex: tidx=[%u %u %u %u %u %u %u %u]\\n", diff --git a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp index f07522d2578..eb03639abf1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp @@ -81,9 +81,58 @@ void sym_add(ComputeGraph& graph, const std::vector& args) { new ExecuteNode(resize_sym_add_node, args)); } +void select_as_symint_impl( + ComputeGraph* graph, + const std::vector& unused, + const std::vector& args) { + (void)unused; // Unused parameter + + const ValueRef x = args.at(0); + const ValueRef dim = args.at(1); + const ValueRef index = args.at(2); + const ValueRef out = args.at(3); + + const int64_t dim_val = graph->extract_scalar(dim); + int64_t index_val = graph->extract_scalar(index); + + const std::vector x_sizes = graph->sizes_of(x); + const vkapi::ScalarType x_dtype = graph->dtype_of(x); + + if (index_val < 0) { + index_val += x_sizes[dim_val]; + } + + const StagingPtr x_staging = graph->get_staging(graph->staging_of(x)); + + int32_t x_val; + switch (x_dtype) { + case vkapi::ScalarType::Int: + x_val = x_staging->select_element_at_dim( + x_sizes, dim_val, index_val); + break; + case vkapi::ScalarType::Long: + x_val = static_cast(x_staging->select_element_at_dim( + x_sizes, dim_val, index_val)); + break; + default: + VK_THROW("Unsupported dtype for select_as_symint"); + } + + graph->set_symint(out, x_val); +} + +void select_as_symint(ComputeGraph& graph, const std::vector& args) { + select_as_symint_impl(&graph, {}, args); + + graph.execute_nodes().emplace_back(new ExecuteNode( + select_as_symint_impl, args, {}, "select_as_symint", true)); + graph.set_has_data_dependent_shapes(); +} + REGISTER_OPERATORS { VK_REGISTER_OP(sym_size.int, sym_size_int); VK_REGISTER_OP(add, sym_add); + VK_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint); } } // namespace vkcompute diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 876f7fa8900..57863703498 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -22,7 +22,6 @@ FoldQDQPass, FuseQuantizedOpsTransform, insert_prepack_nodes, - RemoveLocalScalarDenseOpsTransform, RemoveRedundantOpsTransform, ReplaceQDQPass, SqueezeUnsqueezeInputs, @@ -193,9 +192,6 @@ def preprocess( # noqa: C901 program, [ RemoveAssertsTransform(), - # Since this pass may replace a scalar argument with a tensor argument, - # this pass may result in a non ATen compliant graph structure. - RemoveLocalScalarDenseOpsTransform(), insert_prepack_nodes, ], )