diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 04943280fcfe..4bff108a9742 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -527,10 +527,12 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [ let hasCanonicalizeMethod = 1; } -def TTG_WarpSpecializePartitionsOp : TTG_Op<"warp_specialize.partitions", [ - IsolatedFromAbove, RecursiveMemoryEffects, RecursivelySpeculatable, - Terminator, HasParent<"WarpSpecializeOp"> -]> { +def TTG_WarpSpecializePartitionsOp + : TTG_Op<"warp_specialize.partitions", + [IsolatedFromAbove, RecursiveMemoryEffects, + RecursivelySpeculatable, Terminator, + HasParent<"WarpSpecializeOp">, + DeclareOpInterfaceMethods]> { let summary = "container op for `ttg.warp_specialize`"; let description = [{ Because MLIR requires entire operations be isolated from above, this op diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 983b6645b8c1..2c97a4855288 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -232,15 +232,9 @@ class AllocationAnalysis { std::unique_ptr solver = createDataFlowSolver(); SharedMemoryAliasAnalysis *aliasAnalysis = solver->load(); - // Run the analysis rooted at every isolated from above operation, including - // the top-level function but also any nested regions. - operation->walk([&](Operation *op) { - if (op->hasTrait() && - failed(solver->initializeAndRun(op))) { - // TODO: return error instead of bailing out.. - llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); - } - }); + if (failed(solver->initializeAndRun(operation))) { + llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); + } operation->walk([&](Operation *op) { for (auto operand : op->getOperands()) { getValueAlias(operand, *aliasAnalysis); diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 50ded51aa397..5f862b5ff282 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1185,13 +1185,6 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) { initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, &knownContiguity, &knownDivisibility, &knownConstancy); - } else if (isa(op)) { - // Initialize the arguments to gpu::WarpSpecializePartitionsOp with - // "unknown" state: the maximum possible divisibility, contiguity, and - // constancy. - knownDivisibility = DimVectorT(rank, kMaxDivisor); - knownConstancy = DimVectorT(rank, kMaxDivisor); - knownContiguity = DimVectorT(rank, kMaxDivisor); } } else if (Operation *op = value.getDefiningOp()) { // Other operations are conservatively initialized with the lowest possible @@ -1331,16 +1324,7 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp, axisinfo::CallbackType callback) { std::unique_ptr solver = createDataFlowSolver(); AxisInfoAnalysis *analysis = solver->load(callback); - // Walk pre-order so analysis results can be propagated into nested isolated - // regions. - WalkResult result = - funcOp.walk([&](Operation *op) { - if (op->hasTrait() && - failed(solver->initializeAndRun(op))) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - if (result.wasInterrupted()) + if (failed(solver->initializeAndRun(funcOp))) return; auto *axisInfoMap = getFuncData(funcOp); diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 400efe58a98b..c8e278fd7f74 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -38,13 +38,8 @@ void MembarOrFenceAnalysis::resolve(FunctionOpInterface funcOp, DenseMap inputBlockInfoMap; DenseMap outputBlockInfoMap; std::deque blockList; - funcOp.walk([&](Block *block) { - // Start the analysis from the entry blocks of any nested isolated from - // above regions. - if (block->isEntryBlock() && - !isa(block->getParentOp())) - blockList.emplace_back(block, Block::iterator()); - }); + // Start the analysis from the entry block of the function. + blockList.emplace_back(&funcOp.getBlocks().front(), Block::iterator()); // A fixed point algorithm while (!blockList.empty()) { @@ -56,12 +51,15 @@ void MembarOrFenceAnalysis::resolve(FunctionOpInterface funcOp, Block::iterator startIt = block.second.isValid() ? std::next(block.second) : block.first->begin(); for (Operation &op : llvm::make_range(startIt, block.first->end())) { + // Update inputBlockInfo based on the current operation. Note that we do + // this before we process terminators and branch-like ops, because some of + // them (e.g. WarpSpecializePartitionsOp) may have synchronizing effects. + update(&op, &inputBlockInfo, funcBlockInfoMap, builder); if (op.hasTrait() || isa(op)) { visitTerminator(&op, successors); break; } - update(&op, &inputBlockInfo, funcBlockInfoMap, builder); } // Get the reference because we want to update if it changed if (outputBlockInfoMap.count(block) && @@ -165,7 +163,8 @@ void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) { void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder) { - if (isa(op)) { + if (isa(op)) { // If the current op is a barrier, we sync previous reads and writes blockInfo->sync(); return; diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 97c351586820..19df4055b7e4 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -923,14 +923,24 @@ RegionRange WarpSpecializeOp::getPartitionRegions() { void WarpSpecializeOp::getSuccessorRegions( RegionBranchPoint src, SmallVectorImpl &successors) { - // The parent branches transparently into the default region. + // The parent branches into the default region and the partition regions. if (src.isParent()) { successors.emplace_back(&getDefaultRegion()); + successors.emplace_back(&getPartitionOpHolder()); return; } // And the default region branches transparently back to the parent. - assert(src.getRegionOrNull() == &getDefaultRegion()); - successors.push_back(RegionSuccessor(getResults())); + if (src.getRegionOrNull() == &getDefaultRegion()) + successors.push_back(RegionSuccessor(getResults())); +} + +void WarpSpecializePartitionsOp::getSuccessorRegions( + RegionBranchPoint src, SmallVectorImpl &successors) { + // The parent branches to each of the partition regions, but nothing flows out + // of the partition regions. + if (src.isParent()) + for (Region ®ion : getPartitionRegions()) + successors.emplace_back(®ion); } LogicalResult WarpSpecializeOp::verify() {