You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update base for Update on "[ET-VK] Implement select_at_dim_as_symint"
## Context
The SDPA custom op accepts the `input_pos` (i.e. cache position) argument as a symbolic integer. The value of the symbolic integer is obtained by selecting the first element of a cache position input tensor and converting it to symint via local_scalar_dense.
Currently, ET-VK handles this in a hacky manner.
1. the select + local_scalar_dense op pattern is removed, and the cache pos tensor is passed directly into the custom sdpa ops
2. Single element tensors that have users that are all select + local_scalar_dense will be interpreted as symints instead of tensors
Unfortunately, this technique will not work for the huggingface implementation of transformer models, since the cache pos input tensor has not just a single element but is expected to be a vector of integer cache positions corresponding to all cache positions that will be updated.
## Changes
Introduce a custom op to capture the select + local_scalar_dense op pattern, which is the proper way to handle the op pattern.
Note that a custom op is needed because this op needs to access the staging buffer data of the input tensor, whereas `select` would typically be executed via a compute shader. The reason for this is because the `input_pos` value is needed to configure the sizes of attention weight tensors participating in the custom SDPA op, so the value must be set before any command buffers are dispatched.
As a consequence of this change, the previous handling of select + local scalar dense can also be removed.
Differential Revision: [D86340340](https://our.internmc.facebook.com/intern/diff/D86340340/)
[ghstack-poisoned]
0 commit comments