From 223aa5553b008d6f2c525800a21480ee727dc333 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Thu, 20 Nov 2025 17:36:10 +0000 Subject: [PATCH 1/3] Add update_tensor_descriptor operation to Triton/Gluon --- include/triton/Dialect/Triton/IR/TritonOps.td | 35 +++ .../TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td | 42 +++ lib/Analysis/Allocation.cpp | 4 + lib/Dialect/Triton/IR/Ops.cpp | 15 ++ lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 21 ++ .../Transforms/TMALowering.cpp | 55 +++- python/src/gluon_ir.cc | 12 + python/src/ir.cc | 15 +- python/test/gluon/test_core.py | 172 ++++++++++++ .../unit/language/test_tensor_descriptor.py | 246 ++++++++++++++++++ .../gluon/language/nvidia/blackwell/tma.py | 2 + .../gluon/language/nvidia/hopper/tma.py | 81 +++++- python/triton/language/__init__.py | 10 +- python/triton/language/core.py | 112 +++++++- python/triton/language/semantic.py | 2 +- test/Conversion/tritonnvidiagpu_to_llvm.mlir | 55 ++++ test/TritonNvidiaGPU/tma_lowering.mlir | 15 ++ .../lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp | 172 ++++++++++-- 18 files changed, 1024 insertions(+), 42 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index cfb31995bc3a..c377ceff3aca 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -1061,6 +1061,41 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [ }]; } +// +// Update Tensor Descriptor Op +// +def TT_UpdateTensorDescOp : TT_Op<"update_tensor_descriptor", [ + MemoryEffects<[MemRead, MemWrite]>, + AttrSizedOperandSegments, +]> { + let summary = "Update an existing tensor descriptor"; + + let description = [{ + `tt.update_tensor_descriptor` updates one or more fields of an existing tensor descriptor in-place. + + At the moment, it allows for updating the base pointer, shape and strides of a tensor in global memory. + }]; + + let arguments = (ins + AnyTypeOf<[TT_TensorDescType]>:$desc, + Optional:$base, + Variadic:$shape, + Variadic:$strides + ); + + let assemblyFormat = [{ + $desc + oilist( + `base` `=` $base `:` type($base) | + `shape` `=` `[` $shape `]` | + `strides` `=` `[` $strides `]` + ) + attr-dict `:` type($desc) + }]; + + let hasVerifier = 1; +} + // The following ops, including `call`, `func`, and `return` are copied and modified from // https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td // We could revert it back once MLIR has a better inliner interface. diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index fde58b6bfc78..d595b0dde5bb 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -766,6 +766,22 @@ def TTNG_ReinterpretTensorDescOp : TTNG_Op<"reinterpret_tensor_descriptor", [Pur }]; } +def TTNG_GetDescriptorPtrOp : TTNG_Op<"get_descriptor_ptr", [Pure]> { + let summary = "Get the underlying pointer from a tensor descriptor"; + + let description = [{ + This Op extracts the underlying pointer from a tensor descriptor. + Used internally for TMA operations that need to access the raw descriptor pointer. + }]; + + let arguments = (ins TT_TensorDescType:$desc); + let results = (outs TT_PtrType:$result); + + let assemblyFormat = [{ + $desc attr-dict `:` qualified(type($desc)) `->` qualified(type($result)) + }]; +} + def TTNG_TensormapCreateOp: TTNG_Op< "tensormap_create", [ @@ -815,4 +831,30 @@ def TTNG_TensormapFenceproxyAcquireOp: TTNG_Op< }]; } +def TTNG_TensormapUpdateOp: TTNG_Op< + "tensormap_update", + [ + MemoryEffects<[MemRead, MemWrite]>, + AttrSizedOperandSegments + ] +> { + let summary = "Update in-place TMA descriptor fields selectively"; + let arguments = (ins + TT_PtrType:$desc_ptr, + Optional:$global_address, + Variadic:$global_dim, + Variadic:$global_stride + ); + let assemblyFormat = [{ + $desc_ptr + oilist( + `global_address` `=` $global_address `:` type($global_address) | + `global_dim` `=` `[` $global_dim `]` | + `global_stride` `=` `[` $global_stride `]` + ) + attr-dict `:` type($desc_ptr) + }]; + let hasVerifier = 1; +} + #endif diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 983b6645b8c1..c5d6697c9cc2 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -108,6 +108,10 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) { constexpr int32_t kTMASize = 128; return kTMASize; } + if (isa(op)) { + constexpr int32_t kTMASize = 128; + return kTMASize; + } return 0; } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 7be62c73407b..7514bc139472 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1032,6 +1032,21 @@ void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, builder.getDenseI32ArrayAttr(order)); } +LogicalResult UpdateTensorDescOp::verify() { + if (!llvm::isa(getDesc().getType())) + return emitOpError("first operand must be a !tt.tensordesc"); + bool hasBase = (getBase() != nullptr); + bool hasShape = !getShape().empty(); + bool hasStrides = !getStrides().empty(); + if (!hasBase && !hasShape && !hasStrides) + return emitOpError("must update at least one of base/shape/strides"); + if (hasStrides && !hasShape) + return emitOpError("cannot update strides without shape"); + if (hasShape && hasStrides && getShape().size() != getStrides().size()) + return emitOpError("shape and strides must have the same length"); + return success(); +} + //-- AddPtrOp -- OpFoldResult AddPtrOp::fold(FoldAdaptor adaptor) { // addptr(ptr, 0) -> ptr diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 678cf98e5287..b96392747a3f 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -858,6 +858,27 @@ LogicalResult TensormapCreateOp::verify() { return success(); } +// -- TensormapUpdateOp -- +LogicalResult TensormapUpdateOp::verify() { + auto hasAddress = (getGlobalAddress() != nullptr); + auto hasDim = !getGlobalDim().empty(); + auto hasStride = !getGlobalStride().empty(); + if (!hasAddress && !hasDim && !hasStride) { + return emitError("Must update at least one descriptor field"); + } + if (hasStride && !hasDim) { + return emitError("Cannot update global stride without dim specified"); + } + if (hasDim && hasStride) { + if (getGlobalStride().size() + 1 != getGlobalDim().size()) { + return emitError("Rank mismatch for global stride. Got ") + << getGlobalStride().size() << " but expected " + << getGlobalDim().size() - 1; + } + } + return success(); +} + } // namespace nvidia_gpu } // namespace triton } // namespace mlir diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index dc129ad14309..1304d8b5a9ac 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -187,6 +187,58 @@ class TMACreateDescLowering : public OpRewritePattern { } }; +class TMAUpdateDescLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(UpdateTensorDescOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value desc = op.getDesc(); + + // Get the descriptor pointer + Value descPtr = rewriter.create(loc, getPointerType(rewriter.getI8Type()), desc); + + ValueRange shape = op.getShape(); + ValueRange strides = op.getStrides(); + + // Convert element strides to byte strides (same as in createTMADesc) + if (!strides.empty()) { + auto descType = mlir::cast(desc.getType()); + auto elemType = descType.getBlockType().getElementType(); + auto elemSize = elemType.getIntOrFloatBitWidth() / 8; + Value elemSizeVal = rewriter.create( + loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(elemSize)); + + SmallVector byteStrides; + for (Value stride : strides) { + byteStrides.push_back(rewriter.create(loc, stride, elemSizeVal)); + } + strides = byteStrides; + } + + // Reverse shape and strides to match TMA descriptor layout (same + // as in createTMADesc) + SmallVector reversedShape(llvm::reverse(shape)); + SmallVector reversedStrides; + if (!strides.empty()) { + for (int k = strides.size() - 2; k >= 0; --k) { + reversedStrides.push_back(strides[k]); + } + } + rewriter.create(loc, descPtr, op.getBase(), + reversedShape, reversedStrides); + + // The fence ensures that that memory ordering is correct for + // subsequent TMA operations + TensormapFenceproxyAcquireOp::create(rewriter, loc, descPtr); + + rewriter.eraseOp(op); + return success(); + } +}; + + } // anonymous namespace class TritonNvidiaGPUTMALoweringPass @@ -199,7 +251,8 @@ class TritonNvidiaGPUTMALoweringPass mlir::RewritePatternSet patterns(context); patterns.add( + TMAScatterLowering, TMAReduceLowering, TMACreateDescLowering, + TMAUpdateDescLowering>( context); if (applyPatternsGreedily(m, std::move(patterns)).failed()) signalPassFailure(); diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index 481d9ab5f677..0cb3efaef2da 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -866,6 +866,18 @@ void init_gluon_ir(py::module &&m) { return self.create(resultTy, base, shape, strides, paddingOption); }) + .def("create_update_tensor_descriptor", + [](TritonOpBuilder &self, Value &desc, + std::optional base, + std::vector shape, + std::vector strides) -> void { + Value baseVal = base.has_value() ? base.value() : Value(); + self.create(desc, baseVal, shape, strides); + }, + py::arg("desc"), + py::arg("base") = std::nullopt, + py::arg("shape") = std::vector{}, + py::arg("strides") = std::vector{}) .def("create_async_tdm_copy_global_to_local", [](GluonOpBuilder &self, Value descPtr, std::vector &indices, Value result, Value pred, Value barrier) { diff --git a/python/src/ir.cc b/python/src/ir.cc index 7b02040d3cf2..804ef94550db 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1831,7 +1831,20 @@ void init_triton_ir(py::module &&m) { return self.create(base, shape, strides, tensorShape, isSignedInteger, paddingOption); - }); + }) + // Update a tensor descriptor + .def("create_update_tensor_descriptor", + [](TritonOpBuilder &self, Value &desc, + std::optional base, + std::vector shape, + std::vector strides) { + Value baseVal = base.has_value() ? base.value() : Value(); + self.create(desc, baseVal, shape, strides); + }, + py::arg("desc"), + py::arg("base") = std::nullopt, + py::arg("shape") = std::vector{}, + py::arg("strides") = std::vector{}); py::class_(m, "pass_manager", py::module_local()) .def(py::init()) diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 66daf1082b40..7c02060d4770 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -1687,3 +1687,175 @@ def kernel(in_ptr, out_ptr, # XBLOCK, YBLOCK, num_warps=4) torch.testing.assert_close(output, ref) + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") +def test_update_tensor_descriptor_base(): + @gluon.jit + def kernel(input_ptr, output_ptr, XBLOCK: ttgl.constexpr, smem_layout: ttgl.constexpr): + # Create descriptor for input + desc = tma.make_tensor_descriptor( + input_ptr, + shape=[XBLOCK, XBLOCK], + strides=[XBLOCK, 1], + block_shape=[XBLOCK, XBLOCK], + layout=smem_layout, + ) + + smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + + mbarrier.expect(bar, desc.block_type.nbytes) + tma.async_copy_global_to_shared(desc, [0, 0], bar, smem) + mbarrier.wait(bar, 0) + mbarrier.invalidate(bar) + + tma.update_tensor_descriptor(desc, base=output_ptr) + + block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0]) + data = smem.load(block_layout) + data_smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout, data) + tma.async_copy_shared_to_global(desc, [0, 0], data_smem) + tma.store_wait(0) + data_smem._keep_alive() + + XBLOCK = 16 + input_data = torch.randn((XBLOCK, XBLOCK), device="cuda", dtype=torch.float16) + output = torch.zeros_like(input_data) + smem_layout = ttgl.NVMMASharedLayout( + swizzle_byte_width=32, + element_bitwidth=16, + rank=2, + transposed=False, + fp4_padded=False, + ) + + def alloc_fn(size: int, alignment: int, stream: int): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + kernel[(1,)](input_data, output, XBLOCK, smem_layout) + torch.testing.assert_close(input_data, output) + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") +def test_update_tensor_descriptor_shape(): + @gluon.jit + def kernel(in_ptr, out_ptr, M1: ttgl.constexpr, N1: ttgl.constexpr, + M2: ttgl.constexpr, N2: ttgl.constexpr, smem_layout: ttgl.constexpr): + desc = tma.make_tensor_descriptor( + in_ptr, + shape=[M1, N1], + strides=[N1, 1], + block_shape=[M1, N1], + layout=smem_layout, + ) + + smem_small = ttgl.allocate_shared_memory(ttgl.float16, [M1, N1], smem_layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + + mbarrier.expect(bar, desc.block_type.nbytes) + tma.async_copy_global_to_shared(desc, [0, 0], bar, smem_small) + mbarrier.wait(bar, 0) + + tma.update_tensor_descriptor(desc, base=out_ptr, shape=[M2, N2], strides=[N2, 1]) + + block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [8, 4], [4, 1], [1, 0]) + data = smem_small.load(block_layout) + smem_large_layout: ttgl.constexpr = ttgl.NVMMASharedLayout( + swizzle_byte_width=32, + element_bitwidth=16, + rank=2, + transposed=False, + fp4_padded=False, + ) + data_smem = ttgl.allocate_shared_memory(ttgl.float16, [M1, N1], smem_large_layout, data) + tma.async_copy_shared_to_global(desc, [0, 0], data_smem) + tma.store_wait(0) + data_smem._keep_alive() + + M1, N1 = 16, 16 + M2, N2 = 32, 32 + input_data = torch.randn((M1, N1), device="cuda", dtype=torch.float16) + output = torch.zeros((M2, N2), device="cuda", dtype=torch.float16) + + smem_layout = ttgl.NVMMASharedLayout( + swizzle_byte_width=32, + element_bitwidth=16, + rank=2, + transposed=False, + fp4_padded=False, + ) + + def alloc_fn(size: int, alignment: int, stream: int): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + kernel[(1,)](input_data, output, M1, N1, M2, N2, smem_layout) + torch.testing.assert_close(output[:M1, :N1], input_data) + torch.testing.assert_close(output[M1:, :], torch.zeros((M2 - M1, N2), device="cuda", dtype=torch.float16)) + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") +def test_update_tensor_descriptor_loop(): + @gluon.jit + def kernel(tensors_ptr, M: ttgl.constexpr, N: ttgl.constexpr, + num_batches: ttgl.constexpr, smem_layout: ttgl.constexpr): + desc = tma.make_tensor_descriptor( + tensors_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M, N], + layout=smem_layout, + ) + + smem = ttgl.allocate_shared_memory(ttgl.float16, [M, N], smem_layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + + block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0]) + + for i in range(num_batches): + batch_offset = i * M * N + new_base = tensors_ptr + batch_offset + tma.update_tensor_descriptor(desc, base=new_base) + + # Load, process, and store + mbarrier.init(bar, count=1) + mbarrier.expect(bar, desc.block_type.nbytes) + tma.async_copy_global_to_shared(desc, [0, 0], bar, smem) + mbarrier.wait(bar, 0) + + data = smem.load(block_layout) + data = data + ttgl.full([M, N], i + 1, ttgl.float16, block_layout) + data_smem = ttgl.allocate_shared_memory(ttgl.float16, [M, N], smem_layout, data) + tma.async_copy_shared_to_global(desc, [0, 0], data_smem) + tma.store_wait(0) + data_smem._keep_alive() + + M, N = 16, 16 + num_batches = 3 + tensors = torch.randn((num_batches, M, N), device="cuda", dtype=torch.float16) + ref = tensors.clone() + + smem_layout = ttgl.NVMMASharedLayout( + swizzle_byte_width=32, + element_bitwidth=16, + rank=2, + transposed=False, + fp4_padded=False, + ) + + def alloc_fn(size: int, alignment: int, stream: int): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + kernel[(1,)](tensors, M, N, num_batches, smem_layout) + + for i in range(num_batches): + expected = ref[i] + (i + 1) + torch.testing.assert_close(tensors[i], expected) diff --git a/python/test/unit/language/test_tensor_descriptor.py b/python/test/unit/language/test_tensor_descriptor.py index 19835ead21f9..7d8db4442714 100644 --- a/python/test/unit/language/test_tensor_descriptor.py +++ b/python/test/unit/language/test_tensor_descriptor.py @@ -1760,3 +1760,249 @@ def kernel(desc, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): kernel[(grid_m, grid_n)](desc, M, N, M_BLOCK=M_BLOCK, N_BLOCK=N_BLOCK) ref = torch.arange(M * N, dtype=torch.float32, device=device).reshape(M, N).to(torch_dtype) torch.testing.assert_close(out, ref) + + +@triton.jit +def kernel_update_tensor_descriptor_base(desc, a_ptr, b_ptr, M, N, MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + moffset = MBLOCK * pid_m + noffset = NBLOCK * pid_n + + a = desc.load([moffset, noffset]) + + tl.update_tensor_descriptor(desc, base=b_ptr) + + desc.store([moffset, noffset], a + 10) + + +@pytest.mark.interpreter +def test_update_tensor_descriptor_base(device): + M, N = 64, 128 + MBLOCK, NBLOCK = 16, 32 + + torch.manual_seed(42) + A = torch.randn((M, N), dtype=torch.float32, device=device) + B = torch.zeros((M, N), dtype=torch.float32, device=device) + + desc = TensorDescriptor.from_tensor(A, [MBLOCK, NBLOCK]) + + grid = (triton.cdiv(M, MBLOCK), triton.cdiv(N, NBLOCK)) + kernel_update_tensor_descriptor_base[grid]( + desc, A, B, M, N, MBLOCK=MBLOCK, NBLOCK=NBLOCK + ) + + ref_out = A + 10 + torch.testing.assert_close(B, ref_out) + + +@triton.jit +def kernel_update_tensor_descriptor_shape(a_ptr, b_ptr, M1, N1, M2, N2, + MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): + desc = tl.make_tensor_descriptor( + a_ptr, + shape=[M1, N1], + strides=[N1, 1], + block_shape=[MBLOCK, NBLOCK], + ) + + a = desc.load([0, 0]) + + tl.update_tensor_descriptor(desc, base=b_ptr, shape=[M2, N2], strides=[N2, 1]) + + desc.store([0, 0], a * 2) + + +@pytest.mark.interpreter +def test_update_tensor_descriptor_shape(device): + M1, N1 = 32, 64 + M2, N2 = 64, 128 + MBLOCK, NBLOCK = 16, 32 + + torch.manual_seed(42) + A = torch.randn((M1, N1), dtype=torch.float32, device=device) + B = torch.zeros((M2, N2), dtype=torch.float32, device=device) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty(size, dtype=torch.int8, device=device) + triton.set_allocator(alloc_fn) + + kernel_update_tensor_descriptor_shape[(1,)]( + A, B, M1, N1, M2, N2, MBLOCK=MBLOCK, NBLOCK=NBLOCK + ) + + ref_B = torch.zeros((M2, N2), dtype=torch.float32, device=device) + ref_B[:MBLOCK, :NBLOCK] = A[:MBLOCK, :NBLOCK] * 2 + torch.testing.assert_close(B, ref_B) + + +@triton.jit +def kernel_update_tensor_descriptor_strides(a_ptr, b_ptr, M, N, + MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): + desc = tl.make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[MBLOCK, NBLOCK], + ) + + a = desc.load([0, 0]) + + tl.update_tensor_descriptor(desc, base=b_ptr, shape=[N, M], strides=[M, 1]) + + desc.store([0, 0], a) + + +@pytest.mark.interpreter +def test_update_tensor_descriptor_strides(device): + M, N = 64, 128 + MBLOCK, NBLOCK = 16, 32 + + torch.manual_seed(42) + A = torch.randn((M, N), dtype=torch.float32, device=device) + B = torch.zeros((N, M), dtype=torch.float32, device=device) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty(size, dtype=torch.int8, device=device) + triton.set_allocator(alloc_fn) + + kernel_update_tensor_descriptor_strides[(1,)]( + A, B, M, N, MBLOCK=MBLOCK, NBLOCK=NBLOCK + ) + + ref_B = torch.zeros((N, M), dtype=torch.float32, device=device) + ref_B[:MBLOCK, :NBLOCK] = A[:MBLOCK, :NBLOCK] + torch.testing.assert_close(B, ref_B) + + +@triton.jit +def kernel_update_tensor_descriptor_loop(ptr, M, N, num_tensors: tl.constexpr, + MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): + pid = tl.program_id(0) + offset = MBLOCK * pid + + desc = tl.make_tensor_descriptor( + ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[MBLOCK, NBLOCK], + ) + + for i in range(num_tensors): + tensor_offset = i * M * N + new_base = ptr + tensor_offset + tl.update_tensor_descriptor(desc, base=new_base) + + data = desc.load([offset, 0]) + data = data + (i + 1) * 10 + desc.store([offset, 0], data) + + +@pytest.mark.interpreter +def test_update_tensor_descriptor_loop(device): + M, N = 64, 128 + MBLOCK, NBLOCK = 16, 128 + num_tensors = 3 + + torch.manual_seed(42) + tensors = torch.randn((num_tensors, M, N), dtype=torch.float32, device=device) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty(size, dtype=torch.int8, device=device) + triton.set_allocator(alloc_fn) + + grid = (triton.cdiv(M, MBLOCK),) + kernel_update_tensor_descriptor_loop[grid]( + tensors, M, N, num_tensors=num_tensors, MBLOCK=MBLOCK, NBLOCK=NBLOCK + ) + + torch.manual_seed(42) + ref_tensors = torch.randn((num_tensors, M, N), dtype=torch.float32, device=device) + for i in range(num_tensors): + ref_tensors[i] = ref_tensors[i] + (i + 1) * 10 + + torch.testing.assert_close(tensors, ref_tensors) + + +@triton.jit +def kernel_update_tensor_descriptor_mixed(in_ptr, out_ptr, M, N, new_M, new_N, + MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): + desc = tl.make_tensor_descriptor( + in_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[MBLOCK, NBLOCK], + ) + + data = desc.load([0, 0]) + + tl.update_tensor_descriptor( + desc, + base=out_ptr, + shape=[new_M, new_N], + strides=[new_N, 1] + ) + + desc.store([0, 0], data * 3) + + +@pytest.mark.interpreter +def test_update_tensor_descriptor_all_fields(device): + M, N = 32, 64 + new_M, new_N = 64, 128 + MBLOCK, NBLOCK = 16, 32 + + torch.manual_seed(42) + A = torch.randn((M, N), dtype=torch.float32, device=device) + B = torch.zeros((new_M, new_N), dtype=torch.float32, device=device) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty(size, dtype=torch.int8, device=device) + triton.set_allocator(alloc_fn) + + kernel_update_tensor_descriptor_mixed[(1,)]( + A, B, M, N, new_M, new_N, + MBLOCK=MBLOCK, NBLOCK=NBLOCK + ) + + ref_B = torch.zeros((new_M, new_N), dtype=torch.float32, device=device) + ref_B[:MBLOCK, :NBLOCK] = A[:MBLOCK, :NBLOCK] * 3 + torch.testing.assert_close(B, ref_B) + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype_str", ["float16", "float32", "bfloat16"]) +def test_update_tensor_descriptor_dtypes(dtype_str, device): + @triton.jit + def kernel(desc, new_ptr, M, N, MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): + data = desc.load([0, 0]) + tl.update_tensor_descriptor(desc, base=new_ptr) + desc.store([0, 0], data + 1) + + M, N = 32, 64 + MBLOCK, NBLOCK = 16, 32 + torch_dtype = getattr(torch, dtype_str) + + torch.manual_seed(42) + A = torch.randn((M, N), dtype=torch.float32, device=device).to(torch_dtype) + B = torch.zeros((M, N), dtype=torch_dtype, device=device) + + desc = TensorDescriptor.from_tensor(A, [MBLOCK, NBLOCK]) + kernel[(1,)](desc, B, M, N, MBLOCK=MBLOCK, NBLOCK=NBLOCK) + + ref_B = torch.zeros((M, N), dtype=torch_dtype, device=device) + ref_B[:MBLOCK, :NBLOCK] = A[:MBLOCK, :NBLOCK] + 1 + torch.testing.assert_close(B, ref_B) + +@triton.jit +def kernel_update_tensor_descriptor_invalid_strides(desc, ptr, M, N, MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): + tl.update_tensor_descriptor(desc, base=ptr, strides=[N, 1]) + +@pytest.mark.interpreter +def test_update_tensor_descriptor_invalid_strides_compile_error(device): + M, N = 32, 64 + MBLOCK, NBLOCK = 16, 32 + A = torch.empty((M, N), dtype=torch.float32, device=device) + desc = TensorDescriptor.from_tensor(A, [MBLOCK, NBLOCK]) + with pytest.raises(triton.CompilationError): + kernel_update_tensor_descriptor_invalid_strides[(1,)](desc, A, M, N, MBLOCK=MBLOCK, NBLOCK=NBLOCK) diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py b/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py index 717331e53c04..5fe53ee0b07f 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py @@ -6,6 +6,7 @@ tensor_descriptor, tensor_descriptor_type, make_tensor_descriptor, + update_tensor_descriptor, ) __all__ = [ @@ -17,6 +18,7 @@ "tensor_descriptor", "tensor_descriptor_type", "make_tensor_descriptor", + "update_tensor_descriptor", ] diff --git a/python/triton/experimental/gluon/language/nvidia/hopper/tma.py b/python/triton/experimental/gluon/language/nvidia/hopper/tma.py index dc4ef3ace295..eca9c5845ed0 100644 --- a/python/triton/experimental/gluon/language/nvidia/hopper/tma.py +++ b/python/triton/experimental/gluon/language/nvidia/hopper/tma.py @@ -15,6 +15,7 @@ @dataclass(eq=True) class tensor_descriptor_type(base_type): block_type: ttgl.block_type + base_type: base_type shape_type: ttgl.tuple_type strides_type: ttgl.tuple_type layout: NVMMASharedLayout @@ -33,9 +34,16 @@ def _to_ir(self, builder: ir.builder) -> ir.type: def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]: handle = handles[cursor] cursor += 1 + + # base is not flattened because it's embedded in the + # descriptor handle + + base = None shape, cursor = self.shape_type._unflatten_ir(handles, cursor) strides, cursor = self.strides_type._unflatten_ir(handles, cursor) - value = tensor_descriptor(handle, shape, strides, self.block_type, layout=self.layout) + shape = shape.values + strides = strides.values + value = tensor_descriptor(handle, base, shape, strides, self.block_type, layout=self.layout, base_type=self.base_type) return value, cursor def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: @@ -46,6 +54,10 @@ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: self.layout._to_ir(builder), ) out.append(ty) + + # base_type is not flattened because base is embedded in + # the descriptor handle. + self.shape_type._flatten_ir_types(builder, out) self.strides_type._flatten_ir_types(builder, out) @@ -55,16 +67,26 @@ def mangle(self) -> str: class tensor_descriptor(base_value): - def __init__(self, handle, shape: List[ttgl.tensor], strides: List[ttgl.tensor], block_type: ttgl.block_type, - layout: NVMMASharedLayout): + def __init__(self, handle, base: ttgl.tensor, shape: List[ttgl.tensor], strides: List[ttgl.tensor], + block_type: ttgl.block_type, layout: NVMMASharedLayout, base_type=None): self.handle = handle + self.base = base self.shape = ttgl.tuple(shape) self.strides = ttgl.tuple(strides) - self.type = tensor_descriptor_type(block_type, shape_type=self.shape.type, strides_type=self.strides.type, - layout=layout) + # If base_type is not provided, infer it from base + if base_type is None: + if base is None: + raise ValueError("Either base or base_type must be provided") + base_type = base.type + self.type = tensor_descriptor_type(block_type, base_type, self.shape.type, + self.strides.type, layout) def _flatten_ir(self, handles: List[ir.value]) -> None: handles.append(self.handle) + + # base is not flattened because it's embedded in the + # descriptor handle + self.shape._flatten_ir(handles) self.strides._flatten_ir(handles) @@ -155,7 +177,7 @@ def make_tensor_descriptor( shape_type = ttgl.tuple(shape).type strides_type = ttgl.tuple(strides).type - ty = tensor_descriptor_type(block_type, shape_type, strides_type, layout) + ty = tensor_descriptor_type(block_type, base.type, shape_type, strides_type, layout) if base.type.element_ty.is_int() and padding == ttgl.ir.PADDING_OPTION.PAD_NAN: raise ValueError("Padding option `nan` is not supported for integer blocks") @@ -166,4 +188,49 @@ def make_tensor_descriptor( [s.handle for s in strides], padding, ) - return tensor_descriptor(handle, shape, strides, block_type, layout) + return tensor_descriptor(handle, base, shape, strides, block_type, layout) + +@builtin +def update_tensor_descriptor( + desc: tensor_descriptor, + base: ttgl.tensor = None, + shape: List[ttgl.tensor] = None, + strides: List[ttgl.tensor] = None, + _semantic=None, +) -> None: + if base is None and shape is None and strides is None: + raise ValueError("At least one descriptor field must be updated") + + if shape is not None: + ndim = len(desc.block_shape) + + if len(shape) != ndim: + raise ValueError(f"Expected shape of {ndim} dimensions but got {len(shape)} dimensions") + + if strides is not None: + if len(strides) != ndim: + raise ValueError(f"Expected {ndim} strides but got {len(strides)}") + + last_stride = ttgl._unwrap_if_constexpr(strides[-1]) + if last_stride != 1: + raise ValueError(f"Tensor descriptor last dim must be 1 but got {last_stride}") + + shape = [_semantic.make_scalar(x, ttgl.int32) for x in shape] + + if strides is not None: + strides = [_semantic.make_scalar(ttgl._unwrap_if_constexpr(x), ttgl.int64) for x in strides] + + _semantic.builder.create_update_tensor_descriptor( + desc.handle, + base=base.handle if base is not None else None, + shape=[s.handle for s in shape] if shape is not None else [], + strides=[s.handle for s in strides] if strides is not None else [] + ) + + # Update the Python-side metadata + if base is not None: + desc.base = base + if shape is not None: + desc.shape = ttgl.tuple(shape) + if strides is not None: + desc.strides = ttgl.tuple(strides) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 409d337358bd..ad6ddb398b21 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -32,6 +32,7 @@ load_tensor_descriptor, store_tensor_descriptor, make_tensor_descriptor, + update_tensor_descriptor, tensor_descriptor, tensor_descriptor_type, add, @@ -142,6 +143,7 @@ "load_tensor_descriptor", "store_tensor_descriptor", "make_tensor_descriptor", + "update_tensor_descriptor", "tensor_descriptor", "abs", "add", @@ -310,6 +312,8 @@ def str_to_ty(name, c): # FIXME: Last dim stride should be constexpr(1) stride_type = tuple_type(([int64] * ndim)) block = block_type(dtype, block_shape) + # Base type is a pointer to the element type + base_type = pointer_type(dtype) if is_gluon: from triton.experimental.gluon.language._layouts import NVMMASharedLayout, PaddedSharedLayout, SwizzledSharedLayout from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor_type as nvidia_tensor_descriptor_type @@ -319,10 +323,10 @@ def str_to_ty(name, c): dict(NVMMASharedLayout=NVMMASharedLayout, PaddedSharedLayout=PaddedSharedLayout, SwizzledSharedLayout=SwizzledSharedLayout)) if isinstance(layout, NVMMASharedLayout): - return nvidia_tensor_descriptor_type(block, shape_type, stride_type, layout) + return nvidia_tensor_descriptor_type(block, base_type, shape_type, stride_type, layout) else: - return amd_tensor_descriptor_type(block, shape_type, stride_type, layout) - return tensor_descriptor_type(block, shape_type, stride_type) + return amd_tensor_descriptor_type(block, base_type, shape_type, stride_type, layout) + return tensor_descriptor_type(block, base_type, shape_type, stride_type) if name.startswith("constexpr"): return constexpr_type(c) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e125245d3607..970cfed5714c 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1456,28 +1456,34 @@ def scatter(self, value, *args, _semantic=None) -> tensor: class tensor_descriptor_type(tensor_descriptor_base_type): - def __init__(self, block_type: block_type, shape_type: tuple_type, strides_type: tuple_type): + def __init__(self, block_type: block_type, base_type: base_type, shape_type: tuple_type, strides_type: tuple_type): self.block_type = block_type + self.base_type = base_type self.shape_type = shape_type self.strides_type = strides_type def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: handle = handles[cursor] cursor += 1 + # Note: base is not unflattened, it's embedded in the + # descriptor handle. + base = None shape, cursor = self.shape_type._unflatten_ir(handles, cursor) strides, cursor = self.strides_type._unflatten_ir(handles, cursor) shape = shape.values strides = strides.values - value = tensor_descriptor(handle, shape, strides, self.block_type) + value = tensor_descriptor(handle, base, shape, strides, self.block_type, base_type=self.base_type) return value, cursor def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: super()._flatten_ir_types(builder, out) + # Note: base_type is not flattened, it's embedded in the + # descriptor handle. self.shape_type._flatten_ir_types(builder, out) self.strides_type._flatten_ir_types(builder, out) def __eq__(self, other): - return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type + return super().__eq__(other) and (self.base_type == other.base_type) and (self.shape_type == other.shape_type) and (self.strides_type == other.strides_type) @@ -1485,21 +1491,32 @@ class tensor_descriptor(tensor_descriptor_base): """A descriptor representing a tensor in global memory. """ - def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type): + def __init__(self, handle, base: tensor, shape: List[tensor], strides: List[tensor], block_type: block_type, base_type=None): """Not called by user code.""" # IR handle super().__init__(handle, block_type) - # Global shape + # Base pointer to global memory tensor + self.base = base + # Global shape and strides self.shape = tuple(shape) self.strides = tuple(strides) + # If base_type is not provided, infer it from base. If base + # is None, base_type must be provided + if base_type is None: + if base is None: + raise ValueError("Either base or base_type must be provided") + base_type = base.type self.type = tensor_descriptor_type( block_type, - shape_type=self.shape.type, - strides_type=self.strides.type, + base_type, + self.shape.type, + self.strides.type, ) def _flatten_ir(self, handles: List[ir.value]) -> None: handles.append(self.handle) + # Note: base is NOT flattened - it's embedded in the descriptor handle + # self.base._flatten_ir(handles) self.shape._flatten_ir(handles) self.strides._flatten_ir(handles) @@ -2349,6 +2366,87 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): return _semantic.make_tensor_descriptor(base, shape, strides, block_shape, padding_option) +@builtin +def update_tensor_descriptor( + desc: tensor_descriptor, + base: tensor = None, + shape: List[tensor] = None, + strides: List[tensor] = None, + _semantic=None, +) -> None: + """Update an existing TMA descriptor + + Updates one or more fields of an existing TMA descriptor. + + :param desc: The existing tensor descriptor to update + :param base: The new base pointer, must be 16-byte aligned (optional) + :param shape: The new tensor shape (optional) + :param strides: The new tensor strides (optional) + + Notes + ***** + - At least one field (base, shape, or strides) must be provided + - When providing strides, shape must also be provided + - Shape and strides must have the same length + - Same limitation for updates values hold as for `make_tensor_descriptor` + + Example + ******* + .. code-block:: python + + @triton.jit + def kernel(ptr, M: tl.constexpr, N: tl.constexpr): + # Create descriptor + desc = tl.make_tensor_descriptor( + ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[16, 16] + ) + + # Update to new shape + tl.update_tensor_descriptor( + desc, + shape=[M//2, N], + ) + + # Use updated descriptor + data = desc.load([0, 0]) + """ + if base is None and shape is None and strides is None: + raise ValueError("At least one descriptor field must be updated") + + if strides is not None and shape is None: + raise ValueError("Cannot update strides without providing shape") + + if shape is not None and strides is not None: + if len(shape) != len(strides): + raise ValueError(f"Shape and strides must have the same length, got {len(shape)} and {len(strides)}") + + last_stride = _unwrap_if_constexpr(strides[-1]) + if last_stride != 1: + raise ValueError(f"Tensor descriptor last dim must be 1 but got {last_stride}") + + if shape is not None: + shape = [_semantic.make_scalar(s, int32) for s in shape] + if strides is not None: + strides = [_semantic.make_scalar(_unwrap_if_constexpr(s), int64) for s in strides] + + _semantic.builder.create_update_tensor_descriptor( + desc.handle, + base=base.handle if base is not None else None, + shape=[s.handle for s in shape] if shape is not None else [], + strides=[s.handle for s in strides] if strides is not None else [] + ) + + if base is not None: + desc.base = base + if shape is not None: + desc.shape = tuple(shape) + if strides is not None: + desc.strides = tuple(strides) + + # ----------------------- # Atomic Memory Operations # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 6db9ffb25fec..36ab7f47966b 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1963,4 +1963,4 @@ def make_tensor_descriptor(self, base: TensorTy, shape: List[TensorTy], strides: handle = self.builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape], [s.handle for s in strides], block_shape, is_signed_int, padding) - return tl.tensor_descriptor(handle, shape, strides, type) + return tl.tensor_descriptor(handle, base, shape, strides, type) diff --git a/test/Conversion/tritonnvidiagpu_to_llvm.mlir b/test/Conversion/tritonnvidiagpu_to_llvm.mlir index f5a0b7eb067a..8c1f7196d549 100644 --- a/test/Conversion/tritonnvidiagpu_to_llvm.mlir +++ b/test/Conversion/tritonnvidiagpu_to_llvm.mlir @@ -210,6 +210,61 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- +// CHECK-LABEL: tensormap_update_address +module attributes {"ttg.num-ctas" = 1 : i32 , "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tensormap_update_address(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + // CHECK: tensormap.replace.tile.global_address + // CHECK-NOT: tensormap.replace.tile.global_dim + // CHECK-NOT: tensormap.replace.tile.global_stride + ttng.tensormap_update %arg0 global_address = %arg1 : !tt.ptr {allocation.offset = 0 : i32} : !tt.ptr + tt.return + } +} + +// ----- + +// CHECK-LABEL: tensormap_update_dim +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tensormap_update_dim(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32) { + // CHECK-NOT: tensormap.replace.tile.global_address + // CHECK: tensormap.replace.tile.global_dim + // CHECK: tensormap.replace.tile.global_dim + // CHECK-NOT: tensormap.replace.tile.global_stride + ttng.tensormap_update %arg0 global_dim = [%arg1, %arg2] {allocation.offset = 0 : i32} : !tt.ptr + tt.return + } +} + +// ----- + +// CHECK-LABEL: tensormap_update_dim_and_stride +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tensormap_update_dim_and_stride(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32, %arg3: i64) { + // CHECK-NOT: tensormap.replace.tile.global_address + // CHECK: tensormap.replace.tile.global_dim + // CHECK: tensormap.replace.tile.global_dim + // CHECK: tensormap.replace.tile.global_stride + ttng.tensormap_update %arg0 global_dim = [%arg1, %arg2] global_stride = [%arg3] {allocation.offset = 0 : i32} : !tt.ptr + tt.return + } +} + +// ----- + +// CHECK-LABEL: tensormap_update_all_fields +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @tensormap_update_all_fields(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i64) { + // CHECK: tensormap.replace.tile.global_address + // CHECK: tensormap.replace.tile.global_dim + // CHECK: tensormap.replace.tile.global_dim + // CHECK: tensormap.replace.tile.global_stride + ttng.tensormap_update %arg0 global_dim = [%arg2, %arg3] global_stride = [%arg4] global_address = %arg1 : !tt.ptr {allocation.offset = 0 : i32} : !tt.ptr + tt.return + } +} + +// ----- + #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> // CHECK-LABEL: async_copy_mbarrier_arrive diff --git a/test/TritonNvidiaGPU/tma_lowering.mlir b/test/TritonNvidiaGPU/tma_lowering.mlir index c90ee6b28b94..18572e72d86f 100644 --- a/test/TritonNvidiaGPU/tma_lowering.mlir +++ b/test/TritonNvidiaGPU/tma_lowering.mlir @@ -55,6 +55,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- +#nvmma_32_update = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: update_tensor_descriptor + // CHECK: %[[DESC_PTR:.+]] = ttng.get_descriptor_ptr %arg0 + // CHECK: ttng.tensormap_update %[[DESC_PTR]] global_address = %arg1 + // CHECK: ttng.tensormap_fenceproxy_acquire %[[DESC_PTR]] : !tt.ptr + tt.func public @update_tensor_descriptor(%arg0: !tt.tensordesc>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + tt.update_tensor_descriptor %arg0 base = %arg1 : !tt.ptr : !tt.tensordesc> + tt.return + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #nvmma_128 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp index 94e7e5af3a1e..725647097759 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TMAToLLVM.cpp @@ -27,7 +27,7 @@ void tensormap_cp_fenceproxy(Location loc, MLIRContext *ctx, // prepare asm operands auto *outAddrOpr = ptxBuilder.newAddrOperand(outPtr, "l"); - auto *inAddrOpr = ptxBuilder.newAddrOperand(inPtr, "l"); + auto *inAddrOpr = ptxBuilder.newAddrOperand(inPtr, "r"); auto *sizeOpr = ptxBuilder.newConstantOperand(TMA_SIZE_BYTES); // Define the instruction opcode @@ -46,18 +46,18 @@ void tensormap_cp_fenceproxy(Location loc, MLIRContext *ctx, void tensormap_replace_generic(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, std::string fieldName, Value descPtr, - int32_t newVal) { + int32_t newVal, bool useSharedMemory = true) { PTXBuilder ptxBuilder; auto b = TritonLLVMOpBuilder(loc, rewriter); // prepare asm operands - auto *descAddrOpr = ptxBuilder.newAddrOperand(descPtr, "l"); + auto *descAddrOpr = ptxBuilder.newAddrOperand(descPtr, useSharedMemory ? "r" : "l"); auto newValOpr = ptxBuilder.newConstantOperand(newVal); // Define the instruction opcode auto &replace = ptxBuilder.create("tensormap.replace.tile") ->o(fieldName) - .o("shared::cta") + .o(useSharedMemory ? "shared::cta" : "global") .o("b1024") .o("b32"); @@ -72,7 +72,8 @@ void tensormap_replace_generic(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, std::string fieldName, Value descPtr, Value newVal, - std::optional ord = std::nullopt) { + std::optional ord = std::nullopt, + bool useSharedMemory = true) { PTXBuilder ptxBuilder; auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -80,7 +81,7 @@ void tensormap_replace_generic(Location loc, MLIRContext *ctx, int width = 0; // prepare asm operands - auto *descAddrOpr = ptxBuilder.newAddrOperand(descPtr, "l"); + auto *descAddrOpr = ptxBuilder.newAddrOperand(descPtr, useSharedMemory ? "r" : "l"); PTXInstr::Operand *ordOpr = ord ? ptxBuilder.newConstantOperand(*ord) : nullptr; PTXInstr::Operand *newValOpr = nullptr; @@ -96,7 +97,7 @@ void tensormap_replace_generic(Location loc, MLIRContext *ctx, // Define the instruction opcode auto &replace = ptxBuilder.create("tensormap.replace.tile") ->o(fieldName) - .o("shared::cta") + .o(useSharedMemory ? "shared::cta" : "global") .o("b1024") .o("b32", width == 32) .o("b64", width == 64); @@ -115,36 +116,36 @@ void tensormap_replace_generic(Location loc, MLIRContext *ctx, void tensormap_replace_global_address(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, - Value descPtr, Value newVal) { + Value descPtr, Value newVal, bool useSharedMemory = true) { tensormap_replace_generic(loc, ctx, rewriter, "global_address", descPtr, - newVal); + newVal, std::nullopt, useSharedMemory); } void tensormap_replace_rank(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, Value descPtr, int32_t newVal) { - tensormap_replace_generic(loc, ctx, rewriter, "rank", descPtr, newVal); + tensormap_replace_generic(loc, ctx, rewriter, "rank", descPtr, newVal, true); } void tensormap_replace_box_dim(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, Value descPtr, int32_t ord, Value newVal) { tensormap_replace_generic(loc, ctx, rewriter, "box_dim", descPtr, newVal, - ord); + ord, true); } void tensormap_replace_global_dim(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, - Value descPtr, int32_t ord, Value newVal) { + Value descPtr, int32_t ord, Value newVal, bool useSharedMemory = true) { tensormap_replace_generic(loc, ctx, rewriter, "global_dim", descPtr, newVal, - ord); + ord, useSharedMemory); } void tensormap_replace_global_stride(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, - Value descPtr, int32_t ord, Value newVal) { + Value descPtr, int32_t ord, Value newVal, bool useSharedMemory = true) { tensormap_replace_generic(loc, ctx, rewriter, "global_stride", descPtr, - newVal, ord); + newVal, ord, useSharedMemory); } void tensormap_replace_element_stride(Location loc, MLIRContext *ctx, @@ -152,33 +153,33 @@ void tensormap_replace_element_stride(Location loc, MLIRContext *ctx, Value descPtr, int32_t ord, Value newVal) { tensormap_replace_generic(loc, ctx, rewriter, "element_stride", descPtr, - newVal, ord); + newVal, ord, true); } void tensormap_replace_elemtype(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, Value descPtr, int32_t newVal) { - tensormap_replace_generic(loc, ctx, rewriter, "elemtype", descPtr, newVal); + tensormap_replace_generic(loc, ctx, rewriter, "elemtype", descPtr, newVal, true); } void tensormap_replace_interleave_layout(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, Value descPtr, int32_t newVal) { tensormap_replace_generic(loc, ctx, rewriter, "interleave_layout", descPtr, - newVal); + newVal, true); } void tensormap_replace_swizzle_mode(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, Value descPtr, int32_t newVal) { tensormap_replace_generic(loc, ctx, rewriter, "swizzle_mode", descPtr, - newVal); + newVal, true); } void tensormap_replace_fill_mode(Location loc, MLIRContext *ctx, ConversionPatternRewriter &rewriter, Value descPtr, int32_t newVal) { - tensormap_replace_generic(loc, ctx, rewriter, "fill_mode", descPtr, newVal); + tensormap_replace_generic(loc, ctx, rewriter, "fill_mode", descPtr, newVal, true); } struct TensormapFenceproxyAcquireOpConversion @@ -313,12 +314,139 @@ struct ReinterpretTensorDescOpConversion } }; +struct GetDescriptorPtrOpConversion + : public ConvertOpToLLVMPattern { + + GetDescriptorPtrOpConversion(LLVMTypeConverter &converter, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(ttng::GetDescriptorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getDesc()); + return success(); + } +}; + +struct TensormapUpdateOpConversion + : public ConvertOpToLLVMPattern { + const NVIDIA::TargetInfo &targetInfo; + + TensormapUpdateOpConversion(LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(ttng::TensormapUpdateOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto ctx = getContext(); + + Value descPtr = adaptor.getDescPtr(); + + // Detect address space + auto ptrType = mlir::cast(descPtr.getType()); + auto addrSpace = ptrType.getAddressSpace(); + + // For descriptors in global memory, copy to SMEM, update, and copy back + if (addrSpace != 3) { + // Get shared memory workspace for updating + auto smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); + + // Copy descriptor from GMEM to SMEM (128 bytes) + constexpr int kWarpSize = 32; + constexpr int kTMASize = 128; + Value threadId = getThreadId(rewriter, loc); + Value pred = b.icmp_slt(threadId, b.i32_val(kWarpSize)); + + // Cast pointers for 32-bit (4-byte) loads/stores + auto i32PtrTyGmem = ptr_ty(ctx, 1); + auto i32PtrTySmem = ptr_ty(ctx, 3); + Value gmemI32 = b.bitcast(descPtr, i32PtrTyGmem); + Value smemI32 = b.bitcast(smemBase, i32PtrTySmem); + + // Use threadId clamped to 0 for threads that shouldn't participate + Value clampedThreadId = b.select(pred, threadId, b.i32_val(0)); + + // Each thread copies one i32 (4 bytes) + Value gmemAddr = b.gep(i32PtrTyGmem, i32_ty, gmemI32, clampedThreadId); + Value smemAddr = b.gep(i32PtrTySmem, i32_ty, smemI32, clampedThreadId); + + // Load from GMEM and store to SMEM + Value data = b.load(i32_ty, gmemAddr); + targetInfo.storeShared(rewriter, loc, smemAddr, data, pred); + LLVM::NVIDIA::createSyncWarp(loc, rewriter); + + // Perform all updates in SMEM using fast shared::cta + // instructions + if (adaptor.getGlobalAddress()) { + tensormap_replace_global_address(loc, ctx, rewriter, smemBase, + adaptor.getGlobalAddress(), true); + } + if (adaptor.getGlobalDim().size() > 0) { + for (int i = 0; i < adaptor.getGlobalDim().size(); ++i) { + tensormap_replace_global_dim(loc, ctx, rewriter, smemBase, i, + adaptor.getGlobalDim()[i], true); + } + } + if (adaptor.getGlobalStride().size() > 0) { + bool needsStrideWorkaround = targetInfo.getPtxVersion() <= 85; + for (int i = 0; i < adaptor.getGlobalStride().size(); ++i) { + auto strideVal = adaptor.getGlobalStride()[i]; + if (needsStrideWorkaround) { + strideVal = b.ashr(strideVal, b.i64_val(4)); + } + tensormap_replace_global_stride(loc, ctx, rewriter, smemBase, i, + strideVal, true); + } + } + + tensormap_cp_fenceproxy(loc, ctx, rewriter, descPtr, smemBase); + } else { + // Descriptor already in SMEM, update in place + if (adaptor.getGlobalAddress()) { + tensormap_replace_global_address(loc, ctx, rewriter, descPtr, + adaptor.getGlobalAddress(), true); + } + if (adaptor.getGlobalDim().size() > 0) { + for (int i = 0; i < adaptor.getGlobalDim().size(); ++i) { + tensormap_replace_global_dim(loc, ctx, rewriter, descPtr, i, + adaptor.getGlobalDim()[i], true); + } + } + if (adaptor.getGlobalStride().size() > 0) { + bool needsStrideWorkaround = targetInfo.getPtxVersion() <= 85; + for (int i = 0; i < adaptor.getGlobalStride().size(); ++i) { + auto strideVal = adaptor.getGlobalStride()[i]; + if (needsStrideWorkaround) { + strideVal = b.ashr(strideVal, b.i64_val(4)); + } + tensormap_replace_global_stride(loc, ctx, rewriter, descPtr, i, + strideVal, true); + } + } + + LLVM::NVIDIA::createSyncWarp(loc, rewriter); + } + + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace void mlir::triton::NVIDIA::populateTMAToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(typeConverter, targetInfo, benefit); patterns.add(typeConverter, benefit); + ReinterpretTensorDescOpConversion, + GetDescriptorPtrOpConversion>(typeConverter, benefit); + patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, benefit); } From 6ebaf6c6f3dd7216f603c8851d185767ed4d216a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Fri, 21 Nov 2025 16:59:14 +0000 Subject: [PATCH 2/3] Removed Triton operation and Gluon operation put under tma namespace --- include/triton/Dialect/Triton/IR/TritonOps.td | 35 --- lib/Dialect/Triton/IR/Ops.cpp | 15 -- .../Transforms/TMALowering.cpp | 55 +--- python/src/gluon_ir.cc | 34 ++- python/src/ir.cc | 15 +- python/test/gluon/test_core.py | 54 ++++ .../unit/language/test_tensor_descriptor.py | 246 ------------------ python/triton/language/__init__.py | 2 - python/triton/language/core.py | 112 +------- python/triton/language/semantic.py | 2 +- test/TritonNvidiaGPU/tma_lowering.mlir | 15 -- 11 files changed, 96 insertions(+), 489 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index c377ceff3aca..cfb31995bc3a 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -1061,41 +1061,6 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [ }]; } -// -// Update Tensor Descriptor Op -// -def TT_UpdateTensorDescOp : TT_Op<"update_tensor_descriptor", [ - MemoryEffects<[MemRead, MemWrite]>, - AttrSizedOperandSegments, -]> { - let summary = "Update an existing tensor descriptor"; - - let description = [{ - `tt.update_tensor_descriptor` updates one or more fields of an existing tensor descriptor in-place. - - At the moment, it allows for updating the base pointer, shape and strides of a tensor in global memory. - }]; - - let arguments = (ins - AnyTypeOf<[TT_TensorDescType]>:$desc, - Optional:$base, - Variadic:$shape, - Variadic:$strides - ); - - let assemblyFormat = [{ - $desc - oilist( - `base` `=` $base `:` type($base) | - `shape` `=` `[` $shape `]` | - `strides` `=` `[` $strides `]` - ) - attr-dict `:` type($desc) - }]; - - let hasVerifier = 1; -} - // The following ops, including `call`, `func`, and `return` are copied and modified from // https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Func/IR/FuncOps.td // We could revert it back once MLIR has a better inliner interface. diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 7514bc139472..7be62c73407b 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1032,21 +1032,6 @@ void MakeTensorPtrOp::build(OpBuilder &builder, OperationState &state, builder.getDenseI32ArrayAttr(order)); } -LogicalResult UpdateTensorDescOp::verify() { - if (!llvm::isa(getDesc().getType())) - return emitOpError("first operand must be a !tt.tensordesc"); - bool hasBase = (getBase() != nullptr); - bool hasShape = !getShape().empty(); - bool hasStrides = !getStrides().empty(); - if (!hasBase && !hasShape && !hasStrides) - return emitOpError("must update at least one of base/shape/strides"); - if (hasStrides && !hasShape) - return emitOpError("cannot update strides without shape"); - if (hasShape && hasStrides && getShape().size() != getStrides().size()) - return emitOpError("shape and strides must have the same length"); - return success(); -} - //-- AddPtrOp -- OpFoldResult AddPtrOp::fold(FoldAdaptor adaptor) { // addptr(ptr, 0) -> ptr diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 1304d8b5a9ac..dc129ad14309 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -187,58 +187,6 @@ class TMACreateDescLowering : public OpRewritePattern { } }; -class TMAUpdateDescLowering : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(UpdateTensorDescOp op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value desc = op.getDesc(); - - // Get the descriptor pointer - Value descPtr = rewriter.create(loc, getPointerType(rewriter.getI8Type()), desc); - - ValueRange shape = op.getShape(); - ValueRange strides = op.getStrides(); - - // Convert element strides to byte strides (same as in createTMADesc) - if (!strides.empty()) { - auto descType = mlir::cast(desc.getType()); - auto elemType = descType.getBlockType().getElementType(); - auto elemSize = elemType.getIntOrFloatBitWidth() / 8; - Value elemSizeVal = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(elemSize)); - - SmallVector byteStrides; - for (Value stride : strides) { - byteStrides.push_back(rewriter.create(loc, stride, elemSizeVal)); - } - strides = byteStrides; - } - - // Reverse shape and strides to match TMA descriptor layout (same - // as in createTMADesc) - SmallVector reversedShape(llvm::reverse(shape)); - SmallVector reversedStrides; - if (!strides.empty()) { - for (int k = strides.size() - 2; k >= 0; --k) { - reversedStrides.push_back(strides[k]); - } - } - rewriter.create(loc, descPtr, op.getBase(), - reversedShape, reversedStrides); - - // The fence ensures that that memory ordering is correct for - // subsequent TMA operations - TensormapFenceproxyAcquireOp::create(rewriter, loc, descPtr); - - rewriter.eraseOp(op); - return success(); - } -}; - - } // anonymous namespace class TritonNvidiaGPUTMALoweringPass @@ -251,8 +199,7 @@ class TritonNvidiaGPUTMALoweringPass mlir::RewritePatternSet patterns(context); patterns.add( + TMAScatterLowering, TMAReduceLowering, TMACreateDescLowering>( context); if (applyPatternsGreedily(m, std::move(patterns)).failed()) signalPassFailure(); diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index 0cb3efaef2da..f980df63578a 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -867,12 +867,42 @@ void init_gluon_ir(py::module &&m) { strides, paddingOption); }) .def("create_update_tensor_descriptor", - [](TritonOpBuilder &self, Value &desc, + [](GluonOpBuilder &self, Value &desc, std::optional base, std::vector shape, std::vector strides) -> void { + auto &builder = self.getBuilder(); + + auto ptrType = tt::PointerType::get( + builder.getIntegerType(8), /*addressSpace=*/1); + Value descPtr = self.create(ptrType, desc); + + std::vector byteStrides; + if (!strides.empty()) { + auto descType = mlir::cast(desc.getType()); + auto elemType = descType.getBlockType().getElementType(); + auto elemSize = elemType.getIntOrFloatBitWidth() / 8; + Value elemSizeVal = self.create( + builder.getI64Type(), elemSize); + + for (Value stride : strides) { + byteStrides.push_back(self.create(stride, elemSizeVal)); + } + strides = byteStrides; + } + + std::vector reversedShape(shape.rbegin(), shape.rend()); + std::vector reversedStrides; + if (!strides.empty()) { + for (int k = strides.size() - 2; k >= 0; --k) { + reversedStrides.push_back(strides[k]); + } + } + Value baseVal = base.has_value() ? base.value() : Value(); - self.create(desc, baseVal, shape, strides); + self.create(descPtr, baseVal, reversedShape, reversedStrides); + + self.create(descPtr); }, py::arg("desc"), py::arg("base") = std::nullopt, diff --git a/python/src/ir.cc b/python/src/ir.cc index 804ef94550db..7b02040d3cf2 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1831,20 +1831,7 @@ void init_triton_ir(py::module &&m) { return self.create(base, shape, strides, tensorShape, isSignedInteger, paddingOption); - }) - // Update a tensor descriptor - .def("create_update_tensor_descriptor", - [](TritonOpBuilder &self, Value &desc, - std::optional base, - std::vector shape, - std::vector strides) { - Value baseVal = base.has_value() ? base.value() : Value(); - self.create(desc, baseVal, shape, strides); - }, - py::arg("desc"), - py::arg("base") = std::nullopt, - py::arg("shape") = std::vector{}, - py::arg("strides") = std::vector{}); + }); py::class_(m, "pass_manager", py::module_local()) .def(py::init()) diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 7c02060d4770..b2513136077c 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -1859,3 +1859,57 @@ def alloc_fn(size: int, alignment: int, stream: int): for i in range(num_batches): expected = ref[i] + (i + 1) torch.testing.assert_close(tensors[i], expected) + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") +def test_update_tensor_descriptor_strides(): + @gluon.jit + def kernel(a_ptr, b_ptr, M: ttgl.constexpr, N: ttgl.constexpr, smem_layout: ttgl.constexpr): + desc = tma.make_tensor_descriptor( + a_ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[M, N], + layout=smem_layout, + ) + + smem = ttgl.allocate_shared_memory(ttgl.float16, [M, N], smem_layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + + mbarrier.expect(bar, desc.block_type.nbytes) + tma.async_copy_global_to_shared(desc, [0, 0], bar, smem) + mbarrier.wait(bar, 0) + mbarrier.invalidate(bar) + + tma.update_tensor_descriptor(desc, base=b_ptr, shape=[N, M], strides=[M, 1]) + + block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0]) + data = smem.load(block_layout) + data_smem = ttgl.allocate_shared_memory(ttgl.float16, [M, N], smem_layout, data) + tma.async_copy_shared_to_global(desc, [0, 0], data_smem) + tma.store_wait(0) + data_smem._keep_alive() + + M, N = 16, 32 + input_data = torch.randn((M, N), device="cuda", dtype=torch.float16) + output = torch.zeros((N, M), device="cuda", dtype=torch.float16) + + smem_layout = ttgl.NVMMASharedLayout( + swizzle_byte_width=32, + element_bitwidth=16, + rank=2, + transposed=False, + fp4_padded=False, + ) + + def alloc_fn(size: int, alignment: int, stream: int): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + kernel[(1,)](input_data, output, M, N, smem_layout) + + ref = torch.zeros((N, M), device="cuda", dtype=torch.float16) + ref[:M, :M] = input_data[:, :M] + torch.testing.assert_close(output, ref) diff --git a/python/test/unit/language/test_tensor_descriptor.py b/python/test/unit/language/test_tensor_descriptor.py index 7d8db4442714..19835ead21f9 100644 --- a/python/test/unit/language/test_tensor_descriptor.py +++ b/python/test/unit/language/test_tensor_descriptor.py @@ -1760,249 +1760,3 @@ def kernel(desc, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr): kernel[(grid_m, grid_n)](desc, M, N, M_BLOCK=M_BLOCK, N_BLOCK=N_BLOCK) ref = torch.arange(M * N, dtype=torch.float32, device=device).reshape(M, N).to(torch_dtype) torch.testing.assert_close(out, ref) - - -@triton.jit -def kernel_update_tensor_descriptor_base(desc, a_ptr, b_ptr, M, N, MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - moffset = MBLOCK * pid_m - noffset = NBLOCK * pid_n - - a = desc.load([moffset, noffset]) - - tl.update_tensor_descriptor(desc, base=b_ptr) - - desc.store([moffset, noffset], a + 10) - - -@pytest.mark.interpreter -def test_update_tensor_descriptor_base(device): - M, N = 64, 128 - MBLOCK, NBLOCK = 16, 32 - - torch.manual_seed(42) - A = torch.randn((M, N), dtype=torch.float32, device=device) - B = torch.zeros((M, N), dtype=torch.float32, device=device) - - desc = TensorDescriptor.from_tensor(A, [MBLOCK, NBLOCK]) - - grid = (triton.cdiv(M, MBLOCK), triton.cdiv(N, NBLOCK)) - kernel_update_tensor_descriptor_base[grid]( - desc, A, B, M, N, MBLOCK=MBLOCK, NBLOCK=NBLOCK - ) - - ref_out = A + 10 - torch.testing.assert_close(B, ref_out) - - -@triton.jit -def kernel_update_tensor_descriptor_shape(a_ptr, b_ptr, M1, N1, M2, N2, - MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): - desc = tl.make_tensor_descriptor( - a_ptr, - shape=[M1, N1], - strides=[N1, 1], - block_shape=[MBLOCK, NBLOCK], - ) - - a = desc.load([0, 0]) - - tl.update_tensor_descriptor(desc, base=b_ptr, shape=[M2, N2], strides=[N2, 1]) - - desc.store([0, 0], a * 2) - - -@pytest.mark.interpreter -def test_update_tensor_descriptor_shape(device): - M1, N1 = 32, 64 - M2, N2 = 64, 128 - MBLOCK, NBLOCK = 16, 32 - - torch.manual_seed(42) - A = torch.randn((M1, N1), dtype=torch.float32, device=device) - B = torch.zeros((M2, N2), dtype=torch.float32, device=device) - - def alloc_fn(size: int, align: int, stream: Optional[int]): - return torch.empty(size, dtype=torch.int8, device=device) - triton.set_allocator(alloc_fn) - - kernel_update_tensor_descriptor_shape[(1,)]( - A, B, M1, N1, M2, N2, MBLOCK=MBLOCK, NBLOCK=NBLOCK - ) - - ref_B = torch.zeros((M2, N2), dtype=torch.float32, device=device) - ref_B[:MBLOCK, :NBLOCK] = A[:MBLOCK, :NBLOCK] * 2 - torch.testing.assert_close(B, ref_B) - - -@triton.jit -def kernel_update_tensor_descriptor_strides(a_ptr, b_ptr, M, N, - MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): - desc = tl.make_tensor_descriptor( - a_ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[MBLOCK, NBLOCK], - ) - - a = desc.load([0, 0]) - - tl.update_tensor_descriptor(desc, base=b_ptr, shape=[N, M], strides=[M, 1]) - - desc.store([0, 0], a) - - -@pytest.mark.interpreter -def test_update_tensor_descriptor_strides(device): - M, N = 64, 128 - MBLOCK, NBLOCK = 16, 32 - - torch.manual_seed(42) - A = torch.randn((M, N), dtype=torch.float32, device=device) - B = torch.zeros((N, M), dtype=torch.float32, device=device) - - def alloc_fn(size: int, align: int, stream: Optional[int]): - return torch.empty(size, dtype=torch.int8, device=device) - triton.set_allocator(alloc_fn) - - kernel_update_tensor_descriptor_strides[(1,)]( - A, B, M, N, MBLOCK=MBLOCK, NBLOCK=NBLOCK - ) - - ref_B = torch.zeros((N, M), dtype=torch.float32, device=device) - ref_B[:MBLOCK, :NBLOCK] = A[:MBLOCK, :NBLOCK] - torch.testing.assert_close(B, ref_B) - - -@triton.jit -def kernel_update_tensor_descriptor_loop(ptr, M, N, num_tensors: tl.constexpr, - MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): - pid = tl.program_id(0) - offset = MBLOCK * pid - - desc = tl.make_tensor_descriptor( - ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[MBLOCK, NBLOCK], - ) - - for i in range(num_tensors): - tensor_offset = i * M * N - new_base = ptr + tensor_offset - tl.update_tensor_descriptor(desc, base=new_base) - - data = desc.load([offset, 0]) - data = data + (i + 1) * 10 - desc.store([offset, 0], data) - - -@pytest.mark.interpreter -def test_update_tensor_descriptor_loop(device): - M, N = 64, 128 - MBLOCK, NBLOCK = 16, 128 - num_tensors = 3 - - torch.manual_seed(42) - tensors = torch.randn((num_tensors, M, N), dtype=torch.float32, device=device) - - def alloc_fn(size: int, align: int, stream: Optional[int]): - return torch.empty(size, dtype=torch.int8, device=device) - triton.set_allocator(alloc_fn) - - grid = (triton.cdiv(M, MBLOCK),) - kernel_update_tensor_descriptor_loop[grid]( - tensors, M, N, num_tensors=num_tensors, MBLOCK=MBLOCK, NBLOCK=NBLOCK - ) - - torch.manual_seed(42) - ref_tensors = torch.randn((num_tensors, M, N), dtype=torch.float32, device=device) - for i in range(num_tensors): - ref_tensors[i] = ref_tensors[i] + (i + 1) * 10 - - torch.testing.assert_close(tensors, ref_tensors) - - -@triton.jit -def kernel_update_tensor_descriptor_mixed(in_ptr, out_ptr, M, N, new_M, new_N, - MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): - desc = tl.make_tensor_descriptor( - in_ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[MBLOCK, NBLOCK], - ) - - data = desc.load([0, 0]) - - tl.update_tensor_descriptor( - desc, - base=out_ptr, - shape=[new_M, new_N], - strides=[new_N, 1] - ) - - desc.store([0, 0], data * 3) - - -@pytest.mark.interpreter -def test_update_tensor_descriptor_all_fields(device): - M, N = 32, 64 - new_M, new_N = 64, 128 - MBLOCK, NBLOCK = 16, 32 - - torch.manual_seed(42) - A = torch.randn((M, N), dtype=torch.float32, device=device) - B = torch.zeros((new_M, new_N), dtype=torch.float32, device=device) - - def alloc_fn(size: int, align: int, stream: Optional[int]): - return torch.empty(size, dtype=torch.int8, device=device) - triton.set_allocator(alloc_fn) - - kernel_update_tensor_descriptor_mixed[(1,)]( - A, B, M, N, new_M, new_N, - MBLOCK=MBLOCK, NBLOCK=NBLOCK - ) - - ref_B = torch.zeros((new_M, new_N), dtype=torch.float32, device=device) - ref_B[:MBLOCK, :NBLOCK] = A[:MBLOCK, :NBLOCK] * 3 - torch.testing.assert_close(B, ref_B) - - -@pytest.mark.interpreter -@pytest.mark.parametrize("dtype_str", ["float16", "float32", "bfloat16"]) -def test_update_tensor_descriptor_dtypes(dtype_str, device): - @triton.jit - def kernel(desc, new_ptr, M, N, MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): - data = desc.load([0, 0]) - tl.update_tensor_descriptor(desc, base=new_ptr) - desc.store([0, 0], data + 1) - - M, N = 32, 64 - MBLOCK, NBLOCK = 16, 32 - torch_dtype = getattr(torch, dtype_str) - - torch.manual_seed(42) - A = torch.randn((M, N), dtype=torch.float32, device=device).to(torch_dtype) - B = torch.zeros((M, N), dtype=torch_dtype, device=device) - - desc = TensorDescriptor.from_tensor(A, [MBLOCK, NBLOCK]) - kernel[(1,)](desc, B, M, N, MBLOCK=MBLOCK, NBLOCK=NBLOCK) - - ref_B = torch.zeros((M, N), dtype=torch_dtype, device=device) - ref_B[:MBLOCK, :NBLOCK] = A[:MBLOCK, :NBLOCK] + 1 - torch.testing.assert_close(B, ref_B) - -@triton.jit -def kernel_update_tensor_descriptor_invalid_strides(desc, ptr, M, N, MBLOCK: tl.constexpr, NBLOCK: tl.constexpr): - tl.update_tensor_descriptor(desc, base=ptr, strides=[N, 1]) - -@pytest.mark.interpreter -def test_update_tensor_descriptor_invalid_strides_compile_error(device): - M, N = 32, 64 - MBLOCK, NBLOCK = 16, 32 - A = torch.empty((M, N), dtype=torch.float32, device=device) - desc = TensorDescriptor.from_tensor(A, [MBLOCK, NBLOCK]) - with pytest.raises(triton.CompilationError): - kernel_update_tensor_descriptor_invalid_strides[(1,)](desc, A, M, N, MBLOCK=MBLOCK, NBLOCK=NBLOCK) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index ad6ddb398b21..06acf870fc1b 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -32,7 +32,6 @@ load_tensor_descriptor, store_tensor_descriptor, make_tensor_descriptor, - update_tensor_descriptor, tensor_descriptor, tensor_descriptor_type, add, @@ -143,7 +142,6 @@ "load_tensor_descriptor", "store_tensor_descriptor", "make_tensor_descriptor", - "update_tensor_descriptor", "tensor_descriptor", "abs", "add", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 970cfed5714c..e125245d3607 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1456,34 +1456,28 @@ def scatter(self, value, *args, _semantic=None) -> tensor: class tensor_descriptor_type(tensor_descriptor_base_type): - def __init__(self, block_type: block_type, base_type: base_type, shape_type: tuple_type, strides_type: tuple_type): + def __init__(self, block_type: block_type, shape_type: tuple_type, strides_type: tuple_type): self.block_type = block_type - self.base_type = base_type self.shape_type = shape_type self.strides_type = strides_type def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: handle = handles[cursor] cursor += 1 - # Note: base is not unflattened, it's embedded in the - # descriptor handle. - base = None shape, cursor = self.shape_type._unflatten_ir(handles, cursor) strides, cursor = self.strides_type._unflatten_ir(handles, cursor) shape = shape.values strides = strides.values - value = tensor_descriptor(handle, base, shape, strides, self.block_type, base_type=self.base_type) + value = tensor_descriptor(handle, shape, strides, self.block_type) return value, cursor def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: super()._flatten_ir_types(builder, out) - # Note: base_type is not flattened, it's embedded in the - # descriptor handle. self.shape_type._flatten_ir_types(builder, out) self.strides_type._flatten_ir_types(builder, out) def __eq__(self, other): - return super().__eq__(other) and (self.base_type == other.base_type) and (self.shape_type == other.shape_type) and (self.strides_type + return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type == other.strides_type) @@ -1491,32 +1485,21 @@ class tensor_descriptor(tensor_descriptor_base): """A descriptor representing a tensor in global memory. """ - def __init__(self, handle, base: tensor, shape: List[tensor], strides: List[tensor], block_type: block_type, base_type=None): + def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type): """Not called by user code.""" # IR handle super().__init__(handle, block_type) - # Base pointer to global memory tensor - self.base = base - # Global shape and strides + # Global shape self.shape = tuple(shape) self.strides = tuple(strides) - # If base_type is not provided, infer it from base. If base - # is None, base_type must be provided - if base_type is None: - if base is None: - raise ValueError("Either base or base_type must be provided") - base_type = base.type self.type = tensor_descriptor_type( block_type, - base_type, - self.shape.type, - self.strides.type, + shape_type=self.shape.type, + strides_type=self.strides.type, ) def _flatten_ir(self, handles: List[ir.value]) -> None: handles.append(self.handle) - # Note: base is NOT flattened - it's embedded in the descriptor handle - # self.base._flatten_ir(handles) self.shape._flatten_ir(handles) self.strides._flatten_ir(handles) @@ -2366,87 +2349,6 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): return _semantic.make_tensor_descriptor(base, shape, strides, block_shape, padding_option) -@builtin -def update_tensor_descriptor( - desc: tensor_descriptor, - base: tensor = None, - shape: List[tensor] = None, - strides: List[tensor] = None, - _semantic=None, -) -> None: - """Update an existing TMA descriptor - - Updates one or more fields of an existing TMA descriptor. - - :param desc: The existing tensor descriptor to update - :param base: The new base pointer, must be 16-byte aligned (optional) - :param shape: The new tensor shape (optional) - :param strides: The new tensor strides (optional) - - Notes - ***** - - At least one field (base, shape, or strides) must be provided - - When providing strides, shape must also be provided - - Shape and strides must have the same length - - Same limitation for updates values hold as for `make_tensor_descriptor` - - Example - ******* - .. code-block:: python - - @triton.jit - def kernel(ptr, M: tl.constexpr, N: tl.constexpr): - # Create descriptor - desc = tl.make_tensor_descriptor( - ptr, - shape=[M, N], - strides=[N, 1], - block_shape=[16, 16] - ) - - # Update to new shape - tl.update_tensor_descriptor( - desc, - shape=[M//2, N], - ) - - # Use updated descriptor - data = desc.load([0, 0]) - """ - if base is None and shape is None and strides is None: - raise ValueError("At least one descriptor field must be updated") - - if strides is not None and shape is None: - raise ValueError("Cannot update strides without providing shape") - - if shape is not None and strides is not None: - if len(shape) != len(strides): - raise ValueError(f"Shape and strides must have the same length, got {len(shape)} and {len(strides)}") - - last_stride = _unwrap_if_constexpr(strides[-1]) - if last_stride != 1: - raise ValueError(f"Tensor descriptor last dim must be 1 but got {last_stride}") - - if shape is not None: - shape = [_semantic.make_scalar(s, int32) for s in shape] - if strides is not None: - strides = [_semantic.make_scalar(_unwrap_if_constexpr(s), int64) for s in strides] - - _semantic.builder.create_update_tensor_descriptor( - desc.handle, - base=base.handle if base is not None else None, - shape=[s.handle for s in shape] if shape is not None else [], - strides=[s.handle for s in strides] if strides is not None else [] - ) - - if base is not None: - desc.base = base - if shape is not None: - desc.shape = tuple(shape) - if strides is not None: - desc.strides = tuple(strides) - - # ----------------------- # Atomic Memory Operations # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 36ab7f47966b..6db9ffb25fec 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1963,4 +1963,4 @@ def make_tensor_descriptor(self, base: TensorTy, shape: List[TensorTy], strides: handle = self.builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape], [s.handle for s in strides], block_shape, is_signed_int, padding) - return tl.tensor_descriptor(handle, base, shape, strides, type) + return tl.tensor_descriptor(handle, shape, strides, type) diff --git a/test/TritonNvidiaGPU/tma_lowering.mlir b/test/TritonNvidiaGPU/tma_lowering.mlir index 18572e72d86f..c90ee6b28b94 100644 --- a/test/TritonNvidiaGPU/tma_lowering.mlir +++ b/test/TritonNvidiaGPU/tma_lowering.mlir @@ -55,21 +55,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- -#nvmma_32_update = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}> - -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: update_tensor_descriptor - // CHECK: %[[DESC_PTR:.+]] = ttng.get_descriptor_ptr %arg0 - // CHECK: ttng.tensormap_update %[[DESC_PTR]] global_address = %arg1 - // CHECK: ttng.tensormap_fenceproxy_acquire %[[DESC_PTR]] : !tt.ptr - tt.func public @update_tensor_descriptor(%arg0: !tt.tensordesc>, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { - tt.update_tensor_descriptor %arg0 base = %arg1 : !tt.ptr : !tt.tensordesc> - tt.return - } -} - -// ----- - #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #nvmma_128 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> From 12be2b09777093fcece9bc224f174f634ba42eff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Fri, 21 Nov 2025 17:47:06 +0000 Subject: [PATCH 3/3] Make the operation work only on in-kernel created descriptors --- lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 20 +++++++ python/test/gluon/test_core.py | 19 +++++++ .../gluon/language/nvidia/hopper/tma.py | 57 +++++++++++++++++++ 3 files changed, 96 insertions(+) diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index b96392747a3f..a20b238872fe 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -26,6 +26,7 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" @@ -876,6 +877,25 @@ LogicalResult TensormapUpdateOp::verify() { << getGlobalDim().size() - 1; } } + + // Verify that the descriptor being updated is not shared across CTAs. + // Descriptors passed as kernel parameters (from TensorDescriptor.from_tensor()) + // are stored in constant memory and shared across all CTAs, so updating them + // would cause race conditions. Descriptors created in-kernel with + // make_tensor_descriptor() are per-CTA and safe to update. + auto descPtr = getDescPtr(); + auto getDescPtrOp = descPtr.getDefiningOp(); + if (getDescPtrOp) { + auto desc = getDescPtrOp.getDesc(); + + // Use the existing utility to check if this is a host-side descriptor + if (triton::isHostSideDescriptor(desc)) { + return emitError("Descriptor must be created within the kernel using " + "make_tensor_descriptor. Updating descriptors passed as " + "kernel parameters would cause race conditions across CTAs."); + } + } + return success(); } diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index b2513136077c..ed22a3010623 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -1913,3 +1913,22 @@ def alloc_fn(size: int, alignment: int, stream: int): ref = torch.zeros((N, M), device="cuda", dtype=torch.float16) ref[:M, :M] = input_data[:, :M] torch.testing.assert_close(output, ref) + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") +def test_update_tensor_descriptor_from_tensor_fails(): + M, N = 16, 16 + input_data = torch.randn((M, N), device="cuda", dtype=torch.float16) + + block_shape = [M, N] + layout = ttgl.NVMMASharedLayout.get_default_for(block_shape, ttgl.float16) + desc = gluon.nvidia.hopper.TensorDescriptor.from_tensor(input_data, block_shape, layout) + + output_data = torch.zeros_like(input_data) + + @gluon.jit + def kernel(desc, new_ptr): + tma.update_tensor_descriptor(desc, base=new_ptr) + + with pytest.raises(Exception): + kernel[(1,)](desc, output_data) diff --git a/python/triton/experimental/gluon/language/nvidia/hopper/tma.py b/python/triton/experimental/gluon/language/nvidia/hopper/tma.py index eca9c5845ed0..2afe63ed222a 100644 --- a/python/triton/experimental/gluon/language/nvidia/hopper/tma.py +++ b/python/triton/experimental/gluon/language/nvidia/hopper/tma.py @@ -198,6 +198,63 @@ def update_tensor_descriptor( strides: List[ttgl.tensor] = None, _semantic=None, ) -> None: + """Update an existing TMA descriptor + + Updates one or more fields of an existing TMA descriptor. + + :param desc: The existing tensor descriptor to update + :param base: The new base pointer, must be 16-byte aligned (optional) + :param shape: The new tensor shape (optional) + :param strides: The new tensor strides (optional) + + Notes + ***** + - At least one field (base, shape, or strides) must be provided + - When providing strides, shape must also be provided + - Shape and strides must have the same length + - Same limitations for updated values hold as for `make_tensor_descriptor` + - The descriptor to be updated must be created within the kernel + using make_tensor_descriptor(). Descriptors passed as kernel + parameters (e.g., from TensorDescriptor.from_tensor()) cannot be + updated, as they reside in constant memory and updating them + would cause race conditions when the grid size exceeds the + number of SMs. + + Example + ******* + .. code-block:: python + + @gluon.jit + def kernel(ptr, M: ttgl.constexpr, N: ttgl.constexpr, smem_layout: ttgl.constexpr): + # Create descriptor in-kernel + desc = tma.make_tensor_descriptor( + ptr, + shape=[M, N], + strides=[N, 1], + block_shape=[16, 16], + layout=smem_layout + ) + + # ... + + # Later, update to point to second half, along M, of the tensor + new_ptr = ptr + M // 2 * N + tma.update_tensor_descriptor( + desc, + base=new_ptr, + shape=[M//2, N], + strides=[N, 1] + ) + + # Using updated descriptor, copy from second half to SMEM + smem = ttgl.allocate_shared_memory(ttgl.float16, [16, 16], smem_layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + mbarrier.expect(bar, desc.block_type.nbytes) + tma.async_copy_global_to_shared(desc, [0, 0], bar, smem) + mbarrier.wait(bar, 0) + + """ if base is None and shape is None and strides is None: raise ValueError("At least one descriptor field must be updated")