Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,16 @@

import logging
import operator

from typing import Any

import executorch.backends.vulkan.utils as utils

import torch

from executorch.backends.vulkan.op_registry import get_op_features, has_impl, OpFeatures

from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkMemoryLayout,
VkStorageType,
)

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.tensor import TensorSpec

Expand Down Expand Up @@ -130,15 +124,17 @@ def __init__(
texture_limits: utils.ImageExtents,
default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D,
default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED,
force_fp16: bool = False,
):
super().__init__()
self.default_storage: VkStorageType = default_storage_type
self.default_layout: VkMemoryLayout = default_memory_layout
self.texture_limits = texture_limits
self.force_fp16 = force_fp16

# Magic number to limit "lookahead" when tracing through users of an operator
# to constrain the representation of its arguments/outputs.
self.max_trace_search_depth = 20
self.max_trace_search_depth = None

def is_valid_op_node(self, node: Any) -> bool:
"""
Expand Down Expand Up @@ -361,6 +357,12 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No
2. Then, try to trace through the users of the argument to find a representation
that can be used for as long as possible without needing a transition.
"""
# If forcing fp16, then try to use texture storage whenever possible. This is
# a temporary stopgap measure until all buffer implementations properly account
# for potential overflow of fp16 representation range when doing math in fp16.
if self.force_fp16:
op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE)

arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i)
op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset)

Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def register_sdpa_ops():
@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
def register_rotary_emb_op():
return OpFeatures(
inputs_storage=utils.WIDTH_PACKED_TEXTURE,
inputs_storage=utils.CONTIGUOUS_ANY,
supports_resize=True,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ int load_embedding_idx(const TensorIndex4D out_tidx) {
indices_tidx.data.xyz = out_tidx.data.yzw;
indices_tidx.data.w = 0;

TextureElementIndex elem_pos = tensor_idx_to_texture_element_idx_simple(
indices_tidx, indices);
TextureElementIndex elem_pos = tensor4d_idx_to_texture_element_idx_simple(
indices, indices_tidx);

const ivec4 in_texel = texelFetch(t_indices, elem_pos.pos, 0);
return in_texel[elem_pos.comp];
Expand All @@ -61,7 +61,7 @@ void main() {
return;
}

TensorIndex4D out_tidx = texture_pos_to_tensor_idx_simple(out_pos, outp);
TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos);
const int embedding_idx = load_embedding_idx(out_tidx);

const VEC4_T weight_texel = load_weight_texel(embedding_idx, out_tidx.data.x);
Expand Down
50 changes: 40 additions & 10 deletions backends/vulkan/runtime/graph/ops/glsl/indexing.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,20 @@ struct TensorIndex4D {
ivec4 data;
};

TensorIndex4D zero_tensor4d_idx() {
TensorIndex4D tidx;
tidx.data = ivec4(0);
return tidx;
}

bool out_of_bounds(const TensorIndex4D tidx, const BufferMetadata meta) {
return any(greaterThanEqual(tidx.data, meta.sizes[0]));
}

bool out_of_bounds(const TensorIndex4D tidx, const TextureMetadata meta) {
return any(greaterThanEqual(tidx.data, meta.sizes));
}

//
// TextureElementIndex
//
Expand Down Expand Up @@ -245,15 +259,9 @@ void clamp_tensor_idx(const BufferMetadata meta, inout TensorIndex tidx) {
tidx.data[1] = min(tidx.data[1], meta.sizes[1] - 1);
}

TensorIndex4D zero_tensor4d_idx() {
TensorIndex4D tidx;
tidx.data = ivec4(0);
return tidx;
}

// Does not account for axis mapping or batches
TensorIndex4D texture_pos_to_tensor_idx_simple(
const ivec3 pos, const TextureMetadata meta) {
TensorIndex4D texture_pos_to_tensor4d_idx_simple(
const TextureMetadata meta, const ivec3 pos) {
TensorIndex4D tidx;
tidx.data.xyz = pos;
tidx.data.w = 0;
Expand All @@ -262,8 +270,20 @@ TensorIndex4D texture_pos_to_tensor_idx_simple(
}

// Does not account for axis mapping or batches
TextureElementIndex tensor_idx_to_texture_element_idx_simple(
const TensorIndex4D tidx, const TextureMetadata meta) {
ivec3 tensor4d_idx_to_texel_pos_simple(
const TextureMetadata meta, const TensorIndex4D tidx) {
ivec3 texel_pos;

const int packed_dim_idx = tidx.data[meta.packed_dim];

texel_pos = tidx.data.xyz;
texel_pos[meta.packed_dim] = div_4(packed_dim_idx);
return texel_pos;
}

// Does not account for axis mapping or batches
TextureElementIndex tensor4d_idx_to_texture_element_idx_simple(
const TextureMetadata meta, const TensorIndex4D tidx) {
const int packed_dim_idx = tidx.data[meta.packed_dim];
TextureElementIndex tex_idx;
tex_idx.pos = tidx.data.xyz;
Expand All @@ -272,6 +292,16 @@ TextureElementIndex tensor_idx_to_texture_element_idx_simple(
return tex_idx;
}

uint tensor4d_idx_to_linear_idx(
const BufferMetadata meta,
const TensorIndex4D tidx) {
uint lin_idx = 0;
for (int d = 0; d < 4; ++d) {
lin_idx += meta.strides[0][d] * tidx.data[d];
}
return lin_idx;
}

//
// Debug utilities
//
Expand Down
103 changes: 74 additions & 29 deletions backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,29 @@
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

${define_required_extensions(DTYPE)}
${define_active_storage_type(STORAGE)}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)}
${layout_declare_ubo(B, "ivec3", "xqout_limits")}
${layout_declare_ubo(B, "ivec3", "xkout_limits")}
#include "indexing.glslh"

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
${layout_declare_tensor(B, "w", "t_xqout", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "w", "t_xkout", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_xq", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_xk", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_freqs_cos", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_freqs_sin", DTYPE, STORAGE, is_scalar_array=False)}

layout(constant_id = 3) const int packed_dim = 0;
$if STORAGE == "buffer":
${layout_declare_ubo(B, "BufferMetadata", "xqout")}
${layout_declare_ubo(B, "BufferMetadata", "xkout")}
${layout_declare_ubo(B, "BufferMetadata", "freqs_cos")}
$else:
${layout_declare_ubo(B, "TextureMetadata", "xqout")}
${layout_declare_ubo(B, "TextureMetadata", "xkout")}
${layout_declare_ubo(B, "TextureMetadata", "freqs_cos")}

#include "indexing_utils.h"
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* This shader computes rotary positional embeddings which are used in the Llama
Expand All @@ -39,7 +45,7 @@ layout(constant_id = 3) const int packed_dim = 0;
* 1. xq (batch_size, sequence_len, num_heads, head_dim)
* 2. xk (batch_size, sequence_len, num_kv_heads, head_dim)
* 3. freqs_cos (sequence_len, head_dim / 2)
* 4. freqs_cos (sequence_len, head_dim / 2)
* 4. freqs_sin (sequence_len, head_dim / 2)
*
* Two output tensors are produced, with the same shapes as xq and xk
* respectively.
Expand All @@ -66,23 +72,43 @@ void main() {
// Each thread will write to two output locations to maximize data re-use.
// One texel loaded from the freqs_cos/freqs_sin tensors can be used to
// calculate two output texels.
const ivec3 x_pos_1 = ivec3(
gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz);
const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz);
TensorIndex4D out_tidx_1 = zero_tensor4d_idx();
out_tidx_1.data.x = int(gl_GlobalInvocationID.x) * 8;
out_tidx_1.data.yz = ivec2(gl_GlobalInvocationID.yz);

TensorIndex4D out_tidx_2 = out_tidx_1;
out_tidx_2.data.x += 4;

if (any(greaterThanEqual(x_pos_2, xqout_limits))) {
if (out_of_bounds(out_tidx_2, xqout)) {
return;
}

const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0);
TensorIndex4D freqs_tidx = zero_tensor4d_idx();
freqs_tidx.data.x = int(gl_GlobalInvocationID.x) * 4;
freqs_tidx.data.y = out_tidx_1.data.z;

VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos);
VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos);
#ifdef USING_BUFFER
const uint freqs_texel_bufi = div_4(tensor4d_idx_to_linear_idx(freqs_cos, freqs_tidx));
VEC4_T cos_tex = t_freqs_cos[freqs_texel_bufi];
VEC4_T sin_tex = t_freqs_sin[freqs_texel_bufi];

// Compute xqout
uint x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_1));
uint x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_2));
VEC4_T x_tex_1 = t_xq[x_texel_bufi_1];
VEC4_T x_tex_2 = t_xq[x_texel_bufi_2];

#else // USING_TEXTURE
const ivec3 freqs_pos = tensor4d_idx_to_texel_pos_simple(freqs_cos, freqs_tidx);
VEC4_T cos_tex = texelFetch(t_freqs_cos, freqs_pos, 0);
VEC4_T sin_tex = texelFetch(t_freqs_sin, freqs_pos, 0);

VEC4_T x_tex_1 = load_texel(xq, x_pos_1);
VEC4_T x_tex_2 = load_texel(xq, x_pos_2);
const ivec3 x_pos_1 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_1);
const ivec3 x_pos_2 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_2);
VEC4_T x_tex_1 = texelFetch(t_xq, x_pos_1, 0);
VEC4_T x_tex_2 = texelFetch(t_xq, x_pos_2, 0);
#endif

// Compute xqout

// Separate into even and odd elements
VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
Expand All @@ -94,20 +120,34 @@ void main() {
VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);

write_texel(xqout, x_pos_1, xout_tex_1);
write_texel(xqout, x_pos_2, xout_tex_2);
#ifdef USING_BUFFER
t_xqout[x_texel_bufi_1] = xout_tex_1;
t_xqout[x_texel_bufi_2] = xout_tex_2;
#else // USING_TEXTURE
imageStore(t_xqout, x_pos_1, xout_tex_1);
imageStore(t_xqout, x_pos_2, xout_tex_2);
#endif

// n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout
// may have a larger height dim than xk and xkout. Only compute xkout if this
// invocation is still within bounds.
if (any(greaterThanEqual(x_pos_2, xkout_limits))) {
if (out_of_bounds(out_tidx_2, xkout)) {
return;
}

// Compute xkout

x_tex_1 = load_texel(xk, x_pos_1);
x_tex_2 = load_texel(xk, x_pos_2);
#ifdef USING_BUFFER
x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_1));
x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_2));

x_tex_1 = t_xk[x_texel_bufi_1];
x_tex_2 = t_xk[x_texel_bufi_2];

#else // USING_TEXTURE
x_tex_1 = texelFetch(t_xk, x_pos_1, 0);
x_tex_2 = texelFetch(t_xk, x_pos_2, 0);
#endif

x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw);
Expand All @@ -118,6 +158,11 @@ void main() {
xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);

write_texel(xkout, x_pos_1, xout_tex_1);
write_texel(xkout, x_pos_2, xout_tex_2);
#ifdef USING_BUFFER
t_xkout[x_texel_bufi_1] = xout_tex_1;
t_xkout[x_texel_bufi_2] = xout_tex_2;
#else // USING_TEXTURE
imageStore(t_xkout, x_pos_1, xout_tex_1);
imageStore(t_xkout, x_pos_2, xout_tex_2);
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ rotary_embedding:
DTYPE: float
STORAGE: texture3d
generate_variant_forall:
STORAGE:
- VALUE: texture3d
- VALUE: buffer
DTYPE:
- VALUE: half
- VALUE: float
Expand Down
21 changes: 17 additions & 4 deletions backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,17 @@ utils::uvec3 rotary_embedding_global_wg_size(

const ValueRef xq_out = args.at(0).refs.at(0);

utils::uvec3 global_wg_size = graph->logical_limits_of(xq_out);
global_wg_size[0] /= 2;
// Head dim texel size
const uint32_t D4 = utils::div_up_4(graph->size_at<uint32_t>(-1, xq_out));
// Divide by 2 since each invocation computes 2 output locations
const uint32_t D8 = utils::div_up(D4, uint32_t(2));

return global_wg_size;
// Number of query heads
const uint32_t QH = graph->size_at<uint32_t>(-2, xq_out);
// Input tokens sequence length
const uint32_t S = graph->size_at<uint32_t>(-3, xq_out);

return {D8, QH, S};
}

void add_rotary_embedding_node(
Expand All @@ -73,8 +80,14 @@ void add_rotary_embedding_node(
VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin));

std::string kernel_name = "rotary_embedding";
add_storage_type_suffix(kernel_name, graph.storage_type_of(xq_out));
add_dtype_suffix(kernel_name, graph.dtype_of(xq_out));

vkapi::ParamsBindList param_ubos = {
graph.meta_ubo(xq_out),
graph.meta_ubo(xk_out),
graph.meta_ubo(freqs_cos)};

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
Expand All @@ -84,7 +97,7 @@ void add_rotary_embedding_node(
{{{xq_out, xk_out}, vkapi::kWrite},
{{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}},
// Parameter buffers
{graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)},
param_ubos,
// Push Constants
{},
// Specialization Constants
Expand Down
Loading
Loading