diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 8ed71aa1dae..15449b98f6f 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -6,22 +6,16 @@ import logging import operator - from typing import Any import executorch.backends.vulkan.utils as utils - import torch - from executorch.backends.vulkan.op_registry import get_op_features, has_impl, OpFeatures - from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, VkStorageType, ) - from executorch.exir.dialects._ops import ops as exir_ops - from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.tensor import TensorSpec @@ -130,15 +124,17 @@ def __init__( texture_limits: utils.ImageExtents, default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D, default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED, + force_fp16: bool = False, ): super().__init__() self.default_storage: VkStorageType = default_storage_type self.default_layout: VkMemoryLayout = default_memory_layout self.texture_limits = texture_limits + self.force_fp16 = force_fp16 # Magic number to limit "lookahead" when tracing through users of an operator # to constrain the representation of its arguments/outputs. - self.max_trace_search_depth = 20 + self.max_trace_search_depth = None def is_valid_op_node(self, node: Any) -> bool: """ @@ -361,6 +357,12 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No 2. Then, try to trace through the users of the argument to find a representation that can be used for as long as possible without needing a transition. """ + # If forcing fp16, then try to use texture storage whenever possible. This is + # a temporary stopgap measure until all buffer implementations properly account + # for potential overflow of fp16 representation range when doing math in fp16. + if self.force_fp16: + op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE) + arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i) op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index e487491dfbb..ef41060272c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -687,7 +687,7 @@ def register_sdpa_ops(): @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) def register_rotary_emb_op(): return OpFeatures( - inputs_storage=utils.WIDTH_PACKED_TEXTURE, + inputs_storage=utils.CONTIGUOUS_ANY, supports_resize=True, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl index b064d8a3295..9a6295a8094 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl @@ -38,8 +38,8 @@ int load_embedding_idx(const TensorIndex4D out_tidx) { indices_tidx.data.xyz = out_tidx.data.yzw; indices_tidx.data.w = 0; - TextureElementIndex elem_pos = tensor_idx_to_texture_element_idx_simple( - indices_tidx, indices); + TextureElementIndex elem_pos = tensor4d_idx_to_texture_element_idx_simple( + indices, indices_tidx); const ivec4 in_texel = texelFetch(t_indices, elem_pos.pos, 0); return in_texel[elem_pos.comp]; @@ -61,7 +61,7 @@ void main() { return; } - TensorIndex4D out_tidx = texture_pos_to_tensor_idx_simple(out_pos, outp); + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); const int embedding_idx = load_embedding_idx(out_tidx); const VEC4_T weight_texel = load_weight_texel(embedding_idx, out_tidx.data.x); diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh index 0e30faa5d66..38016547d19 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh @@ -147,6 +147,20 @@ struct TensorIndex4D { ivec4 data; }; +TensorIndex4D zero_tensor4d_idx() { + TensorIndex4D tidx; + tidx.data = ivec4(0); + return tidx; +} + +bool out_of_bounds(const TensorIndex4D tidx, const BufferMetadata meta) { + return any(greaterThanEqual(tidx.data, meta.sizes[0])); +} + +bool out_of_bounds(const TensorIndex4D tidx, const TextureMetadata meta) { + return any(greaterThanEqual(tidx.data, meta.sizes)); +} + // // TextureElementIndex // @@ -245,15 +259,9 @@ void clamp_tensor_idx(const BufferMetadata meta, inout TensorIndex tidx) { tidx.data[1] = min(tidx.data[1], meta.sizes[1] - 1); } -TensorIndex4D zero_tensor4d_idx() { - TensorIndex4D tidx; - tidx.data = ivec4(0); - return tidx; -} - // Does not account for axis mapping or batches -TensorIndex4D texture_pos_to_tensor_idx_simple( - const ivec3 pos, const TextureMetadata meta) { +TensorIndex4D texture_pos_to_tensor4d_idx_simple( + const TextureMetadata meta, const ivec3 pos) { TensorIndex4D tidx; tidx.data.xyz = pos; tidx.data.w = 0; @@ -262,8 +270,20 @@ TensorIndex4D texture_pos_to_tensor_idx_simple( } // Does not account for axis mapping or batches -TextureElementIndex tensor_idx_to_texture_element_idx_simple( - const TensorIndex4D tidx, const TextureMetadata meta) { +ivec3 tensor4d_idx_to_texel_pos_simple( + const TextureMetadata meta, const TensorIndex4D tidx) { + ivec3 texel_pos; + + const int packed_dim_idx = tidx.data[meta.packed_dim]; + + texel_pos = tidx.data.xyz; + texel_pos[meta.packed_dim] = div_4(packed_dim_idx); + return texel_pos; +} + +// Does not account for axis mapping or batches +TextureElementIndex tensor4d_idx_to_texture_element_idx_simple( + const TextureMetadata meta, const TensorIndex4D tidx) { const int packed_dim_idx = tidx.data[meta.packed_dim]; TextureElementIndex tex_idx; tex_idx.pos = tidx.data.xyz; @@ -272,6 +292,16 @@ TextureElementIndex tensor_idx_to_texture_element_idx_simple( return tex_idx; } +uint tensor4d_idx_to_linear_idx( + const BufferMetadata meta, + const TensorIndex4D tidx) { + uint lin_idx = 0; + for (int d = 0; d < 4; ++d) { + lin_idx += meta.strides[0][d] * tidx.data[d]; + } + return lin_idx; +} + // // Debug utilities // diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl index 30375728921..155eda467c4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl @@ -13,23 +13,29 @@ #define VEC4_T ${texel_load_type(DTYPE, STORAGE)} ${define_required_extensions(DTYPE)} +${define_active_storage_type(STORAGE)} layout(std430) buffer; -${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)} -${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)} -${layout_declare_ubo(B, "ivec3", "xqout_limits")} -${layout_declare_ubo(B, "ivec3", "xkout_limits")} +#include "indexing.glslh" -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_tensor(B, "w", "t_xqout", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_xkout", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_xq", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_xk", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_freqs_cos", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_freqs_sin", DTYPE, STORAGE, is_scalar_array=False)} -layout(constant_id = 3) const int packed_dim = 0; +$if STORAGE == "buffer": + ${layout_declare_ubo(B, "BufferMetadata", "xqout")} + ${layout_declare_ubo(B, "BufferMetadata", "xkout")} + ${layout_declare_ubo(B, "BufferMetadata", "freqs_cos")} +$else: + ${layout_declare_ubo(B, "TextureMetadata", "xqout")} + ${layout_declare_ubo(B, "TextureMetadata", "xkout")} + ${layout_declare_ubo(B, "TextureMetadata", "freqs_cos")} -#include "indexing_utils.h" +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; /* * This shader computes rotary positional embeddings which are used in the Llama @@ -39,7 +45,7 @@ layout(constant_id = 3) const int packed_dim = 0; * 1. xq (batch_size, sequence_len, num_heads, head_dim) * 2. xk (batch_size, sequence_len, num_kv_heads, head_dim) * 3. freqs_cos (sequence_len, head_dim / 2) - * 4. freqs_cos (sequence_len, head_dim / 2) + * 4. freqs_sin (sequence_len, head_dim / 2) * * Two output tensors are produced, with the same shapes as xq and xk * respectively. @@ -66,23 +72,43 @@ void main() { // Each thread will write to two output locations to maximize data re-use. // One texel loaded from the freqs_cos/freqs_sin tensors can be used to // calculate two output texels. - const ivec3 x_pos_1 = ivec3( - gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz); - const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz); + TensorIndex4D out_tidx_1 = zero_tensor4d_idx(); + out_tidx_1.data.x = int(gl_GlobalInvocationID.x) * 8; + out_tidx_1.data.yz = ivec2(gl_GlobalInvocationID.yz); + + TensorIndex4D out_tidx_2 = out_tidx_1; + out_tidx_2.data.x += 4; - if (any(greaterThanEqual(x_pos_2, xqout_limits))) { + if (out_of_bounds(out_tidx_2, xqout)) { return; } - const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0); + TensorIndex4D freqs_tidx = zero_tensor4d_idx(); + freqs_tidx.data.x = int(gl_GlobalInvocationID.x) * 4; + freqs_tidx.data.y = out_tidx_1.data.z; - VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos); - VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos); +#ifdef USING_BUFFER + const uint freqs_texel_bufi = div_4(tensor4d_idx_to_linear_idx(freqs_cos, freqs_tidx)); + VEC4_T cos_tex = t_freqs_cos[freqs_texel_bufi]; + VEC4_T sin_tex = t_freqs_sin[freqs_texel_bufi]; - // Compute xqout + uint x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_1)); + uint x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_2)); + VEC4_T x_tex_1 = t_xq[x_texel_bufi_1]; + VEC4_T x_tex_2 = t_xq[x_texel_bufi_2]; + +#else // USING_TEXTURE + const ivec3 freqs_pos = tensor4d_idx_to_texel_pos_simple(freqs_cos, freqs_tidx); + VEC4_T cos_tex = texelFetch(t_freqs_cos, freqs_pos, 0); + VEC4_T sin_tex = texelFetch(t_freqs_sin, freqs_pos, 0); - VEC4_T x_tex_1 = load_texel(xq, x_pos_1); - VEC4_T x_tex_2 = load_texel(xq, x_pos_2); + const ivec3 x_pos_1 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_1); + const ivec3 x_pos_2 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_2); + VEC4_T x_tex_1 = texelFetch(t_xq, x_pos_1, 0); + VEC4_T x_tex_2 = texelFetch(t_xq, x_pos_2, 0); +#endif + + // Compute xqout // Separate into even and odd elements VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); @@ -94,20 +120,34 @@ void main() { VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); - write_texel(xqout, x_pos_1, xout_tex_1); - write_texel(xqout, x_pos_2, xout_tex_2); +#ifdef USING_BUFFER + t_xqout[x_texel_bufi_1] = xout_tex_1; + t_xqout[x_texel_bufi_2] = xout_tex_2; +#else // USING_TEXTURE + imageStore(t_xqout, x_pos_1, xout_tex_1); + imageStore(t_xqout, x_pos_2, xout_tex_2); +#endif // n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout // may have a larger height dim than xk and xkout. Only compute xkout if this // invocation is still within bounds. - if (any(greaterThanEqual(x_pos_2, xkout_limits))) { + if (out_of_bounds(out_tidx_2, xkout)) { return; } // Compute xkout - x_tex_1 = load_texel(xk, x_pos_1); - x_tex_2 = load_texel(xk, x_pos_2); +#ifdef USING_BUFFER + x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_1)); + x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_2)); + + x_tex_1 = t_xk[x_texel_bufi_1]; + x_tex_2 = t_xk[x_texel_bufi_2]; + +#else // USING_TEXTURE + x_tex_1 = texelFetch(t_xk, x_pos_1, 0); + x_tex_2 = texelFetch(t_xk, x_pos_2, 0); +#endif x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw); @@ -118,6 +158,11 @@ void main() { xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); - write_texel(xkout, x_pos_1, xout_tex_1); - write_texel(xkout, x_pos_2, xout_tex_2); +#ifdef USING_BUFFER + t_xkout[x_texel_bufi_1] = xout_tex_1; + t_xkout[x_texel_bufi_2] = xout_tex_2; +#else // USING_TEXTURE + imageStore(t_xkout, x_pos_1, xout_tex_1); + imageStore(t_xkout, x_pos_2, xout_tex_2); +#endif } diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml index a81fd564d10..ba8aa400958 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml @@ -3,6 +3,9 @@ rotary_embedding: DTYPE: float STORAGE: texture3d generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer DTYPE: - VALUE: half - VALUE: float diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp index fcc8fe4b265..e1914f350b7 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -43,10 +43,17 @@ utils::uvec3 rotary_embedding_global_wg_size( const ValueRef xq_out = args.at(0).refs.at(0); - utils::uvec3 global_wg_size = graph->logical_limits_of(xq_out); - global_wg_size[0] /= 2; + // Head dim texel size + const uint32_t D4 = utils::div_up_4(graph->size_at(-1, xq_out)); + // Divide by 2 since each invocation computes 2 output locations + const uint32_t D8 = utils::div_up(D4, uint32_t(2)); - return global_wg_size; + // Number of query heads + const uint32_t QH = graph->size_at(-2, xq_out); + // Input tokens sequence length + const uint32_t S = graph->size_at(-3, xq_out); + + return {D8, QH, S}; } void add_rotary_embedding_node( @@ -73,8 +80,14 @@ void add_rotary_embedding_node( VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin)); std::string kernel_name = "rotary_embedding"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(xq_out)); add_dtype_suffix(kernel_name, graph.dtype_of(xq_out)); + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(xq_out), + graph.meta_ubo(xk_out), + graph.meta_ubo(freqs_cos)}; + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -84,7 +97,7 @@ void add_rotary_embedding_node( {{{xq_out, xk_out}, vkapi::kWrite}, {{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}}, // Parameter buffers - {graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)}, + param_ubos, // Push Constants {}, // Specialization Constants diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 57863703498..81ee67a596c 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -7,11 +7,9 @@ # pyre-strict from functools import partial - from typing import Any, Callable, Dict, final, List import executorch.backends.vulkan.utils as utils - from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform @@ -29,7 +27,6 @@ ) from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan._passes.remove_asserts import RemoveAssertsTransform - from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, @@ -39,7 +36,6 @@ serialize_vulkan_graph, ) from executorch.backends.xnnpack._passes import FuseBatchNormPass - from executorch.exir.backend.backend_details import ( BackendDetails, CompileSpec, @@ -47,18 +43,12 @@ PreprocessResult, ) from executorch.exir.backend.utils import DelegateMappingBuilder - from executorch.exir.memory_planning import greedy, MemoryPlanningAlgorithmSuite from executorch.exir.pass_base import ExportPass, PassBase - from executorch.exir.passes import MemoryPlanningPass, SpecPropPass - from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass - from executorch.exir.program._program import _transform - from torch._export.verifier import Verifier - from torch.export._remove_auto_functionalized_pass import ( unsafe_remove_auto_functionalized_pass, ) @@ -209,6 +199,7 @@ def preprocess( # noqa: C901 texture_limits, default_storage_type=default_storage_type, default_memory_layout=default_memory_layout, + force_fp16=force_fp16, ), ], )