-
Notifications
You must be signed in to change notification settings - Fork 111
Fix error with rewrite_reshapes #4482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #4482 +/- ##
===========================================
+ Coverage 92.21% 92.22% +0.01%
===========================================
Files 561 561
Lines 27228 27275 +47
===========================================
+ Hits 25108 25154 +46
- Misses 2120 2121 +1
🚀 New features to boost your workflow:
|
| if(desc.empty()) | ||
| return; | ||
|
|
||
| if(desc.elements() != elements(dims2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the Llama3.2 issue, it looks like this line just prevents the rewrite. The gather before fused_reduce is {1, 1, 2048}, {2048, 2048, 1} while the output of the fused_reduce + reshapes is {1, 1, 64, 1, 32}, {1, 1, 0, 1, 0}. To move the reshape instructions from after the fused_reduce to before we would have to unsqueeze and broadcast the gather to something like {1 , 1, 64, 2048, 32}, {0, 0, 0, 1, 0}.
The shape_transform_descriptor after rebase with the bugged code is {[batch_size: 0], [1: 1], [64,:], [2048:2], [32:]}. Mentioned that this was incorrect, but it does look right? Since the 64 and 32 dimensions are broadcasted dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the Llama3.2 issue, it looks like this line just prevents the rewrite.
Yes it will. Once we have the logger, I would like to log a message for these cases because it means we are missing some perf issues.
The gather before fused_reduce is {1, 1, 2048}, {2048, 2048, 1} while the output of the fused_reduce + reshapes is {1, 1, 64, 1, 32}, {1, 1, 0, 1, 0}.
What are the reshapes ops being used? You can print the ops vector. I think we mght need that to reproduce this issue.
Mentioned that this was incorrect, but it does look right? Since the 64 and 32 dimensions are broadcasted dimensions.
No its not right because we arent broadcasting on the input to pointwise.
We start with {1, 1, 2048}, after reduce its {1, 1, 1}, then its reshaped(or uses unsqueeze) to {1, 1, 1, 1, 1} which is then broadcasted to {1, 1, 64, 1, 32}.
We start with {1, 1, 1} arriving to {1, 1, 64, 1, 32}with the shape transform descriptor and then we rebase it with the{1, 1, 2048}so it starts with that instead of{1, 1, 1}` because we want to move the transformations before the reduce so that we can fuse them together.
So we want {1, 1, 2048} reshaped to {1, 1, 64, 1, 32} and the reduction to happen along the last 3 axes, but the descriptor is showing a reshape to {1, 1, 1, 2048, 1} and then broadcasted to {1, 1, 64, 2048, 32} which is wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we want {1, 1, 2048} reshaped to {1, 1, 64, 1, 32} and the reduction to happen along the last 3 axes, but the descriptor is showing a reshape to {1, 1, 1, 2048, 1} and then broadcasted to {1, 1, 64, 2048, 32} which is wrong.
{1, 1, 64, 1, 32} would not work as the instruction before the reduce though? The reduction has to occur on the 2048, so with {1, 1, 64, 1, 32} the reduction shape output would be {1, 1, 1, 1, 1} which would have to be broadcasted again. I don't get what benefit there would be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the above what we want is this?:
TEST_CASE(rebase_broadcasted_scalar)
{
// Taken from bug found when compiling Llama3.2
auto base_desc =
make_simple_descriptor({1, 1, 1},
make_op("unsqueeze", {{"axes", {2, 4}}}),
make_op("multibroadcast", {{"out_lens", {1, 1, 64, 1, 32}}}));
{
auto desc = base_desc.rebase({1, 1, 2048});
EXPECT(not desc.empty());
EXPECT(get_final_lens(desc) == final_lens{1, 1, 64, 1, 32});
EXPECT(get_all_lens(desc) == all_lens{{1}, {1}, {64}, {1}, {32}});
EXPECT(get_all_axes(desc) == all_axes{d_axes{{0}}, d_axes{{1}}, d_axes{{2, 0}}, d_axes{{2, 1}}, d_axes{{2, 2}}});
auto generated = desc.generate();
EXPECT(generated ==
ops{
make_op("reshape", {{"out_lens", {1, 1, 64, 1, 32}}}),
});
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the above what we want is this?:
Yea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{1, 1, 64, 1, 32} would not work as the instruction before the reduce though? The reduction has to occur on the 2048, so with {1, 1, 64, 1, 32} the reduction shape output would be {1, 1, 1, 1, 1} which would have to be broadcasted again. I don't get what benefit there would be.
Currently there is a reduce -> unsqueeze -> broadcast -> pointwise, which we dont fuse because of the unsqueeze. After the rewrite we will have reshape -> reduce -> broadcast -> pointwise which we can then fuse the reduce -> broadcast -> pointwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes errors in the rewrite_reshapes optimization pass related to shape transformations involving squeeze, unsqueeze, and broadcast operations. The fix addresses edge cases where the rebase operation could produce incorrect results or fail when dealing with dimensions of size 1.
- Adds early exit optimization path for trivial direct mapping cases in shape transformation
- Improves broadcast validation logic to correctly handle dimensions with length 1
- Adds element count validation to prevent incorrect reshape transformations
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| src/shape_transform_descriptor.cpp | Adds try_trivial_direct_mapping() method for optimized handling of trivial cases and extracts regroup_axes() helper method to reduce duplication |
| src/include/migraphx/rewrite_reshapes.hpp | Adds element count validation check to prevent incorrect transformations when element counts don't match |
| src/fuse_reduce.cpp | Fixes is_valid_broadcast() to properly filter out axes with dimension length 1 when comparing broadcast and reduce axes |
| test/shape_transform_descriptor.cpp | Adds test case for rebase with squeeze/unsqueeze/broadcast operations on high-dimensional tensors |
| test/fuse_reduce.cpp | Adds test case for reduce-squeeze-unsqueeze-pointwise fusion pattern |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Motivation
Technical Details
Changelog Category