Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,12 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [
let hasCanonicalizeMethod = 1;
}

def TTG_WarpSpecializePartitionsOp : TTG_Op<"warp_specialize.partitions", [
IsolatedFromAbove, RecursiveMemoryEffects, RecursivelySpeculatable,
Terminator, HasParent<"WarpSpecializeOp">
]> {
def TTG_WarpSpecializePartitionsOp
: TTG_Op<"warp_specialize.partitions",
[IsolatedFromAbove, RecursiveMemoryEffects,
RecursivelySpeculatable, Terminator,
HasParent<"WarpSpecializeOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
let summary = "container op for `ttg.warp_specialize`";
let description = [{
Because MLIR requires entire operations be isolated from above, this op
Expand Down
12 changes: 3 additions & 9 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,9 @@ class AllocationAnalysis {
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
SharedMemoryAliasAnalysis *aliasAnalysis =
solver->load<SharedMemoryAliasAnalysis>();
// Run the analysis rooted at every isolated from above operation, including
// the top-level function but also any nested regions.
operation->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
failed(solver->initializeAndRun(op))) {
// TODO: return error instead of bailing out..
llvm_unreachable("failed to run SharedMemoryAliasAnalysis");
}
});
if (failed(solver->initializeAndRun(operation))) {
llvm_unreachable("failed to run SharedMemoryAliasAnalysis");
}
operation->walk<WalkOrder::PreOrder>([&](Operation *op) {
for (auto operand : op->getOperands()) {
getValueAlias(operand, *aliasAnalysis);
Expand Down
18 changes: 1 addition & 17 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1185,13 +1185,6 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
&knownContiguity, &knownDivisibility,
&knownConstancy);
} else if (isa<gpu::WarpSpecializePartitionsOp>(op)) {
// Initialize the arguments to gpu::WarpSpecializePartitionsOp with
// "unknown" state: the maximum possible divisibility, contiguity, and
// constancy.
knownDivisibility = DimVectorT(rank, kMaxDivisor);
knownConstancy = DimVectorT(rank, kMaxDivisor);
knownContiguity = DimVectorT(rank, kMaxDivisor);
}
} else if (Operation *op = value.getDefiningOp()) {
// Other operations are conservatively initialized with the lowest possible
Expand Down Expand Up @@ -1331,16 +1324,7 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
axisinfo::CallbackType callback) {
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>(callback);
// Walk pre-order so analysis results can be propagated into nested isolated
// regions.
WalkResult result =
funcOp.walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
failed(solver->initializeAndRun(op)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (result.wasInterrupted())
if (failed(solver->initializeAndRun(funcOp)))
return;

auto *axisInfoMap = getFuncData(funcOp);
Expand Down
17 changes: 8 additions & 9 deletions lib/Analysis/Membar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,8 @@ void MembarOrFenceAnalysis::resolve(FunctionOpInterface funcOp,
DenseMap<VirtualBlock, BlockInfo> inputBlockInfoMap;
DenseMap<VirtualBlock, BlockInfo> outputBlockInfoMap;
std::deque<VirtualBlock> blockList;
funcOp.walk<WalkOrder::PreOrder>([&](Block *block) {
// Start the analysis from the entry blocks of any nested isolated from
// above regions.
if (block->isEntryBlock() &&
!isa<RegionBranchOpInterface>(block->getParentOp()))
blockList.emplace_back(block, Block::iterator());
});
// Start the analysis from the entry block of the function.
blockList.emplace_back(&funcOp.getBlocks().front(), Block::iterator());

// A fixed point algorithm
while (!blockList.empty()) {
Expand All @@ -56,12 +51,15 @@ void MembarOrFenceAnalysis::resolve(FunctionOpInterface funcOp,
Block::iterator startIt =
block.second.isValid() ? std::next(block.second) : block.first->begin();
for (Operation &op : llvm::make_range(startIt, block.first->end())) {
// Update inputBlockInfo based on the current operation. Note that we do
// this before we process terminators and branch-like ops, because some of
// them (e.g. WarpSpecializePartitionsOp) may have synchronizing effects.
update(&op, &inputBlockInfo, funcBlockInfoMap, builder);
if (op.hasTrait<OpTrait::IsTerminator>() ||
isa<RegionBranchOpInterface>(op)) {
visitTerminator(&op, successors);
break;
}
update(&op, &inputBlockInfo, funcBlockInfoMap, builder);
}
// Get the reference because we want to update if it changed
if (outputBlockInfoMap.count(block) &&
Expand Down Expand Up @@ -165,7 +163,8 @@ void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) {
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) {
if (isa<gpu::BarrierOp, triton::gpu::LocalBarrierOp>(op)) {
if (isa<gpu::BarrierOp, triton::gpu::LocalBarrierOp,
triton::gpu::WarpSpecializePartitionsOp>(op)) {
// If the current op is a barrier, we sync previous reads and writes
blockInfo->sync();
return;
Expand Down
16 changes: 13 additions & 3 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -923,14 +923,24 @@ RegionRange WarpSpecializeOp::getPartitionRegions() {

void WarpSpecializeOp::getSuccessorRegions(
RegionBranchPoint src, SmallVectorImpl<RegionSuccessor> &successors) {
// The parent branches transparently into the default region.
// The parent branches into the default region and the partition regions.
if (src.isParent()) {
successors.emplace_back(&getDefaultRegion());
successors.emplace_back(&getPartitionOpHolder());
return;
}
// And the default region branches transparently back to the parent.
assert(src.getRegionOrNull() == &getDefaultRegion());
successors.push_back(RegionSuccessor(getResults()));
if (src.getRegionOrNull() == &getDefaultRegion())
successors.push_back(RegionSuccessor(getResults()));
}

void WarpSpecializePartitionsOp::getSuccessorRegions(
RegionBranchPoint src, SmallVectorImpl<RegionSuccessor> &successors) {
// The parent branches to each of the partition regions, but nothing flows out
// of the partition regions.
if (src.isParent())
for (Region &region : getPartitionRegions())
successors.emplace_back(&region);
}

LogicalResult WarpSpecializeOp::verify() {
Expand Down