From 5a8381c0af4998017b363542a081c3cf56638e41 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Mon, 3 Nov 2025 13:33:44 -0800 Subject: [PATCH 1/2] [Comb] Canonicalize power-of-two unsigned div/mod. --- include/circt/Dialect/Comb/CombOps.h | 9 +++ include/circt/Dialect/Comb/Combinational.td | 4 +- lib/Conversion/CombToSynth/CombToSynth.cpp | 31 ++------- lib/Dialect/Comb/CombFolds.cpp | 15 +++++ lib/Dialect/Comb/CombOps.cpp | 66 +++++++++++++++++++ .../Synth/Transforms/SynthesisPipeline.cpp | 3 +- test/Dialect/Comb/canonicalization.mlir | 58 ++++++++++++++++ 7 files changed, 156 insertions(+), 30 deletions(-) diff --git a/include/circt/Dialect/Comb/CombOps.h b/include/circt/Dialect/Comb/CombOps.h index 77681f1c8a3d..ea2edfc240b4 100644 --- a/include/circt/Dialect/Comb/CombOps.h +++ b/include/circt/Dialect/Comb/CombOps.h @@ -90,6 +90,15 @@ Value createInject(OpBuilder &builder, Location loc, Value value, LogicalResult convertSubToAdd(comb::SubOp subOp, mlir::PatternRewriter &rewriter); +/// Convert unsigned division or modulo by a power of two. +/// For division: divu(x, 2^n) -> concat(0...0, extract(x, n, width-n)). +/// For modulo: modu(x, 2^n) -> concat(0...0, extract(x, 0, n)) +/// TODO: Support signed division and modulo. +LogicalResult convertDivUByPowerOfTwo(DivUOp divOp, + mlir::PatternRewriter &rewriter); +LogicalResult convertModUByPowerOfTwo(ModUOp modOp, + mlir::PatternRewriter &rewriter); + /// Enum for mux chain folding styles. enum MuxChainWithComparisonFoldingStyle { None, BalancedMuxTree, ArrayGet }; /// Mux chain folding that converts chains of muxes with index diff --git a/include/circt/Dialect/Comb/Combinational.td b/include/circt/Dialect/Comb/Combinational.td index 9362371925b2..2069dd3581e5 100644 --- a/include/circt/Dialect/Comb/Combinational.td +++ b/include/circt/Dialect/Comb/Combinational.td @@ -68,11 +68,11 @@ class UTVariadicOp traits = []> : def AddOp : UTVariadicOp<"add", [Commutative]>; def MulOp : UTVariadicOp<"mul", [Commutative]>; let hasFolder = true in { - def DivUOp : UTBinOp<"divu">; def DivSOp : UTBinOp<"divs">; - def ModUOp : UTBinOp<"modu">; def ModSOp : UTBinOp<"mods">; let hasCanonicalizeMethod = true in { + def DivUOp : UTBinOp<"divu">; + def ModUOp : UTBinOp<"modu">; def ShlOp : UTBinOp<"shl">; def ShrUOp : UTBinOp<"shru">; def ShrSOp : UTBinOp<"shrs">; diff --git a/lib/Conversion/CombToSynth/CombToSynth.cpp b/lib/Conversion/CombToSynth/CombToSynth.cpp index f2925266f221..31d6582f8b36 100644 --- a/lib/Conversion/CombToSynth/CombToSynth.cpp +++ b/lib/Conversion/CombToSynth/CombToSynth.cpp @@ -990,20 +990,8 @@ struct CombDivUOpConversion : DivModOpConversionBase { matchAndRewrite(DivUOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if the divisor is a power of two. - if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp()) - if (rhsConstantOp.getValue().isPowerOf2()) { - // Extract upper bits. - size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2(); - size_t width = op.getType().getIntOrFloatBitWidth(); - Value upperBits = rewriter.createOrFold( - op.getLoc(), adaptor.getLhs(), extractAmount, - width - extractAmount); - Value constZero = hw::ConstantOp::create(rewriter, op.getLoc(), - APInt::getZero(extractAmount)); - replaceOpWithNewOpAndCopyNamehint( - rewriter, op, op.getType(), ArrayRef{constZero, upperBits}); - return success(); - } + if (llvm::succeeded(comb::convertDivUByPowerOfTwo(op, rewriter))) + return success(); // When rhs is not power of two and the number of unknown bits are small, // create a mux tree that emulates all possible cases. @@ -1024,19 +1012,8 @@ struct CombModUOpConversion : DivModOpConversionBase { matchAndRewrite(ModUOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if the divisor is a power of two. - if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp()) - if (rhsConstantOp.getValue().isPowerOf2()) { - // Extract lower bits. - size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2(); - size_t width = op.getType().getIntOrFloatBitWidth(); - Value lowerBits = rewriter.createOrFold( - op.getLoc(), adaptor.getLhs(), 0, extractAmount); - Value constZero = hw::ConstantOp::create( - rewriter, op.getLoc(), APInt::getZero(width - extractAmount)); - replaceOpWithNewOpAndCopyNamehint( - rewriter, op, op.getType(), ArrayRef{constZero, lowerBits}); - return success(); - } + if (llvm::succeeded(comb::convertModUByPowerOfTwo(op, rewriter))) + return success(); // When rhs is not power of two and the number of unknown bits are small, // create a mux tree that emulates all possible cases. diff --git a/lib/Dialect/Comb/CombFolds.cpp b/lib/Dialect/Comb/CombFolds.cpp index 4f66b0657c89..8df08707e41a 100644 --- a/lib/Dialect/Comb/CombFolds.cpp +++ b/lib/Dialect/Comb/CombFolds.cpp @@ -1697,6 +1697,21 @@ OpFoldResult ModSOp::fold(FoldAdaptor adaptor) { return {}; return foldMod(*this, adaptor.getOperands()); } + +LogicalResult DivUOp::canonicalize(DivUOp op, PatternRewriter &rewriter) { + if (isOpTriviallyRecursive(op)) + return failure(); + + return convertDivUByPowerOfTwo(op, rewriter); +} + +LogicalResult ModUOp::canonicalize(ModUOp op, PatternRewriter &rewriter) { + if (isOpTriviallyRecursive(op)) + return failure(); + + return convertModUByPowerOfTwo(op, rewriter); +} + //===----------------------------------------------------------------------===// // ConcatOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Comb/CombOps.cpp b/lib/Dialect/Comb/CombOps.cpp index 3fd5a9547b46..e0638797bfcd 100644 --- a/lib/Dialect/Comb/CombOps.cpp +++ b/lib/Dialect/Comb/CombOps.cpp @@ -234,6 +234,72 @@ llvm::LogicalResult comb::convertSubToAdd(comb::SubOp subOp, return success(); } +static llvm::LogicalResult convertDivModUByPowerOfTwo(PatternRewriter &rewriter, + Operation *op, Value lhs, + Value rhs, bool isDiv) { + // Check if the divisor is a power of two constant. + auto rhsConstantOp = rhs.getDefiningOp(); + if (!rhsConstantOp) + return failure(); + + APInt rhsValue = rhsConstantOp.getValue(); + if (!rhsValue.isPowerOf2()) + return failure(); + + Location loc = op->getLoc(); + + unsigned width = lhs.getType().getIntOrFloatBitWidth(); + unsigned bitPosition = rhsValue.ceilLogBase2(); + + if (isDiv) { + // divu(x, 2^n) -> concat(0...0, extract(x, n, width-n)) + // This is equivalent to a right shift by n bits. + + // Extract the upper bits (equivalent to right shift). + Value upperBits = rewriter.createOrFold( + loc, lhs, bitPosition, width - bitPosition); + + // Concatenate with zeros on the left. + Value zeros = + hw::ConstantOp::create(rewriter, loc, APInt::getZero(bitPosition)); + + // use replaceOpWithNewOpAndCopyNamehint? + replaceOpAndCopyNamehint( + rewriter, op, + comb::ConcatOp::create(rewriter, loc, + ArrayRef{zeros, upperBits})); + return success(); + } + + // modu(x, 2^n) -> concat(0...0, extract(x, 0, n)) + // This extracts the lower n bits (equivalent to bitwise AND with 2^n - 1). + + // Extract the lower bits. + Value lowerBits = + rewriter.createOrFold(loc, lhs, 0, bitPosition); + + // Concatenate with zeros on the left. + Value zeros = hw::ConstantOp::create(rewriter, loc, + APInt::getZero(width - bitPosition)); + + replaceOpAndCopyNamehint( + rewriter, op, + comb::ConcatOp::create(rewriter, loc, ArrayRef{zeros, lowerBits})); + return success(); +} + +LogicalResult comb::convertDivUByPowerOfTwo(DivUOp divOp, + mlir::PatternRewriter &rewriter) { + return convertDivModUByPowerOfTwo(rewriter, divOp, divOp.getLhs(), + divOp.getRhs(), /*isDiv=*/true); +} + +LogicalResult comb::convertModUByPowerOfTwo(ModUOp modOp, + mlir::PatternRewriter &rewriter) { + return convertDivModUByPowerOfTwo(rewriter, modOp, modOp.getLhs(), + modOp.getRhs(), /*isDiv=*/false); +} + //===----------------------------------------------------------------------===// // ICmpOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp b/lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp index 3caced28a6ac..70a16959b59c 100644 --- a/lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp +++ b/lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp @@ -59,8 +59,9 @@ void circt::synth::buildCombLoweringPipeline( circt::ConvertDatapathToCombOptions datapathOptions; datapathOptions.timingAware = options.timingAware; pm.addPass(createConvertDatapathToComb(datapathOptions)); - pm.addPass(createSimpleCanonicalizerPass()); } + pm.addPass(createCSEPass()); + pm.addPass(createSimpleCanonicalizerPass()); // Partially legalize Comb, then run CSE and canonicalization. circt::ConvertCombToSynthOptions convOptions; addOpName concat(0, extract(x, 1, 7)) + // CHECK: [[EXT1:%.+]] = comb.extract %arg0 from 1 : (i8) -> i7 + // CHECK-NEXT: [[RES1:%.+]] = comb.concat %false, [[EXT1]] : i1, i7 + %c2 = hw.constant 2 : i8 + %0 = comb.divu %arg0, %c2 : i8 + + // divu(x, 4) -> concat(00, extract(x, 2, 6)) + // CHECK-NEXT: [[EXT2:%.+]] = comb.extract %arg0 from 2 : (i8) -> i6 + // CHECK-NEXT: [[RES2:%.+]] = comb.concat %c0_i2, [[EXT2]] : i2, i6 + %c4 = hw.constant 4 : i8 + %1 = comb.divu %arg0, %c4 : i8 + + // divu(x, 16) -> concat(0000, extract(x, 4, 4)) + // CHECK-NEXT: [[EXT4:%.+]] = comb.extract %arg0 from 4 : (i8) -> i4 + // CHECK-NEXT: [[RES4:%.+]] = comb.concat %c0_i4, [[EXT4]] : i4, i4 + %c16 = hw.constant 16 : i8 + %2 = comb.divu %arg0, %c16 : i8 + + // divu(x, 3) -> not canonicalized (not power of two) + // CHECK-NEXT: [[RES3:%.+]] = comb.divu %arg0, %c3_i8 : i8 + %c3 = hw.constant 3 : i8 + %3 = comb.divu %arg0, %c3 : i8 + + // CHECK: hw.output [[RES1]], [[RES2]], [[RES4]], [[RES3]] + hw.output %0, %1, %2, %3 : i8, i8, i8, i8 +} + +// CHECK-LABEL: hw.module @moduPowerOfTwo +hw.module @moduPowerOfTwo(in %arg0 : i8, out o1: i8, out o2: i8, out o3: i8, out o4: i8) { + // modu(x, 2) -> concat(0000000, extract(x, 0, 1)) + // CHECK: [[EXT1:%.+]] = comb.extract %arg0 from 0 : (i8) -> i1 + // CHECK-NEXT: [[RES1:%.+]] = comb.concat %c0_i7, [[EXT1]] : i7, i1 + %c2 = hw.constant 2 : i8 + %0 = comb.modu %arg0, %c2 : i8 + + // modu(x, 4) -> concat(000000, extract(x, 0, 2)) + // CHECK-NEXT: [[EXT2:%.+]] = comb.extract %arg0 from 0 : (i8) -> i2 + // CHECK-NEXT: [[RES2:%.+]] = comb.concat %c0_i6, [[EXT2]] : i6, i2 + %c4 = hw.constant 4 : i8 + %1 = comb.modu %arg0, %c4 : i8 + + // modu(x, 16) -> concat(0000, extract(x, 0, 4)) + // CHECK-NEXT: [[EXT4:%.+]] = comb.extract %arg0 from 0 : (i8) -> i4 + // CHECK-NEXT: [[RES4:%.+]] = comb.concat %c0_i4, [[EXT4]] : i4, i4 + %c16 = hw.constant 16 : i8 + %2 = comb.modu %arg0, %c16 : i8 + + // modu(x, 3) -> not canonicalized (not power of two) + // CHECK-NEXT: [[RES3:%.+]] = comb.modu %arg0, %c3_i8 : i8 + %c3 = hw.constant 3 : i8 + %3 = comb.modu %arg0, %c3 : i8 + + // CHECK: hw.output [[RES1]], [[RES2]], [[RES4]], [[RES3]] + hw.output %0, %1, %2, %3 : i8, i8, i8, i8 +} + // CHECK-LABEL: hw.module @orWithNegation hw.module @orWithNegation(in %arg0 : i32, out o1: i32) { // CHECK: [[ALLONES:%.*]] = hw.constant -1 : i32 From f45dc7f4f0c28fd3f28d3f55f14852a951613843 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Mon, 3 Nov 2025 15:38:48 -0800 Subject: [PATCH 2/2] Add bin flag --- lib/Dialect/Comb/CombFolds.cpp | 5 +-- lib/Dialect/Datapath/DatapathFolds.cpp | 19 +++++++-- test/Dialect/Comb/canonicalization.mlir | 43 ++++++++++++--------- test/Dialect/Datapath/canonicalization.mlir | 10 +++++ 4 files changed, 53 insertions(+), 24 deletions(-) diff --git a/lib/Dialect/Comb/CombFolds.cpp b/lib/Dialect/Comb/CombFolds.cpp index 8df08707e41a..8244bb62c935 100644 --- a/lib/Dialect/Comb/CombFolds.cpp +++ b/lib/Dialect/Comb/CombFolds.cpp @@ -1699,14 +1699,13 @@ OpFoldResult ModSOp::fold(FoldAdaptor adaptor) { } LogicalResult DivUOp::canonicalize(DivUOp op, PatternRewriter &rewriter) { - if (isOpTriviallyRecursive(op)) + if (isOpTriviallyRecursive(op) || !op.getTwoState()) return failure(); - return convertDivUByPowerOfTwo(op, rewriter); } LogicalResult ModUOp::canonicalize(ModUOp op, PatternRewriter &rewriter) { - if (isOpTriviallyRecursive(op)) + if (isOpTriviallyRecursive(op) || !op.getTwoState()) return failure(); return convertModUByPowerOfTwo(op, rewriter); diff --git a/lib/Dialect/Datapath/DatapathFolds.cpp b/lib/Dialect/Datapath/DatapathFolds.cpp index 55bf98ddded9..51b7db95a920 100644 --- a/lib/Dialect/Datapath/DatapathFolds.cpp +++ b/lib/Dialect/Datapath/DatapathFolds.cpp @@ -7,10 +7,12 @@ //===----------------------------------------------------------------------===// #include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/Datapath/DatapathDialect.h" #include "circt/Dialect/Datapath/DatapathOps.h" #include "circt/Dialect/HW/HWOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/KnownBits.h" #include @@ -117,6 +119,9 @@ struct FoldAddIntoCompress : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; // add(compress(a,b,c),d) -> add(compress(a,b,c,d)) + // FIXME: This should be implemented as a canonicalization pattern for + // compress op. Currently `hasDatapathOperand` flag prevents introducing + // datapath operations from comb operations. LogicalResult matchAndRewrite(comb::AddOp addOp, PatternRewriter &rewriter) const override { // comb.add canonicalization patterns handle folding add operations @@ -128,8 +133,9 @@ struct FoldAddIntoCompress : public OpRewritePattern { llvm::SmallSetVector processedCompressorResults; SmallVector newCompressOperands; // Only construct compressor if can form a larger compressor than what - // is currently an input of this add - bool shouldFold = false; + // is currently an input of this add. Also check that there is at least + // one datapath operand. + bool shouldFold = false, hasDatapathOperand = false; for (Value operand : operands) { @@ -137,6 +143,13 @@ struct FoldAddIntoCompress : public OpRewritePattern { if (processedCompressorResults.contains(operand)) continue; + if (auto *op = operand.getDefiningOp()) { + if (llvm::isa_and_nonnull( + op->getDialect())) { + hasDatapathOperand = true; + } + } + // If the operand has multiple uses, we do not fold it into a compress // operation, so we treat it as a regular operand. if (!operand.hasOneUse()) { @@ -174,7 +187,7 @@ struct FoldAddIntoCompress : public OpRewritePattern { // Only fold if we have constructed a larger compressor than what was // already there - if (!shouldFold) + if (!shouldFold || !hasDatapathOperand) return failure(); // Create a new CompressOp with all collected operands diff --git a/test/Dialect/Comb/canonicalization.mlir b/test/Dialect/Comb/canonicalization.mlir index 15c1912335b8..8ff6d883a3da 100644 --- a/test/Dialect/Comb/canonicalization.mlir +++ b/test/Dialect/Comb/canonicalization.mlir @@ -1241,61 +1241,68 @@ hw.module @moduloZeroDividend(in %arg0 : i32, out o1: i32, out o2: i32) { } // CHECK-LABEL: hw.module @divuPowerOfTwo -hw.module @divuPowerOfTwo(in %arg0 : i8, out o1: i8, out o2: i8, out o3: i8, out o4: i8) { +hw.module @divuPowerOfTwo(in %arg0 : i8, out o1: i8, out o2: i8, out o3: i8, out o4: i8, out o5: i8) { // divu(x, 2) -> concat(0, extract(x, 1, 7)) // CHECK: [[EXT1:%.+]] = comb.extract %arg0 from 1 : (i8) -> i7 // CHECK-NEXT: [[RES1:%.+]] = comb.concat %false, [[EXT1]] : i1, i7 %c2 = hw.constant 2 : i8 - %0 = comb.divu %arg0, %c2 : i8 + %0 = comb.divu bin %arg0, %c2 : i8 // divu(x, 4) -> concat(00, extract(x, 2, 6)) // CHECK-NEXT: [[EXT2:%.+]] = comb.extract %arg0 from 2 : (i8) -> i6 // CHECK-NEXT: [[RES2:%.+]] = comb.concat %c0_i2, [[EXT2]] : i2, i6 %c4 = hw.constant 4 : i8 - %1 = comb.divu %arg0, %c4 : i8 + %1 = comb.divu bin %arg0, %c4 : i8 // divu(x, 16) -> concat(0000, extract(x, 4, 4)) // CHECK-NEXT: [[EXT4:%.+]] = comb.extract %arg0 from 4 : (i8) -> i4 - // CHECK-NEXT: [[RES4:%.+]] = comb.concat %c0_i4, [[EXT4]] : i4, i4 + // CHECK-NEXT: [[RES3:%.+]] = comb.concat %c0_i4, [[EXT4]] : i4, i4 %c16 = hw.constant 16 : i8 - %2 = comb.divu %arg0, %c16 : i8 + %2 = comb.divu bin %arg0, %c16 : i8 // divu(x, 3) -> not canonicalized (not power of two) - // CHECK-NEXT: [[RES3:%.+]] = comb.divu %arg0, %c3_i8 : i8 + // CHECK-NEXT: [[RES4:%.+]] = comb.divu bin %arg0, %c3_i8 : i8 %c3 = hw.constant 3 : i8 - %3 = comb.divu %arg0, %c3 : i8 + %3 = comb.divu bin %arg0, %c3 : i8 - // CHECK: hw.output [[RES1]], [[RES2]], [[RES4]], [[RES3]] - hw.output %0, %1, %2, %3 : i8, i8, i8, i8 + // Make sure canonicalization does not happen if there is no bin flag. + // CHECK-NEXT: [[RES5:%.+]] = comb.divu %arg0, %c2_i8 : i8 + %4 = comb.divu %arg0, %c2 : i8 + + // CHECK: hw.output [[RES1]], [[RES2]], [[RES3]], [[RES4]], [[RES5]] + hw.output %0, %1, %2, %3, %4 : i8, i8, i8, i8, i8 } // CHECK-LABEL: hw.module @moduPowerOfTwo -hw.module @moduPowerOfTwo(in %arg0 : i8, out o1: i8, out o2: i8, out o3: i8, out o4: i8) { +hw.module @moduPowerOfTwo(in %arg0 : i8, out o1: i8, out o2: i8, out o3: i8, out o4: i8, out o5: i8) { // modu(x, 2) -> concat(0000000, extract(x, 0, 1)) // CHECK: [[EXT1:%.+]] = comb.extract %arg0 from 0 : (i8) -> i1 // CHECK-NEXT: [[RES1:%.+]] = comb.concat %c0_i7, [[EXT1]] : i7, i1 %c2 = hw.constant 2 : i8 - %0 = comb.modu %arg0, %c2 : i8 + %0 = comb.modu bin %arg0, %c2 : i8 // modu(x, 4) -> concat(000000, extract(x, 0, 2)) // CHECK-NEXT: [[EXT2:%.+]] = comb.extract %arg0 from 0 : (i8) -> i2 // CHECK-NEXT: [[RES2:%.+]] = comb.concat %c0_i6, [[EXT2]] : i6, i2 %c4 = hw.constant 4 : i8 - %1 = comb.modu %arg0, %c4 : i8 + %1 = comb.modu bin %arg0, %c4 : i8 // modu(x, 16) -> concat(0000, extract(x, 0, 4)) // CHECK-NEXT: [[EXT4:%.+]] = comb.extract %arg0 from 0 : (i8) -> i4 - // CHECK-NEXT: [[RES4:%.+]] = comb.concat %c0_i4, [[EXT4]] : i4, i4 + // CHECK-NEXT: [[RES3:%.+]] = comb.concat %c0_i4, [[EXT4]] : i4, i4 %c16 = hw.constant 16 : i8 - %2 = comb.modu %arg0, %c16 : i8 + %2 = comb.modu bin %arg0, %c16 : i8 // modu(x, 3) -> not canonicalized (not power of two) - // CHECK-NEXT: [[RES3:%.+]] = comb.modu %arg0, %c3_i8 : i8 + // CHECK-NEXT: [[RES4:%.+]] = comb.modu bin %arg0, %c3_i8 : i8 %c3 = hw.constant 3 : i8 - %3 = comb.modu %arg0, %c3 : i8 + %3 = comb.modu bin %arg0, %c3 : i8 - // CHECK: hw.output [[RES1]], [[RES2]], [[RES4]], [[RES3]] - hw.output %0, %1, %2, %3 : i8, i8, i8, i8 + // Make sure canonicalization does not happen if there is no bin flag. + // CHECK-NEXT: [[RES5:%.+]] = comb.modu %arg0, %c2_i8 : i8 + %4 = comb.modu %arg0, %c2 : i8 + // CHECK: hw.output [[RES1]], [[RES2]], [[RES3]], [[RES4]], [[RES5]] + hw.output %0, %1, %2, %3, %4 : i8, i8, i8, i8, i8 } // CHECK-LABEL: hw.module @orWithNegation diff --git a/test/Dialect/Datapath/canonicalization.mlir b/test/Dialect/Datapath/canonicalization.mlir index b25534446b67..20160bbfa9aa 100644 --- a/test/Dialect/Datapath/canonicalization.mlir +++ b/test/Dialect/Datapath/canonicalization.mlir @@ -115,3 +115,13 @@ hw.module @pos_partial_product_do_nothing(in %a : i4, in %b : i4, in %c : i4, ou %2:4 = datapath.partial_product %0, %1 : (i4, i4) -> (i4, i4, i4, i4) hw.output %2#0, %2#1, %2#2, %2#3 : i4, i4, i4, i4 } + +// CHECK-LABEL: @dont_introduce_compressor +hw.module @dont_introduce_compressor(in %a : i4, in %b : i4, in %c: i4, out sum : i4) { + // CHECK-NOT: datapath.compress + // CHECK-NEXT: comb.add + // CHECK-NEXT: hw.output + %0:4 = datapath.partial_product %a, %b : (i4, i4) -> (i4, i4, i4, i4) + %1 = comb.add %a, %b, %c : i4 + hw.output %1 : i4 +}