From f96c76fbcfb3832d282125eccf0a16fa2ac12f6a Mon Sep 17 00:00:00 2001 From: Neil Dhar Date: Mon, 17 Nov 2025 22:10:45 -0800 Subject: [PATCH] Make WarpSpecializePartitionsOp implement RegionBranchInterface Existing IR traversals for things like dataflow analysis cannot automatically descend into the non-default partitions of a WarpSpecializeOp because we do not encode the edges to these regions through RegionBranchInterface. Changing the warpspec operations to reflect these edges allows us to clean up code in the dataflow analyses that walks the IR looking for `WarpSpecializePartitionsOp`, and populates its block args. --- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 10 ++++++---- lib/Analysis/Allocation.cpp | 12 +++--------- lib/Analysis/AxisInfo.cpp | 18 +----------------- lib/Analysis/Membar.cpp | 17 ++++++++--------- lib/Dialect/TritonGPU/IR/Ops.cpp | 16 +++++++++++++--- 5 files changed, 31 insertions(+), 42 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 04943280fcfe..4bff108a9742 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -527,10 +527,12 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [ let hasCanonicalizeMethod = 1; } -def TTG_WarpSpecializePartitionsOp : TTG_Op<"warp_specialize.partitions", [ - IsolatedFromAbove, RecursiveMemoryEffects, RecursivelySpeculatable, - Terminator, HasParent<"WarpSpecializeOp"> -]> { +def TTG_WarpSpecializePartitionsOp + : TTG_Op<"warp_specialize.partitions", + [IsolatedFromAbove, RecursiveMemoryEffects, + RecursivelySpeculatable, Terminator, + HasParent<"WarpSpecializeOp">, + DeclareOpInterfaceMethods]> { let summary = "container op for `ttg.warp_specialize`"; let description = [{ Because MLIR requires entire operations be isolated from above, this op diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 983b6645b8c1..2c97a4855288 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -232,15 +232,9 @@ class AllocationAnalysis { std::unique_ptr solver = createDataFlowSolver(); SharedMemoryAliasAnalysis *aliasAnalysis = solver->load(); - // Run the analysis rooted at every isolated from above operation, including - // the top-level function but also any nested regions. - operation->walk([&](Operation *op) { - if (op->hasTrait() && - failed(solver->initializeAndRun(op))) { - // TODO: return error instead of bailing out.. - llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); - } - }); + if (failed(solver->initializeAndRun(operation))) { + llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); + } operation->walk([&](Operation *op) { for (auto operand : op->getOperands()) { getValueAlias(operand, *aliasAnalysis); diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 50ded51aa397..5f862b5ff282 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1185,13 +1185,6 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) { initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, &knownContiguity, &knownDivisibility, &knownConstancy); - } else if (isa(op)) { - // Initialize the arguments to gpu::WarpSpecializePartitionsOp with - // "unknown" state: the maximum possible divisibility, contiguity, and - // constancy. - knownDivisibility = DimVectorT(rank, kMaxDivisor); - knownConstancy = DimVectorT(rank, kMaxDivisor); - knownContiguity = DimVectorT(rank, kMaxDivisor); } } else if (Operation *op = value.getDefiningOp()) { // Other operations are conservatively initialized with the lowest possible @@ -1331,16 +1324,7 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp, axisinfo::CallbackType callback) { std::unique_ptr solver = createDataFlowSolver(); AxisInfoAnalysis *analysis = solver->load(callback); - // Walk pre-order so analysis results can be propagated into nested isolated - // regions. - WalkResult result = - funcOp.walk([&](Operation *op) { - if (op->hasTrait() && - failed(solver->initializeAndRun(op))) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - if (result.wasInterrupted()) + if (failed(solver->initializeAndRun(funcOp))) return; auto *axisInfoMap = getFuncData(funcOp); diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 400efe58a98b..c8e278fd7f74 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -38,13 +38,8 @@ void MembarOrFenceAnalysis::resolve(FunctionOpInterface funcOp, DenseMap inputBlockInfoMap; DenseMap outputBlockInfoMap; std::deque blockList; - funcOp.walk([&](Block *block) { - // Start the analysis from the entry blocks of any nested isolated from - // above regions. - if (block->isEntryBlock() && - !isa(block->getParentOp())) - blockList.emplace_back(block, Block::iterator()); - }); + // Start the analysis from the entry block of the function. + blockList.emplace_back(&funcOp.getBlocks().front(), Block::iterator()); // A fixed point algorithm while (!blockList.empty()) { @@ -56,12 +51,15 @@ void MembarOrFenceAnalysis::resolve(FunctionOpInterface funcOp, Block::iterator startIt = block.second.isValid() ? std::next(block.second) : block.first->begin(); for (Operation &op : llvm::make_range(startIt, block.first->end())) { + // Update inputBlockInfo based on the current operation. Note that we do + // this before we process terminators and branch-like ops, because some of + // them (e.g. WarpSpecializePartitionsOp) may have synchronizing effects. + update(&op, &inputBlockInfo, funcBlockInfoMap, builder); if (op.hasTrait() || isa(op)) { visitTerminator(&op, successors); break; } - update(&op, &inputBlockInfo, funcBlockInfoMap, builder); } // Get the reference because we want to update if it changed if (outputBlockInfoMap.count(block) && @@ -165,7 +163,8 @@ void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) { void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder) { - if (isa(op)) { + if (isa(op)) { // If the current op is a barrier, we sync previous reads and writes blockInfo->sync(); return; diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 97c351586820..19df4055b7e4 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -923,14 +923,24 @@ RegionRange WarpSpecializeOp::getPartitionRegions() { void WarpSpecializeOp::getSuccessorRegions( RegionBranchPoint src, SmallVectorImpl &successors) { - // The parent branches transparently into the default region. + // The parent branches into the default region and the partition regions. if (src.isParent()) { successors.emplace_back(&getDefaultRegion()); + successors.emplace_back(&getPartitionOpHolder()); return; } // And the default region branches transparently back to the parent. - assert(src.getRegionOrNull() == &getDefaultRegion()); - successors.push_back(RegionSuccessor(getResults())); + if (src.getRegionOrNull() == &getDefaultRegion()) + successors.push_back(RegionSuccessor(getResults())); +} + +void WarpSpecializePartitionsOp::getSuccessorRegions( + RegionBranchPoint src, SmallVectorImpl &successors) { + // The parent branches to each of the partition regions, but nothing flows out + // of the partition regions. + if (src.isParent()) + for (Region ®ion : getPartitionRegions()) + successors.emplace_back(®ion); } LogicalResult WarpSpecializeOp::verify() {