Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(

# Magic number to limit "lookahead" when tracing through users of an operator
# to constrain the representation of its arguments/outputs.
self.max_trace_search_depth = 20
self.max_trace_search_depth = None

def is_valid_op_node(self, node: Any) -> bool:
"""
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def register_sdpa_ops():
@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
def register_rotary_emb_op():
return OpFeatures(
inputs_storage=utils.WIDTH_PACKED_TEXTURE,
inputs_storage=utils.CONTIGUOUS_ANY,
supports_resize=True,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ int load_embedding_idx(const TensorIndex4D out_tidx) {
indices_tidx.data.xyz = out_tidx.data.yzw;
indices_tidx.data.w = 0;

TextureElementIndex elem_pos = tensor_idx_to_texture_element_idx_simple(
indices_tidx, indices);
TextureElementIndex elem_pos = tensor4d_idx_to_texture_element_idx_simple(
indices, indices_tidx);

const ivec4 in_texel = texelFetch(t_indices, elem_pos.pos, 0);
return in_texel[elem_pos.comp];
Expand All @@ -61,7 +61,7 @@ void main() {
return;
}

TensorIndex4D out_tidx = texture_pos_to_tensor_idx_simple(out_pos, outp);
TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos);
const int embedding_idx = load_embedding_idx(out_tidx);

const VEC4_T weight_texel = load_weight_texel(embedding_idx, out_tidx.data.x);
Expand Down
50 changes: 40 additions & 10 deletions backends/vulkan/runtime/graph/ops/glsl/indexing.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,20 @@ struct TensorIndex4D {
ivec4 data;
};

TensorIndex4D zero_tensor4d_idx() {
TensorIndex4D tidx;
tidx.data = ivec4(0);
return tidx;
}

bool out_of_bounds(const TensorIndex4D tidx, const BufferMetadata meta) {
return any(greaterThanEqual(tidx.data, meta.sizes[0]));
}

bool out_of_bounds(const TensorIndex4D tidx, const TextureMetadata meta) {
return any(greaterThanEqual(tidx.data, meta.sizes));
}

//
// TextureElementIndex
//
Expand Down Expand Up @@ -245,15 +259,9 @@ void clamp_tensor_idx(const BufferMetadata meta, inout TensorIndex tidx) {
tidx.data[1] = min(tidx.data[1], meta.sizes[1] - 1);
}

TensorIndex4D zero_tensor4d_idx() {
TensorIndex4D tidx;
tidx.data = ivec4(0);
return tidx;
}

// Does not account for axis mapping or batches
TensorIndex4D texture_pos_to_tensor_idx_simple(
const ivec3 pos, const TextureMetadata meta) {
TensorIndex4D texture_pos_to_tensor4d_idx_simple(
const TextureMetadata meta, const ivec3 pos) {
TensorIndex4D tidx;
tidx.data.xyz = pos;
tidx.data.w = 0;
Expand All @@ -262,8 +270,20 @@ TensorIndex4D texture_pos_to_tensor_idx_simple(
}

// Does not account for axis mapping or batches
TextureElementIndex tensor_idx_to_texture_element_idx_simple(
const TensorIndex4D tidx, const TextureMetadata meta) {
ivec3 tensor4d_idx_to_texel_pos_simple(
const TextureMetadata meta, const TensorIndex4D tidx) {
ivec3 texel_pos;

const int packed_dim_idx = tidx.data[meta.packed_dim];

texel_pos = tidx.data.xyz;
texel_pos[meta.packed_dim] = div_4(packed_dim_idx);
return texel_pos;
}

// Does not account for axis mapping or batches
TextureElementIndex tensor4d_idx_to_texture_element_idx_simple(
const TextureMetadata meta, const TensorIndex4D tidx) {
const int packed_dim_idx = tidx.data[meta.packed_dim];
TextureElementIndex tex_idx;
tex_idx.pos = tidx.data.xyz;
Expand All @@ -272,6 +292,16 @@ TextureElementIndex tensor_idx_to_texture_element_idx_simple(
return tex_idx;
}

uint tensor4d_idx_to_linear_idx(
const BufferMetadata meta,
const TensorIndex4D tidx) {
uint lin_idx = 0;
for (int d = 0; d < 4; ++d) {
lin_idx += meta.strides[0][d] * tidx.data[d];
}
return lin_idx;
}

//
// Debug utilities
//
Expand Down
103 changes: 74 additions & 29 deletions backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,29 @@
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

${define_required_extensions(DTYPE)}
${define_active_storage_type(STORAGE)}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)}
${layout_declare_ubo(B, "ivec3", "xqout_limits")}
${layout_declare_ubo(B, "ivec3", "xkout_limits")}
#include "indexing.glslh"

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
${layout_declare_tensor(B, "w", "t_xqout", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "w", "t_xkout", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_xq", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_xk", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_freqs_cos", DTYPE, STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_freqs_sin", DTYPE, STORAGE, is_scalar_array=False)}

layout(constant_id = 3) const int packed_dim = 0;
$if STORAGE == "buffer":
${layout_declare_ubo(B, "BufferMetadata", "xqout")}
${layout_declare_ubo(B, "BufferMetadata", "xkout")}
${layout_declare_ubo(B, "BufferMetadata", "freqs_cos")}
$else:
${layout_declare_ubo(B, "TextureMetadata", "xqout")}
${layout_declare_ubo(B, "TextureMetadata", "xkout")}
${layout_declare_ubo(B, "TextureMetadata", "freqs_cos")}

#include "indexing_utils.h"
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* This shader computes rotary positional embeddings which are used in the Llama
Expand All @@ -39,7 +45,7 @@ layout(constant_id = 3) const int packed_dim = 0;
* 1. xq (batch_size, sequence_len, num_heads, head_dim)
* 2. xk (batch_size, sequence_len, num_kv_heads, head_dim)
* 3. freqs_cos (sequence_len, head_dim / 2)
* 4. freqs_cos (sequence_len, head_dim / 2)
* 4. freqs_sin (sequence_len, head_dim / 2)
*
* Two output tensors are produced, with the same shapes as xq and xk
* respectively.
Expand All @@ -66,23 +72,43 @@ void main() {
// Each thread will write to two output locations to maximize data re-use.
// One texel loaded from the freqs_cos/freqs_sin tensors can be used to
// calculate two output texels.
const ivec3 x_pos_1 = ivec3(
gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz);
const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz);
TensorIndex4D out_tidx_1 = zero_tensor4d_idx();
out_tidx_1.data.x = int(gl_GlobalInvocationID.x) * 8;
out_tidx_1.data.yz = ivec2(gl_GlobalInvocationID.yz);

TensorIndex4D out_tidx_2 = out_tidx_1;
out_tidx_2.data.x += 4;

if (any(greaterThanEqual(x_pos_2, xqout_limits))) {
if (out_of_bounds(out_tidx_2, xqout)) {
return;
}

const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0);
TensorIndex4D freqs_tidx = zero_tensor4d_idx();
freqs_tidx.data.x = int(gl_GlobalInvocationID.x) * 4;
freqs_tidx.data.y = out_tidx_1.data.z;

VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos);
VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos);
#ifdef USING_BUFFER
const uint freqs_texel_bufi = div_4(tensor4d_idx_to_linear_idx(freqs_cos, freqs_tidx));
VEC4_T cos_tex = t_freqs_cos[freqs_texel_bufi];
VEC4_T sin_tex = t_freqs_sin[freqs_texel_bufi];

// Compute xqout
uint x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_1));
uint x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_2));
VEC4_T x_tex_1 = t_xq[x_texel_bufi_1];
VEC4_T x_tex_2 = t_xq[x_texel_bufi_2];

#else // USING_TEXTURE
const ivec3 freqs_pos = tensor4d_idx_to_texel_pos_simple(freqs_cos, freqs_tidx);
VEC4_T cos_tex = texelFetch(t_freqs_cos, freqs_pos, 0);
VEC4_T sin_tex = texelFetch(t_freqs_sin, freqs_pos, 0);

VEC4_T x_tex_1 = load_texel(xq, x_pos_1);
VEC4_T x_tex_2 = load_texel(xq, x_pos_2);
const ivec3 x_pos_1 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_1);
const ivec3 x_pos_2 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_2);
VEC4_T x_tex_1 = texelFetch(t_xq, x_pos_1, 0);
VEC4_T x_tex_2 = texelFetch(t_xq, x_pos_2, 0);
#endif

// Compute xqout

// Separate into even and odd elements
VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
Expand All @@ -94,20 +120,34 @@ void main() {
VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);

write_texel(xqout, x_pos_1, xout_tex_1);
write_texel(xqout, x_pos_2, xout_tex_2);
#ifdef USING_BUFFER
t_xqout[x_texel_bufi_1] = xout_tex_1;
t_xqout[x_texel_bufi_2] = xout_tex_2;
#else // USING_TEXTURE
imageStore(t_xqout, x_pos_1, xout_tex_1);
imageStore(t_xqout, x_pos_2, xout_tex_2);
#endif

// n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout
// may have a larger height dim than xk and xkout. Only compute xkout if this
// invocation is still within bounds.
if (any(greaterThanEqual(x_pos_2, xkout_limits))) {
if (out_of_bounds(out_tidx_2, xkout)) {
return;
}

// Compute xkout

x_tex_1 = load_texel(xk, x_pos_1);
x_tex_2 = load_texel(xk, x_pos_2);
#ifdef USING_BUFFER
x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_1));
x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_2));

x_tex_1 = t_xk[x_texel_bufi_1];
x_tex_2 = t_xk[x_texel_bufi_2];

#else // USING_TEXTURE
x_tex_1 = texelFetch(t_xk, x_pos_1, 0);
x_tex_2 = texelFetch(t_xk, x_pos_2, 0);
#endif

x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw);
Expand All @@ -118,6 +158,11 @@ void main() {
xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);

write_texel(xkout, x_pos_1, xout_tex_1);
write_texel(xkout, x_pos_2, xout_tex_2);
#ifdef USING_BUFFER
t_xkout[x_texel_bufi_1] = xout_tex_1;
t_xkout[x_texel_bufi_2] = xout_tex_2;
#else // USING_TEXTURE
imageStore(t_xkout, x_pos_1, xout_tex_1);
imageStore(t_xkout, x_pos_2, xout_tex_2);
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ rotary_embedding:
DTYPE: float
STORAGE: texture3d
generate_variant_forall:
STORAGE:
- VALUE: texture3d
- VALUE: buffer
DTYPE:
- VALUE: half
- VALUE: float
Expand Down
21 changes: 17 additions & 4 deletions backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,17 @@ utils::uvec3 rotary_embedding_global_wg_size(

const ValueRef xq_out = args.at(0).refs.at(0);

utils::uvec3 global_wg_size = graph->logical_limits_of(xq_out);
global_wg_size[0] /= 2;
// Head dim texel size
const uint32_t D4 = utils::div_up_4(graph->size_at<uint32_t>(-1, xq_out));
// Divide by 2 since each invocation computes 2 output locations
const uint32_t D8 = utils::div_up(D4, uint32_t(2));

return global_wg_size;
// Number of query heads
const uint32_t QH = graph->size_at<uint32_t>(-2, xq_out);
// Input tokens sequence length
const uint32_t S = graph->size_at<uint32_t>(-3, xq_out);

return {D8, QH, S};
}

void add_rotary_embedding_node(
Expand All @@ -73,8 +80,14 @@ void add_rotary_embedding_node(
VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin));

std::string kernel_name = "rotary_embedding";
add_storage_type_suffix(kernel_name, graph.storage_type_of(xq_out));
add_dtype_suffix(kernel_name, graph.dtype_of(xq_out));

vkapi::ParamsBindList param_ubos = {
graph.meta_ubo(xq_out),
graph.meta_ubo(xk_out),
graph.meta_ubo(freqs_cos)};

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
Expand All @@ -84,7 +97,7 @@ void add_rotary_embedding_node(
{{{xq_out, xk_out}, vkapi::kWrite},
{{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}},
// Parameter buffers
{graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)},
param_ubos,
// Push Constants
{},
// Specialization Constants
Expand Down
Loading