-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Add update_tensor_descriptor operation to Triton/Gluon #8786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add update_tensor_descriptor operation to Triton/Gluon #8786
Conversation
include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Outdated
Show resolved
Hide resolved
| // | ||
| // Update Tensor Descriptor Op | ||
| // | ||
| def TT_UpdateTensorDescOp : TT_Op<"update_tensor_descriptor", [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The in-place update kind of forces the underlying implementation to be memory-backed which may not be the case for all hardware, including pre-hopper where we translate tensor descriptors to normal pointer indexing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are several options to fix it, for example lowering basically to what make_tensor_descriptor does for older hardware. But overall: I'm actually not sure having this operation in Triton is appropriate, so how about removing the Triton version, and moving Gluon version under tma or better hopper namespace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a commit that removes Triton version of the operation from PR, and moves Gluon version under tma namespace.
|
|
||
| a = desc.load([moffset, noffset]) | ||
|
|
||
| tl.update_tensor_descriptor(desc, base=b_ptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is illegal. We're passing the descriptor in param space which should be constant.
I also think this will break if your launch grid is larger than the number of SMs. I expect the second program to be scheduled on a single SM would see the already updated descriptor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uhm, indeed. Would it be acceptable to limit the operation to work only for the descriptors created from within the kernel, and thus avoid both problems you pointed to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added another commit to implement the change proposed.
0ff51d0 to
f4f5db2
Compare
f4f5db2 to
00f04d7
Compare
This PR is to add
update_tensor_descriptoroperation, to simplify writing kernels like grouped MM (like pytorch/pytorch#166063, in particular to avoid handling special cases like this). This is also to matchupdate_tensormapop in CuTe DSL, like used here.The operation reads the existing descriptor from GMEM into SMEM, performs updates in SMEM, and writes updated descriptor back into GMEM. So the rationale for using this operation instead of creating a new descriptor is to save a GMEM allocation (more precisely, to trade it for reading from GMEM), and to emit only
tensormap.replace.tile.*PTX instructions for the descriptor fields that are actually changed. Otherwise, the implementation closely followsmake_tensor_descriptorimplementation. The end-to-end performance improvement is minor, but the main advantage is that the code for the cases like the kernel pointed above is cleaner. This PR makes it possible to change tensor base pointer, shape and strides fields in the descriptor; changing other fields could be added in the future if there is a need.