Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/circt/Dialect/Comb/CombOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ Value createInject(OpBuilder &builder, Location loc, Value value,
LogicalResult convertSubToAdd(comb::SubOp subOp,
mlir::PatternRewriter &rewriter);

/// Convert unsigned division or modulo by a power of two.
/// For division: divu(x, 2^n) -> concat(0...0, extract(x, n, width-n)).
/// For modulo: modu(x, 2^n) -> concat(0...0, extract(x, 0, n))
/// TODO: Support signed division and modulo.
LogicalResult convertDivUByPowerOfTwo(DivUOp divOp,
mlir::PatternRewriter &rewriter);
LogicalResult convertModUByPowerOfTwo(ModUOp modOp,
mlir::PatternRewriter &rewriter);

/// Enum for mux chain folding styles.
enum MuxChainWithComparisonFoldingStyle { None, BalancedMuxTree, ArrayGet };
/// Mux chain folding that converts chains of muxes with index
Expand Down
4 changes: 2 additions & 2 deletions include/circt/Dialect/Comb/Combinational.td
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ class UTVariadicOp<string mnemonic, list<Trait> traits = []> :
def AddOp : UTVariadicOp<"add", [Commutative]>;
def MulOp : UTVariadicOp<"mul", [Commutative]>;
let hasFolder = true in {
def DivUOp : UTBinOp<"divu">;
def DivSOp : UTBinOp<"divs">;
def ModUOp : UTBinOp<"modu">;
def ModSOp : UTBinOp<"mods">;
let hasCanonicalizeMethod = true in {
def DivUOp : UTBinOp<"divu">;
def ModUOp : UTBinOp<"modu">;
def ShlOp : UTBinOp<"shl">;
def ShrUOp : UTBinOp<"shru">;
def ShrSOp : UTBinOp<"shrs">;
Expand Down
31 changes: 4 additions & 27 deletions lib/Conversion/CombToSynth/CombToSynth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -990,20 +990,8 @@ struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
matchAndRewrite(DivUOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if the divisor is a power of two.
if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
if (rhsConstantOp.getValue().isPowerOf2()) {
// Extract upper bits.
size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
size_t width = op.getType().getIntOrFloatBitWidth();
Value upperBits = rewriter.createOrFold<comb::ExtractOp>(
op.getLoc(), adaptor.getLhs(), extractAmount,
width - extractAmount);
Value constZero = hw::ConstantOp::create(rewriter, op.getLoc(),
APInt::getZero(extractAmount));
replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(
rewriter, op, op.getType(), ArrayRef<Value>{constZero, upperBits});
return success();
}
Comment on lines -993 to -1006
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to have this in the canonicalizer! 🥳

if (llvm::succeeded(comb::convertDivUByPowerOfTwo(op, rewriter)))
return success();

// When rhs is not power of two and the number of unknown bits are small,
// create a mux tree that emulates all possible cases.
Expand All @@ -1024,19 +1012,8 @@ struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
matchAndRewrite(ModUOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if the divisor is a power of two.
if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
if (rhsConstantOp.getValue().isPowerOf2()) {
// Extract lower bits.
size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
size_t width = op.getType().getIntOrFloatBitWidth();
Value lowerBits = rewriter.createOrFold<comb::ExtractOp>(
op.getLoc(), adaptor.getLhs(), 0, extractAmount);
Value constZero = hw::ConstantOp::create(
rewriter, op.getLoc(), APInt::getZero(width - extractAmount));
replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(
rewriter, op, op.getType(), ArrayRef<Value>{constZero, lowerBits});
return success();
}
if (llvm::succeeded(comb::convertModUByPowerOfTwo(op, rewriter)))
return success();

// When rhs is not power of two and the number of unknown bits are small,
// create a mux tree that emulates all possible cases.
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/Comb/CombFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,20 @@ OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
return {};
return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
}

LogicalResult DivUOp::canonicalize(DivUOp op, PatternRewriter &rewriter) {
if (isOpTriviallyRecursive(op) || !op.getTwoState())
return failure();
return convertDivUByPowerOfTwo(op, rewriter);
}

LogicalResult ModUOp::canonicalize(ModUOp op, PatternRewriter &rewriter) {
if (isOpTriviallyRecursive(op) || !op.getTwoState())
return failure();

return convertModUByPowerOfTwo(op, rewriter);
}

//===----------------------------------------------------------------------===//
// ConcatOp
//===----------------------------------------------------------------------===//
Expand Down
66 changes: 66 additions & 0 deletions lib/Dialect/Comb/CombOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,72 @@ llvm::LogicalResult comb::convertSubToAdd(comb::SubOp subOp,
return success();
}

static llvm::LogicalResult convertDivModUByPowerOfTwo(PatternRewriter &rewriter,
Operation *op, Value lhs,
Value rhs, bool isDiv) {
// Check if the divisor is a power of two constant.
auto rhsConstantOp = rhs.getDefiningOp<hw::ConstantOp>();
if (!rhsConstantOp)
return failure();

APInt rhsValue = rhsConstantOp.getValue();
if (!rhsValue.isPowerOf2())
return failure();

Location loc = op->getLoc();

unsigned width = lhs.getType().getIntOrFloatBitWidth();
unsigned bitPosition = rhsValue.ceilLogBase2();

if (isDiv) {
// divu(x, 2^n) -> concat(0...0, extract(x, n, width-n))
// This is equivalent to a right shift by n bits.

// Extract the upper bits (equivalent to right shift).
Value upperBits = rewriter.createOrFold<comb::ExtractOp>(
loc, lhs, bitPosition, width - bitPosition);

// Concatenate with zeros on the left.
Value zeros =
hw::ConstantOp::create(rewriter, loc, APInt::getZero(bitPosition));

// use replaceOpWithNewOpAndCopyNamehint?
replaceOpAndCopyNamehint(
rewriter, op,
comb::ConcatOp::create(rewriter, loc,
ArrayRef<Value>{zeros, upperBits}));
return success();
}

// modu(x, 2^n) -> concat(0...0, extract(x, 0, n))
// This extracts the lower n bits (equivalent to bitwise AND with 2^n - 1).

// Extract the lower bits.
Value lowerBits =
rewriter.createOrFold<comb::ExtractOp>(loc, lhs, 0, bitPosition);

// Concatenate with zeros on the left.
Value zeros = hw::ConstantOp::create(rewriter, loc,
APInt::getZero(width - bitPosition));

replaceOpAndCopyNamehint(
rewriter, op,
comb::ConcatOp::create(rewriter, loc, ArrayRef<Value>{zeros, lowerBits}));
return success();
}

LogicalResult comb::convertDivUByPowerOfTwo(DivUOp divOp,
mlir::PatternRewriter &rewriter) {
return convertDivModUByPowerOfTwo(rewriter, divOp, divOp.getLhs(),
divOp.getRhs(), /*isDiv=*/true);
}

LogicalResult comb::convertModUByPowerOfTwo(ModUOp modOp,
mlir::PatternRewriter &rewriter) {
return convertDivModUByPowerOfTwo(rewriter, modOp, modOp.getLhs(),
modOp.getRhs(), /*isDiv=*/false);
}

//===----------------------------------------------------------------------===//
// ICmpOp
//===----------------------------------------------------------------------===//
Expand Down
19 changes: 16 additions & 3 deletions lib/Dialect/Datapath/DatapathFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
//===----------------------------------------------------------------------===//

#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/Datapath/DatapathDialect.h"
#include "circt/Dialect/Datapath/DatapathOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/KnownBits.h"
#include <algorithm>

Expand Down Expand Up @@ -117,6 +119,9 @@ struct FoldAddIntoCompress : public OpRewritePattern<comb::AddOp> {
using OpRewritePattern::OpRewritePattern;

// add(compress(a,b,c),d) -> add(compress(a,b,c,d))
// FIXME: This should be implemented as a canonicalization pattern for
// compress op. Currently `hasDatapathOperand` flag prevents introducing
// datapath operations from comb operations.
LogicalResult matchAndRewrite(comb::AddOp addOp,
PatternRewriter &rewriter) const override {
// comb.add canonicalization patterns handle folding add operations
Expand All @@ -128,15 +133,23 @@ struct FoldAddIntoCompress : public OpRewritePattern<comb::AddOp> {
llvm::SmallSetVector<Value, 8> processedCompressorResults;
SmallVector<Value, 8> newCompressOperands;
// Only construct compressor if can form a larger compressor than what
// is currently an input of this add
bool shouldFold = false;
// is currently an input of this add. Also check that there is at least
// one datapath operand.
bool shouldFold = false, hasDatapathOperand = false;

for (Value operand : operands) {

// Skip if already processed this compressor
if (processedCompressorResults.contains(operand))
continue;

if (auto *op = operand.getDefiningOp()) {
if (llvm::isa_and_nonnull<datapath::DatapathDialect>(
op->getDialect())) {
hasDatapathOperand = true;
}
}

// If the operand has multiple uses, we do not fold it into a compress
// operation, so we treat it as a regular operand.
if (!operand.hasOneUse()) {
Expand Down Expand Up @@ -174,7 +187,7 @@ struct FoldAddIntoCompress : public OpRewritePattern<comb::AddOp> {

// Only fold if we have constructed a larger compressor than what was
// already there
if (!shouldFold)
if (!shouldFold || !hasDatapathOperand)
return failure();

// Create a new CompressOp with all collected operands
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ void circt::synth::buildCombLoweringPipeline(
circt::ConvertDatapathToCombOptions datapathOptions;
datapathOptions.timingAware = options.timingAware;
pm.addPass(createConvertDatapathToComb(datapathOptions));
pm.addPass(createSimpleCanonicalizerPass());
}
pm.addPass(createCSEPass());
pm.addPass(createSimpleCanonicalizerPass());
Comment on lines +63 to +64
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for canonicalizing divu/modu before running CombToSynth.

// Partially legalize Comb, then run CSE and canonicalization.
circt::ConvertCombToSynthOptions convOptions;
addOpName<comb::AndOp, comb::OrOp, comb::XorOp, comb::MuxOp, comb::ICmpOp,
Expand Down
65 changes: 65 additions & 0 deletions test/Dialect/Comb/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,71 @@ hw.module @moduloZeroDividend(in %arg0 : i32, out o1: i32, out o2: i32) {
hw.output %0, %1 : i32, i32
}

// CHECK-LABEL: hw.module @divuPowerOfTwo
hw.module @divuPowerOfTwo(in %arg0 : i8, out o1: i8, out o2: i8, out o3: i8, out o4: i8, out o5: i8) {
// divu(x, 2) -> concat(0, extract(x, 1, 7))
// CHECK: [[EXT1:%.+]] = comb.extract %arg0 from 1 : (i8) -> i7
// CHECK-NEXT: [[RES1:%.+]] = comb.concat %false, [[EXT1]] : i1, i7
%c2 = hw.constant 2 : i8
%0 = comb.divu bin %arg0, %c2 : i8

// divu(x, 4) -> concat(00, extract(x, 2, 6))
// CHECK-NEXT: [[EXT2:%.+]] = comb.extract %arg0 from 2 : (i8) -> i6
// CHECK-NEXT: [[RES2:%.+]] = comb.concat %c0_i2, [[EXT2]] : i2, i6
%c4 = hw.constant 4 : i8
%1 = comb.divu bin %arg0, %c4 : i8

// divu(x, 16) -> concat(0000, extract(x, 4, 4))
// CHECK-NEXT: [[EXT4:%.+]] = comb.extract %arg0 from 4 : (i8) -> i4
// CHECK-NEXT: [[RES3:%.+]] = comb.concat %c0_i4, [[EXT4]] : i4, i4
%c16 = hw.constant 16 : i8
%2 = comb.divu bin %arg0, %c16 : i8

// divu(x, 3) -> not canonicalized (not power of two)
// CHECK-NEXT: [[RES4:%.+]] = comb.divu bin %arg0, %c3_i8 : i8
%c3 = hw.constant 3 : i8
%3 = comb.divu bin %arg0, %c3 : i8

// Make sure canonicalization does not happen if there is no bin flag.
// CHECK-NEXT: [[RES5:%.+]] = comb.divu %arg0, %c2_i8 : i8
%4 = comb.divu %arg0, %c2 : i8

// CHECK: hw.output [[RES1]], [[RES2]], [[RES3]], [[RES4]], [[RES5]]
hw.output %0, %1, %2, %3, %4 : i8, i8, i8, i8, i8
}

// CHECK-LABEL: hw.module @moduPowerOfTwo
hw.module @moduPowerOfTwo(in %arg0 : i8, out o1: i8, out o2: i8, out o3: i8, out o4: i8, out o5: i8) {
// modu(x, 2) -> concat(0000000, extract(x, 0, 1))
// CHECK: [[EXT1:%.+]] = comb.extract %arg0 from 0 : (i8) -> i1
// CHECK-NEXT: [[RES1:%.+]] = comb.concat %c0_i7, [[EXT1]] : i7, i1
%c2 = hw.constant 2 : i8
%0 = comb.modu bin %arg0, %c2 : i8

// modu(x, 4) -> concat(000000, extract(x, 0, 2))
// CHECK-NEXT: [[EXT2:%.+]] = comb.extract %arg0 from 0 : (i8) -> i2
// CHECK-NEXT: [[RES2:%.+]] = comb.concat %c0_i6, [[EXT2]] : i6, i2
%c4 = hw.constant 4 : i8
%1 = comb.modu bin %arg0, %c4 : i8

// modu(x, 16) -> concat(0000, extract(x, 0, 4))
// CHECK-NEXT: [[EXT4:%.+]] = comb.extract %arg0 from 0 : (i8) -> i4
// CHECK-NEXT: [[RES3:%.+]] = comb.concat %c0_i4, [[EXT4]] : i4, i4
%c16 = hw.constant 16 : i8
%2 = comb.modu bin %arg0, %c16 : i8

// modu(x, 3) -> not canonicalized (not power of two)
// CHECK-NEXT: [[RES4:%.+]] = comb.modu bin %arg0, %c3_i8 : i8
%c3 = hw.constant 3 : i8
%3 = comb.modu bin %arg0, %c3 : i8

// Make sure canonicalization does not happen if there is no bin flag.
// CHECK-NEXT: [[RES5:%.+]] = comb.modu %arg0, %c2_i8 : i8
%4 = comb.modu %arg0, %c2 : i8
// CHECK: hw.output [[RES1]], [[RES2]], [[RES3]], [[RES4]], [[RES5]]
hw.output %0, %1, %2, %3, %4 : i8, i8, i8, i8, i8
}

// CHECK-LABEL: hw.module @orWithNegation
hw.module @orWithNegation(in %arg0 : i32, out o1: i32) {
// CHECK: [[ALLONES:%.*]] = hw.constant -1 : i32
Expand Down
10 changes: 10 additions & 0 deletions test/Dialect/Datapath/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,13 @@ hw.module @pos_partial_product_do_nothing(in %a : i4, in %b : i4, in %c : i4, ou
%2:4 = datapath.partial_product %0, %1 : (i4, i4) -> (i4, i4, i4, i4)
hw.output %2#0, %2#1, %2#2, %2#3 : i4, i4, i4, i4
}

// CHECK-LABEL: @dont_introduce_compressor
hw.module @dont_introduce_compressor(in %a : i4, in %b : i4, in %c: i4, out sum : i4) {
// CHECK-NOT: datapath.compress
// CHECK-NEXT: comb.add
// CHECK-NEXT: hw.output
%0:4 = datapath.partial_product %a, %b : (i4, i4) -> (i4, i4, i4, i4)
%1 = comb.add %a, %b, %c : i4
hw.output %1 : i4
}
Loading