Skip to content

Conversation

@pfultz2
Copy link
Collaborator

@pfultz2 pfultz2 commented Dec 5, 2025

Motivation

Technical Details

Changelog Category

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

@pfultz2 pfultz2 requested a review from causten as a code owner December 5, 2025 01:01
@codecov
Copy link

codecov bot commented Dec 5, 2025

Codecov Report

❌ Patch coverage is 98.14815% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/shape_transform_descriptor.cpp 97.78% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4482      +/-   ##
===========================================
+ Coverage    92.21%   92.22%   +0.01%     
===========================================
  Files          561      561              
  Lines        27228    27275      +47     
===========================================
+ Hits         25108    25154      +46     
- Misses        2120     2121       +1     
Files with missing lines Coverage Δ
src/fuse_reduce.cpp 97.64% <100.00%> (+0.03%) ⬆️
src/include/migraphx/rewrite_reshapes.hpp 94.87% <100.00%> (+0.07%) ⬆️
src/shape_transform_descriptor.cpp 93.15% <97.78%> (+0.21%) ⬆️

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

if(desc.empty())
return;

if(desc.elements() != elements(dims2))
Copy link
Collaborator

@CharlieL7 CharlieL7 Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the Llama3.2 issue, it looks like this line just prevents the rewrite. The gather before fused_reduce is {1, 1, 2048}, {2048, 2048, 1} while the output of the fused_reduce + reshapes is {1, 1, 64, 1, 32}, {1, 1, 0, 1, 0}. To move the reshape instructions from after the fused_reduce to before we would have to unsqueeze and broadcast the gather to something like {1 , 1, 64, 2048, 32}, {0, 0, 0, 1, 0}.

The shape_transform_descriptor after rebase with the bugged code is {[batch_size: 0], [1: 1], [64,:], [2048:2], [32:]}. Mentioned that this was incorrect, but it does look right? Since the 64 and 32 dimensions are broadcasted dimensions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the Llama3.2 issue, it looks like this line just prevents the rewrite.

Yes it will. Once we have the logger, I would like to log a message for these cases because it means we are missing some perf issues.

The gather before fused_reduce is {1, 1, 2048}, {2048, 2048, 1} while the output of the fused_reduce + reshapes is {1, 1, 64, 1, 32}, {1, 1, 0, 1, 0}.

What are the reshapes ops being used? You can print the ops vector. I think we mght need that to reproduce this issue.

Mentioned that this was incorrect, but it does look right? Since the 64 and 32 dimensions are broadcasted dimensions.

No its not right because we arent broadcasting on the input to pointwise.

We start with {1, 1, 2048}, after reduce its {1, 1, 1}, then its reshaped(or uses unsqueeze) to {1, 1, 1, 1, 1} which is then broadcasted to {1, 1, 64, 1, 32}.

We start with {1, 1, 1} arriving to {1, 1, 64, 1, 32}with the shape transform descriptor and then we rebase it with the{1, 1, 2048}so it starts with that instead of{1, 1, 1}` because we want to move the transformations before the reduce so that we can fuse them together.

So we want {1, 1, 2048} reshaped to {1, 1, 64, 1, 32} and the reduction to happen along the last 3 axes, but the descriptor is showing a reshape to {1, 1, 1, 2048, 1} and then broadcasted to {1, 1, 64, 2048, 32} which is wrong.

Copy link
Collaborator

@CharlieL7 CharlieL7 Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we want {1, 1, 2048} reshaped to {1, 1, 64, 1, 32} and the reduction to happen along the last 3 axes, but the descriptor is showing a reshape to {1, 1, 1, 2048, 1} and then broadcasted to {1, 1, 64, 2048, 32} which is wrong.

{1, 1, 64, 1, 32} would not work as the instruction before the reduce though? The reduction has to occur on the 2048, so with {1, 1, 64, 1, 32} the reduction shape output would be {1, 1, 1, 1, 1} which would have to be broadcasted again. I don't get what benefit there would be.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the above what we want is this?:

TEST_CASE(rebase_broadcasted_scalar)
{
    // Taken from bug found when compiling Llama3.2
    auto base_desc =
        make_simple_descriptor({1, 1, 1},
                               make_op("unsqueeze", {{"axes", {2, 4}}}),
                               make_op("multibroadcast", {{"out_lens", {1, 1, 64, 1, 32}}}));

    {
        auto desc = base_desc.rebase({1, 1, 2048});
        EXPECT(not desc.empty());
        EXPECT(get_final_lens(desc) == final_lens{1, 1, 64, 1, 32});
        EXPECT(get_all_lens(desc) == all_lens{{1}, {1}, {64}, {1}, {32}});
        EXPECT(get_all_axes(desc) == all_axes{d_axes{{0}}, d_axes{{1}}, d_axes{{2, 0}}, d_axes{{2, 1}}, d_axes{{2, 2}}});
        auto generated = desc.generate();
        EXPECT(generated ==
               ops{
                   make_op("reshape", {{"out_lens", {1, 1, 64, 1, 32}}}),
               });
    }
}

Copy link
Collaborator Author

@pfultz2 pfultz2 Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the above what we want is this?:

Yea

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{1, 1, 64, 1, 32} would not work as the instruction before the reduce though? The reduction has to occur on the 2048, so with {1, 1, 64, 1, 32} the reduction shape output would be {1, 1, 1, 1, 1} which would have to be broadcasted again. I don't get what benefit there would be.

Currently there is a reduce -> unsqueeze -> broadcast -> pointwise, which we dont fuse because of the unsqueeze. After the rewrite we will have reshape -> reduce -> broadcast -> pointwise which we can then fuse the reduce -> broadcast -> pointwise.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes errors in the rewrite_reshapes optimization pass related to shape transformations involving squeeze, unsqueeze, and broadcast operations. The fix addresses edge cases where the rebase operation could produce incorrect results or fail when dealing with dimensions of size 1.

  • Adds early exit optimization path for trivial direct mapping cases in shape transformation
  • Improves broadcast validation logic to correctly handle dimensions with length 1
  • Adds element count validation to prevent incorrect reshape transformations

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
src/shape_transform_descriptor.cpp Adds try_trivial_direct_mapping() method for optimized handling of trivial cases and extracts regroup_axes() helper method to reduce duplication
src/include/migraphx/rewrite_reshapes.hpp Adds element count validation check to prevent incorrect transformations when element counts don't match
src/fuse_reduce.cpp Fixes is_valid_broadcast() to properly filter out axes with dimension length 1 when comparing broadcast and reduce axes
test/shape_transform_descriptor.cpp Adds test case for rebase with squeeze/unsqueeze/broadcast operations on high-dimensional tensors
test/fuse_reduce.cpp Adds test case for reduce-squeeze-unsqueeze-pointwise fusion pattern

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants