Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,22 @@ def TTNG_ReinterpretTensorDescOp : TTNG_Op<"reinterpret_tensor_descriptor", [Pur
}];
}

def TTNG_GetDescriptorPtrOp : TTNG_Op<"get_descriptor_ptr", [Pure]> {
let summary = "Get the underlying pointer from a tensor descriptor";

let description = [{
This Op extracts the underlying pointer from a tensor descriptor.
Used internally for TMA operations that need to access the raw descriptor pointer.
}];

let arguments = (ins TT_TensorDescType:$desc);
let results = (outs TT_PtrType:$result);

let assemblyFormat = [{
$desc attr-dict `:` qualified(type($desc)) `->` qualified(type($result))
}];
}

def TTNG_TensormapCreateOp: TTNG_Op<
"tensormap_create",
[
Expand Down Expand Up @@ -815,4 +831,30 @@ def TTNG_TensormapFenceproxyAcquireOp: TTNG_Op<
}];
}

def TTNG_TensormapUpdateOp: TTNG_Op<
"tensormap_update",
[
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
AttrSizedOperandSegments
]
> {
let summary = "Update in-place TMA descriptor fields selectively";
let arguments = (ins
TT_PtrType:$desc_ptr,
Optional<TT_PtrType>:$global_address,
Variadic<I32>:$global_dim,
Variadic<I64>:$global_stride
);
let assemblyFormat = [{
$desc_ptr
oilist(
`global_address` `=` $global_address `:` type($global_address) |
`global_dim` `=` `[` $global_dim `]` |
`global_stride` `=` `[` $global_stride `]`
)
attr-dict `:` type($desc_ptr)
}];
let hasVerifier = 1;
}

#endif
4 changes: 4 additions & 0 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
constexpr int32_t kTMASize = 128;
return kTMASize;
}
if (isa<ttng::TensormapUpdateOp>(op)) {
constexpr int32_t kTMASize = 128;
return kTMASize;
}
return 0;
}

Expand Down
41 changes: 41 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
Expand Down Expand Up @@ -858,6 +859,46 @@ LogicalResult TensormapCreateOp::verify() {
return success();
}

// -- TensormapUpdateOp --
LogicalResult TensormapUpdateOp::verify() {
auto hasAddress = (getGlobalAddress() != nullptr);
auto hasDim = !getGlobalDim().empty();
auto hasStride = !getGlobalStride().empty();
if (!hasAddress && !hasDim && !hasStride) {
return emitError("Must update at least one descriptor field");
}
if (hasStride && !hasDim) {
return emitError("Cannot update global stride without dim specified");
}
if (hasDim && hasStride) {
if (getGlobalStride().size() + 1 != getGlobalDim().size()) {
return emitError("Rank mismatch for global stride. Got ")
<< getGlobalStride().size() << " but expected "
<< getGlobalDim().size() - 1;
}
}

// Verify that the descriptor being updated is not shared across CTAs.
// Descriptors passed as kernel parameters (from TensorDescriptor.from_tensor())
// are stored in constant memory and shared across all CTAs, so updating them
// would cause race conditions. Descriptors created in-kernel with
// make_tensor_descriptor() are per-CTA and safe to update.
auto descPtr = getDescPtr();
auto getDescPtrOp = descPtr.getDefiningOp<GetDescriptorPtrOp>();
if (getDescPtrOp) {
auto desc = getDescPtrOp.getDesc();

// Use the existing utility to check if this is a host-side descriptor
if (triton::isHostSideDescriptor(desc)) {
return emitError("Descriptor must be created within the kernel using "
"make_tensor_descriptor. Updating descriptors passed as "
"kernel parameters would cause race conditions across CTAs.");
}
}

return success();
}

} // namespace nvidia_gpu
} // namespace triton
} // namespace mlir
Expand Down
42 changes: 42 additions & 0 deletions python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,48 @@ void init_gluon_ir(py::module &&m) {
return self.create<tt::MakeTensorDescOp>(resultTy, base, shape,
strides, paddingOption);
})
.def("create_update_tensor_descriptor",
[](GluonOpBuilder &self, Value &desc,
std::optional<Value> base,
std::vector<Value> shape,
std::vector<Value> strides) -> void {
auto &builder = self.getBuilder();

auto ptrType = tt::PointerType::get(
builder.getIntegerType(8), /*addressSpace=*/1);
Value descPtr = self.create<ttng::GetDescriptorPtrOp>(ptrType, desc);

std::vector<Value> byteStrides;
if (!strides.empty()) {
auto descType = mlir::cast<tt::TensorDescType>(desc.getType());
auto elemType = descType.getBlockType().getElementType();
auto elemSize = elemType.getIntOrFloatBitWidth() / 8;
Value elemSizeVal = self.create<arith::ConstantIntOp>(
builder.getI64Type(), elemSize);

for (Value stride : strides) {
byteStrides.push_back(self.create<arith::MulIOp>(stride, elemSizeVal));
}
strides = byteStrides;
}

std::vector<Value> reversedShape(shape.rbegin(), shape.rend());
std::vector<Value> reversedStrides;
if (!strides.empty()) {
for (int k = strides.size() - 2; k >= 0; --k) {
reversedStrides.push_back(strides[k]);
}
}

Value baseVal = base.has_value() ? base.value() : Value();
self.create<ttng::TensormapUpdateOp>(descPtr, baseVal, reversedShape, reversedStrides);

self.create<ttng::TensormapFenceproxyAcquireOp>(descPtr);
},
py::arg("desc"),
py::arg("base") = std::nullopt,
py::arg("shape") = std::vector<Value>{},
py::arg("strides") = std::vector<Value>{})
.def("create_async_tdm_copy_global_to_local",
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &indices,
Value result, Value pred, Value barrier) {
Expand Down
Loading