Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,46 @@ def TTNG_AsyncCopyMbarrierArriveOp : TTNG_Op<"async_copy_mbarrier_arrive"> {
let assemblyFormat = "$barrier attr-dict `:` qualified(type($barrier))";
}

def TTNG_AsyncCLCTryCancelOp : TTNG_Op<"async_clc_try_cancel", []> {
let summary = "Requests cancellation of cluster which is not launched yet";

let description = [{
Requests atomically cancelling the launch of a cluster that has not started running yet.

This lowers using PTX instruction
clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128

It asynchronously writes an opaque response (16-byte CLC response) to shared memory. The completion of the asynchronous operation is tracked using the mbarrier object in `alloc`.

https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel
}];

let arguments = (ins
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$mbarAlloc,
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$clcResAlloc
);

let assemblyFormat = "$mbarAlloc`,` $clcResAlloc attr-dict `:` type(operands)";
}

def TTNG_CLCQueryCancelOp : TTNG_Op<"clc_query_cancel", []> {
let summary = "Extract CTA ID from CLC response";

let description = [{
Extract CTA ID from CLC response if try_cancel was successful.
Otherwise, returns -1.

https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-query-cancel
}];

let arguments = (ins
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$clcResAlloc
);

let results = (outs I32:$ctaId);

let assemblyFormat = "$clcResAlloc attr-dict `:` functional-type(operands, $ctaId)";
}

def TTNG_AsyncTMACopyGlobalToLocalOp : TTNG_Op<"async_tma_copy_global_to_local"> {
let summary = "copy data based on descriptor from global memory to local memory asynchronously";
Expand Down
28 changes: 28 additions & 0 deletions test/Conversion/tritonnvidiagpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,34 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
}


// -----

#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: async_clc_try_cancel
// CHECK: clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128
tt.func @async_clc_try_cancel(%alloc: !ttg.memdesc<1xi64, #shared0, #smem, mutable>, %clc_response: !ttg.memdesc<1xui128, #shared0, #smem, mutable>) {
ttng.async_clc_try_cancel %alloc, %clc_response : !ttg.memdesc<1xi64, #shared0, #smem, mutable>, !ttg.memdesc<1xui128, #shared0, #smem, mutable>
tt.return
}
}

// -----

#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK-LABEL: clc_query_cancel
// CHECK: clusterlaunchcontrol.query_cancel.is_canceled.pred.b128
// CHECK: clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128
tt.func @clc_query_cancel(%clc_response: !ttg.memdesc<1xui128, #shared0, #smem, mutable>) {
%x = ttng.clc_query_cancel %clc_response : (!ttg.memdesc<1xui128, #shared0, #smem, mutable>) -> i32
tt.return
}
}


// -----

#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,90 @@ struct ArriveBarrierOpConversion
return success();
}
};

struct AsyncCLCTryCancelOpConversion
: public ConvertOpToLLVMPattern<triton::nvidia_gpu::AsyncCLCTryCancelOp> {
// TODO. check target infor for compute capability >= 100
using ConvertOpToLLVMPattern<
triton::nvidia_gpu::AsyncCLCTryCancelOp>::ConvertOpToLLVMPattern;

// clc response is 16-byte opaque object available at the location specified
// by the 16-byte wide shared memory address (i.e. 1st operand of PTX inst)
LogicalResult
matchAndRewrite(triton::nvidia_gpu::AsyncCLCTryCancelOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();

auto tid = getThreadId(rewriter, loc);
TritonLLVMOpBuilder b(op.getLoc(), rewriter);
Value pred = b.icmp_eq(tid, b.i32_val(0));

std::string ptx = R"(
{
.reg .u32 first_cta_in_cluster;
.reg .pred pred_first_cta_in_cluster;
.reg .pred pred_issue;
mov.u32 first_cta_in_cluster, %cluster_ctaid.x;
setp.u32.eq pred_first_cta_in_cluster, first_cta_in_cluster, 0x0;
and.pred pred_issue, $2, pred_first_cta_in_cluster;
Comment on lines +287 to +289
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separate this out of the inline ptx, this will allow the code sequence to be optimized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, happy to do that. would you elaborate more about to what extent I should separate this out? asking because I was basically following the same style in ArriveBarrierOpConversion. more context or an existing example would be even better.

@pred_issue clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes.multicast::cluster::all.b128 [$0], [$1];
}
)";

PTXBuilder ptxBuilder;
SmallVector<PTXBuilder::Operand *, 3> operands = {
ptxBuilder.newOperand(adaptor.getClcResAlloc(), "r"),
ptxBuilder.newOperand(adaptor.getMbarAlloc(), "r"),
ptxBuilder.newOperand(pred, "b")};

auto clcOp = *ptxBuilder.create<>(ptx);
clcOp(operands, /*onlyAttachMLIRArgs=*/true);
auto voidTy = void_ty(getContext());
ptxBuilder.launch(rewriter, op.getLoc(), voidTy);

rewriter.eraseOp(op);
return success();
}
};

struct CLCQueryCancelOpConversion
: public ConvertOpToLLVMPattern<triton::nvidia_gpu::CLCQueryCancelOp> {
using ConvertOpToLLVMPattern<
triton::nvidia_gpu::CLCQueryCancelOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::nvidia_gpu::CLCQueryCancelOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();

TritonLLVMOpBuilder b(op.getLoc(), rewriter);

std::string ptx = R"(
{
.reg .b128 clc_result;
.reg .pred p1;
mov.s32 $0, -1;
ld.shared.b128 clc_result, [$1];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, can we separate this out?

clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 p1, clc_result;
@p1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {$0, _, _, _}, clc_result;
}
)";

PTXBuilder builder;
auto queryOp = *builder.create<>(ptx);

SmallVector<PTXBuilder::Operand *, 2> operands = {
builder.newOperand("=r", true),
builder.newOperand(adaptor.getClcResAlloc(), "r")};
queryOp(operands, /*onlyAttachMLIRArgs=*/false);

Value ctaId = builder.launch(rewriter, op.getLoc(), i32_ty, false);

rewriter.replaceOp(op, ctaId);

return success();
}
};
} // namespace

void mlir::triton::NVIDIA::populateBarrierOpToLLVMPatterns(
Expand All @@ -272,4 +356,6 @@ void mlir::triton::NVIDIA::populateBarrierOpToLLVMPatterns(
patterns.add<WaitBarrierOpConversion>(typeConverter, benefit, targetInfo);
patterns.add<BarrierExpectConversion>(typeConverter, benefit);
patterns.add<ArriveBarrierOpConversion>(typeConverter, benefit);
patterns.add<AsyncCLCTryCancelOpConversion>(typeConverter, benefit);
patterns.add<CLCQueryCancelOpConversion>(typeConverter, benefit);
}
Loading