Skip to content

Conversation

@yucai-intel
Copy link
Contributor

@yucai-intel yucai-intel commented Dec 2, 2025

To solve #2182 : Q (Query) tensor output size from torch.transform_bias_rescale_qkv was mismatched against the expected reference size in test cases involving Nested Tensors where the sequence length (T) was not a multiple of 8 after implicit padding.

Resolution: The resolution involved introducing logic within the C++ function transform_bias_rescale_qkv_xpu specifically for the Nested Tensor case to explicitly use the calculated sequence length T to resize the output q, k, and v tensors, thereby ensuring their final size matches the shape derived by the Python reference implementation.

@CuiYifeng CuiYifeng requested a review from Copilot December 9, 2025 01:16
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a tensor size mismatch issue in the transform_bias_rescale_qkv_xpu function when processing NestedTensors. The problem occurred when the sequence length (T) wasn't a multiple of 8 after implicit padding, causing the Q tensor output size to differ from the expected reference size.

Key Changes:

  • Added explicit padding logic to round up the sequence length T to the next multiple of 8 for NestedTensor cases
  • This ensures output tensor dimensions align with Tensor core requirements and match the Python reference implementation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@CuiYifeng CuiYifeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@CuiYifeng CuiYifeng force-pushed the yucai/mha/nested/fix branch from 590ebe6 to 980872b Compare December 9, 2025 02:59
@CuiYifeng CuiYifeng changed the title Fix: Incorrect Tensor Size for NestedTensor QKV Transform Fix incorrect Tensor Size for NestedTensor QKV Transform Dec 9, 2025
@CuiYifeng CuiYifeng requested a review from liangan1 December 9, 2025 03:01
// cores. Otherwise, sometimes with padding, *no* row will have the maximum
// sequence length and so we'll have a non-divisible-by-8 dimension even if
// the model author chose a multiple of 8.
T = T + (8 - (T % 8)) % 8;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dpas should not have limitation, the m can be 1~8, why we need to change it here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants