Skip to content

Commit 5f978b4

Browse files
authored
[Comb] Canonicalize power-of-two unsigned div/mod. (#9177)
This commit adds canonicalization patterns for unsigned division and modulo operations when the divisor is a power of two constant. These operations can be efficiently lowered to bit extraction and concatenation operations. This is done in CombToSynth already but would be good to perform as part of canonicalization.
1 parent 4b301e0 commit 5f978b4

File tree

9 files changed

+188
-33
lines changed

9 files changed

+188
-33
lines changed

include/circt/Dialect/Comb/CombOps.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,15 @@ Value createInject(OpBuilder &builder, Location loc, Value value,
9090
LogicalResult convertSubToAdd(comb::SubOp subOp,
9191
mlir::PatternRewriter &rewriter);
9292

93+
/// Convert unsigned division or modulo by a power of two.
94+
/// For division: divu(x, 2^n) -> concat(0...0, extract(x, n, width-n)).
95+
/// For modulo: modu(x, 2^n) -> concat(0...0, extract(x, 0, n))
96+
/// TODO: Support signed division and modulo.
97+
LogicalResult convertDivUByPowerOfTwo(DivUOp divOp,
98+
mlir::PatternRewriter &rewriter);
99+
LogicalResult convertModUByPowerOfTwo(ModUOp modOp,
100+
mlir::PatternRewriter &rewriter);
101+
93102
/// Enum for mux chain folding styles.
94103
enum MuxChainWithComparisonFoldingStyle { None, BalancedMuxTree, ArrayGet };
95104
/// Mux chain folding that converts chains of muxes with index

include/circt/Dialect/Comb/Combinational.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ class UTVariadicOp<string mnemonic, list<Trait> traits = []> :
6868
def AddOp : UTVariadicOp<"add", [Commutative]>;
6969
def MulOp : UTVariadicOp<"mul", [Commutative]>;
7070
let hasFolder = true in {
71-
def DivUOp : UTBinOp<"divu">;
7271
def DivSOp : UTBinOp<"divs">;
73-
def ModUOp : UTBinOp<"modu">;
7472
def ModSOp : UTBinOp<"mods">;
7573
let hasCanonicalizeMethod = true in {
74+
def DivUOp : UTBinOp<"divu">;
75+
def ModUOp : UTBinOp<"modu">;
7676
def ShlOp : UTBinOp<"shl">;
7777
def ShrUOp : UTBinOp<"shru">;
7878
def ShrSOp : UTBinOp<"shrs">;

lib/Conversion/CombToSynth/CombToSynth.cpp

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -990,20 +990,8 @@ struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
990990
matchAndRewrite(DivUOp op, OpAdaptor adaptor,
991991
ConversionPatternRewriter &rewriter) const override {
992992
// Check if the divisor is a power of two.
993-
if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
994-
if (rhsConstantOp.getValue().isPowerOf2()) {
995-
// Extract upper bits.
996-
size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
997-
size_t width = op.getType().getIntOrFloatBitWidth();
998-
Value upperBits = rewriter.createOrFold<comb::ExtractOp>(
999-
op.getLoc(), adaptor.getLhs(), extractAmount,
1000-
width - extractAmount);
1001-
Value constZero = hw::ConstantOp::create(rewriter, op.getLoc(),
1002-
APInt::getZero(extractAmount));
1003-
replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(
1004-
rewriter, op, op.getType(), ArrayRef<Value>{constZero, upperBits});
1005-
return success();
1006-
}
993+
if (llvm::succeeded(comb::convertDivUByPowerOfTwo(op, rewriter)))
994+
return success();
1007995

1008996
// When rhs is not power of two and the number of unknown bits are small,
1009997
// create a mux tree that emulates all possible cases.
@@ -1024,19 +1012,8 @@ struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
10241012
matchAndRewrite(ModUOp op, OpAdaptor adaptor,
10251013
ConversionPatternRewriter &rewriter) const override {
10261014
// Check if the divisor is a power of two.
1027-
if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
1028-
if (rhsConstantOp.getValue().isPowerOf2()) {
1029-
// Extract lower bits.
1030-
size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
1031-
size_t width = op.getType().getIntOrFloatBitWidth();
1032-
Value lowerBits = rewriter.createOrFold<comb::ExtractOp>(
1033-
op.getLoc(), adaptor.getLhs(), 0, extractAmount);
1034-
Value constZero = hw::ConstantOp::create(
1035-
rewriter, op.getLoc(), APInt::getZero(width - extractAmount));
1036-
replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(
1037-
rewriter, op, op.getType(), ArrayRef<Value>{constZero, lowerBits});
1038-
return success();
1039-
}
1015+
if (llvm::succeeded(comb::convertModUByPowerOfTwo(op, rewriter)))
1016+
return success();
10401017

10411018
// When rhs is not power of two and the number of unknown bits are small,
10421019
// create a mux tree that emulates all possible cases.

lib/Dialect/Comb/CombFolds.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,6 +1697,20 @@ OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
16971697
return {};
16981698
return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
16991699
}
1700+
1701+
LogicalResult DivUOp::canonicalize(DivUOp op, PatternRewriter &rewriter) {
1702+
if (isOpTriviallyRecursive(op) || !op.getTwoState())
1703+
return failure();
1704+
return convertDivUByPowerOfTwo(op, rewriter);
1705+
}
1706+
1707+
LogicalResult ModUOp::canonicalize(ModUOp op, PatternRewriter &rewriter) {
1708+
if (isOpTriviallyRecursive(op) || !op.getTwoState())
1709+
return failure();
1710+
1711+
return convertModUByPowerOfTwo(op, rewriter);
1712+
}
1713+
17001714
//===----------------------------------------------------------------------===//
17011715
// ConcatOp
17021716
//===----------------------------------------------------------------------===//

lib/Dialect/Comb/CombOps.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,72 @@ llvm::LogicalResult comb::convertSubToAdd(comb::SubOp subOp,
234234
return success();
235235
}
236236

237+
static llvm::LogicalResult convertDivModUByPowerOfTwo(PatternRewriter &rewriter,
238+
Operation *op, Value lhs,
239+
Value rhs, bool isDiv) {
240+
// Check if the divisor is a power of two constant.
241+
auto rhsConstantOp = rhs.getDefiningOp<hw::ConstantOp>();
242+
if (!rhsConstantOp)
243+
return failure();
244+
245+
APInt rhsValue = rhsConstantOp.getValue();
246+
if (!rhsValue.isPowerOf2())
247+
return failure();
248+
249+
Location loc = op->getLoc();
250+
251+
unsigned width = lhs.getType().getIntOrFloatBitWidth();
252+
unsigned bitPosition = rhsValue.ceilLogBase2();
253+
254+
if (isDiv) {
255+
// divu(x, 2^n) -> concat(0...0, extract(x, n, width-n))
256+
// This is equivalent to a right shift by n bits.
257+
258+
// Extract the upper bits (equivalent to right shift).
259+
Value upperBits = rewriter.createOrFold<comb::ExtractOp>(
260+
loc, lhs, bitPosition, width - bitPosition);
261+
262+
// Concatenate with zeros on the left.
263+
Value zeros =
264+
hw::ConstantOp::create(rewriter, loc, APInt::getZero(bitPosition));
265+
266+
// use replaceOpWithNewOpAndCopyNamehint?
267+
replaceOpAndCopyNamehint(
268+
rewriter, op,
269+
comb::ConcatOp::create(rewriter, loc,
270+
ArrayRef<Value>{zeros, upperBits}));
271+
return success();
272+
}
273+
274+
// modu(x, 2^n) -> concat(0...0, extract(x, 0, n))
275+
// This extracts the lower n bits (equivalent to bitwise AND with 2^n - 1).
276+
277+
// Extract the lower bits.
278+
Value lowerBits =
279+
rewriter.createOrFold<comb::ExtractOp>(loc, lhs, 0, bitPosition);
280+
281+
// Concatenate with zeros on the left.
282+
Value zeros = hw::ConstantOp::create(rewriter, loc,
283+
APInt::getZero(width - bitPosition));
284+
285+
replaceOpAndCopyNamehint(
286+
rewriter, op,
287+
comb::ConcatOp::create(rewriter, loc, ArrayRef<Value>{zeros, lowerBits}));
288+
return success();
289+
}
290+
291+
LogicalResult comb::convertDivUByPowerOfTwo(DivUOp divOp,
292+
mlir::PatternRewriter &rewriter) {
293+
return convertDivModUByPowerOfTwo(rewriter, divOp, divOp.getLhs(),
294+
divOp.getRhs(), /*isDiv=*/true);
295+
}
296+
297+
LogicalResult comb::convertModUByPowerOfTwo(ModUOp modOp,
298+
mlir::PatternRewriter &rewriter) {
299+
return convertDivModUByPowerOfTwo(rewriter, modOp, modOp.getLhs(),
300+
modOp.getRhs(), /*isDiv=*/false);
301+
}
302+
237303
//===----------------------------------------------------------------------===//
238304
// ICmpOp
239305
//===----------------------------------------------------------------------===//

lib/Dialect/Datapath/DatapathFolds.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "circt/Dialect/Comb/CombOps.h"
10+
#include "circt/Dialect/Datapath/DatapathDialect.h"
1011
#include "circt/Dialect/Datapath/DatapathOps.h"
1112
#include "circt/Dialect/HW/HWOps.h"
1213
#include "mlir/IR/Matchers.h"
1314
#include "mlir/IR/PatternMatch.h"
15+
#include "llvm/Support/Casting.h"
1416
#include "llvm/Support/KnownBits.h"
1517
#include <algorithm>
1618

@@ -117,6 +119,9 @@ struct FoldAddIntoCompress : public OpRewritePattern<comb::AddOp> {
117119
using OpRewritePattern::OpRewritePattern;
118120

119121
// add(compress(a,b,c),d) -> add(compress(a,b,c,d))
122+
// FIXME: This should be implemented as a canonicalization pattern for
123+
// compress op. Currently `hasDatapathOperand` flag prevents introducing
124+
// datapath operations from comb operations.
120125
LogicalResult matchAndRewrite(comb::AddOp addOp,
121126
PatternRewriter &rewriter) const override {
122127
// comb.add canonicalization patterns handle folding add operations
@@ -128,15 +133,23 @@ struct FoldAddIntoCompress : public OpRewritePattern<comb::AddOp> {
128133
llvm::SmallSetVector<Value, 8> processedCompressorResults;
129134
SmallVector<Value, 8> newCompressOperands;
130135
// Only construct compressor if can form a larger compressor than what
131-
// is currently an input of this add
132-
bool shouldFold = false;
136+
// is currently an input of this add. Also check that there is at least
137+
// one datapath operand.
138+
bool shouldFold = false, hasDatapathOperand = false;
133139

134140
for (Value operand : operands) {
135141

136142
// Skip if already processed this compressor
137143
if (processedCompressorResults.contains(operand))
138144
continue;
139145

146+
if (auto *op = operand.getDefiningOp()) {
147+
if (llvm::isa_and_nonnull<datapath::DatapathDialect>(
148+
op->getDialect())) {
149+
hasDatapathOperand = true;
150+
}
151+
}
152+
140153
// If the operand has multiple uses, we do not fold it into a compress
141154
// operation, so we treat it as a regular operand.
142155
if (!operand.hasOneUse()) {
@@ -174,7 +187,7 @@ struct FoldAddIntoCompress : public OpRewritePattern<comb::AddOp> {
174187

175188
// Only fold if we have constructed a larger compressor than what was
176189
// already there
177-
if (!shouldFold)
190+
if (!shouldFold || !hasDatapathOperand)
178191
return failure();
179192

180193
// Create a new CompressOp with all collected operands

lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ void circt::synth::buildCombLoweringPipeline(
5959
circt::ConvertDatapathToCombOptions datapathOptions;
6060
datapathOptions.timingAware = options.timingAware;
6161
pm.addPass(createConvertDatapathToComb(datapathOptions));
62-
pm.addPass(createSimpleCanonicalizerPass());
6362
}
63+
pm.addPass(createCSEPass());
64+
pm.addPass(createSimpleCanonicalizerPass());
6465
// Partially legalize Comb, then run CSE and canonicalization.
6566
circt::ConvertCombToSynthOptions convOptions;
6667
addOpName<comb::AndOp, comb::OrOp, comb::XorOp, comb::MuxOp, comb::ICmpOp,

test/Dialect/Comb/canonicalization.mlir

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,71 @@ hw.module @moduloZeroDividend(in %arg0 : i32, out o1: i32, out o2: i32) {
12401240
hw.output %0, %1 : i32, i32
12411241
}
12421242

1243+
// CHECK-LABEL: hw.module @divuPowerOfTwo
1244+
hw.module @divuPowerOfTwo(in %arg0 : i8, out o1: i8, out o2: i8, out o3: i8, out o4: i8, out o5: i8) {
1245+
// divu(x, 2) -> concat(0, extract(x, 1, 7))
1246+
// CHECK: [[EXT1:%.+]] = comb.extract %arg0 from 1 : (i8) -> i7
1247+
// CHECK-NEXT: [[RES1:%.+]] = comb.concat %false, [[EXT1]] : i1, i7
1248+
%c2 = hw.constant 2 : i8
1249+
%0 = comb.divu bin %arg0, %c2 : i8
1250+
1251+
// divu(x, 4) -> concat(00, extract(x, 2, 6))
1252+
// CHECK-NEXT: [[EXT2:%.+]] = comb.extract %arg0 from 2 : (i8) -> i6
1253+
// CHECK-NEXT: [[RES2:%.+]] = comb.concat %c0_i2, [[EXT2]] : i2, i6
1254+
%c4 = hw.constant 4 : i8
1255+
%1 = comb.divu bin %arg0, %c4 : i8
1256+
1257+
// divu(x, 16) -> concat(0000, extract(x, 4, 4))
1258+
// CHECK-NEXT: [[EXT4:%.+]] = comb.extract %arg0 from 4 : (i8) -> i4
1259+
// CHECK-NEXT: [[RES3:%.+]] = comb.concat %c0_i4, [[EXT4]] : i4, i4
1260+
%c16 = hw.constant 16 : i8
1261+
%2 = comb.divu bin %arg0, %c16 : i8
1262+
1263+
// divu(x, 3) -> not canonicalized (not power of two)
1264+
// CHECK-NEXT: [[RES4:%.+]] = comb.divu bin %arg0, %c3_i8 : i8
1265+
%c3 = hw.constant 3 : i8
1266+
%3 = comb.divu bin %arg0, %c3 : i8
1267+
1268+
// Make sure canonicalization does not happen if there is no bin flag.
1269+
// CHECK-NEXT: [[RES5:%.+]] = comb.divu %arg0, %c2_i8 : i8
1270+
%4 = comb.divu %arg0, %c2 : i8
1271+
1272+
// CHECK: hw.output [[RES1]], [[RES2]], [[RES3]], [[RES4]], [[RES5]]
1273+
hw.output %0, %1, %2, %3, %4 : i8, i8, i8, i8, i8
1274+
}
1275+
1276+
// CHECK-LABEL: hw.module @moduPowerOfTwo
1277+
hw.module @moduPowerOfTwo(in %arg0 : i8, out o1: i8, out o2: i8, out o3: i8, out o4: i8, out o5: i8) {
1278+
// modu(x, 2) -> concat(0000000, extract(x, 0, 1))
1279+
// CHECK: [[EXT1:%.+]] = comb.extract %arg0 from 0 : (i8) -> i1
1280+
// CHECK-NEXT: [[RES1:%.+]] = comb.concat %c0_i7, [[EXT1]] : i7, i1
1281+
%c2 = hw.constant 2 : i8
1282+
%0 = comb.modu bin %arg0, %c2 : i8
1283+
1284+
// modu(x, 4) -> concat(000000, extract(x, 0, 2))
1285+
// CHECK-NEXT: [[EXT2:%.+]] = comb.extract %arg0 from 0 : (i8) -> i2
1286+
// CHECK-NEXT: [[RES2:%.+]] = comb.concat %c0_i6, [[EXT2]] : i6, i2
1287+
%c4 = hw.constant 4 : i8
1288+
%1 = comb.modu bin %arg0, %c4 : i8
1289+
1290+
// modu(x, 16) -> concat(0000, extract(x, 0, 4))
1291+
// CHECK-NEXT: [[EXT4:%.+]] = comb.extract %arg0 from 0 : (i8) -> i4
1292+
// CHECK-NEXT: [[RES3:%.+]] = comb.concat %c0_i4, [[EXT4]] : i4, i4
1293+
%c16 = hw.constant 16 : i8
1294+
%2 = comb.modu bin %arg0, %c16 : i8
1295+
1296+
// modu(x, 3) -> not canonicalized (not power of two)
1297+
// CHECK-NEXT: [[RES4:%.+]] = comb.modu bin %arg0, %c3_i8 : i8
1298+
%c3 = hw.constant 3 : i8
1299+
%3 = comb.modu bin %arg0, %c3 : i8
1300+
1301+
// Make sure canonicalization does not happen if there is no bin flag.
1302+
// CHECK-NEXT: [[RES5:%.+]] = comb.modu %arg0, %c2_i8 : i8
1303+
%4 = comb.modu %arg0, %c2 : i8
1304+
// CHECK: hw.output [[RES1]], [[RES2]], [[RES3]], [[RES4]], [[RES5]]
1305+
hw.output %0, %1, %2, %3, %4 : i8, i8, i8, i8, i8
1306+
}
1307+
12431308
// CHECK-LABEL: hw.module @orWithNegation
12441309
hw.module @orWithNegation(in %arg0 : i32, out o1: i32) {
12451310
// CHECK: [[ALLONES:%.*]] = hw.constant -1 : i32

test/Dialect/Datapath/canonicalization.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,13 @@ hw.module @pos_partial_product_do_nothing(in %a : i4, in %b : i4, in %c : i4, ou
115115
%2:4 = datapath.partial_product %0, %1 : (i4, i4) -> (i4, i4, i4, i4)
116116
hw.output %2#0, %2#1, %2#2, %2#3 : i4, i4, i4, i4
117117
}
118+
119+
// CHECK-LABEL: @dont_introduce_compressor
120+
hw.module @dont_introduce_compressor(in %a : i4, in %b : i4, in %c: i4, out sum : i4) {
121+
// CHECK-NOT: datapath.compress
122+
// CHECK-NEXT: comb.add
123+
// CHECK-NEXT: hw.output
124+
%0:4 = datapath.partial_product %a, %b : (i4, i4) -> (i4, i4, i4, i4)
125+
%1 = comb.add %a, %b, %c : i4
126+
hw.output %1 : i4
127+
}

0 commit comments

Comments
 (0)