Skip to content

Commit

Permalink
[FIRRTLFolds] Fix crashes in bundle/vector create folders. (#5048)
Browse files Browse the repository at this point in the history
* Bug that did drop_front and drop_begin, instead of just one.
* Don't crash if operand list is empty.

Also test behavior folding these to aggregateconstantop's.
  • Loading branch information
dtzSiFive committed Apr 14, 2023
1 parent 73bd384 commit 336b89a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 20 deletions.
42 changes: 23 additions & 19 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1997,32 +1997,36 @@ static Attribute collectFields(MLIRContext *context,
OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
// bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
// bundle<a:..., b:...>.
if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
if (first.getFieldIndex() == 0 && first.getInput().getType() == getType() &&
llvm::all_of(
llvm::drop_begin(llvm::enumerate(getOperands().drop_front())),
[&](auto elem) {
auto subindex = elem.value().template getDefiningOp<SubfieldOp>();
return subindex && subindex.getInput() == first.getInput() &&
subindex.getFieldIndex() == elem.index();
}))
return first.getInput();
if (getNumOperands() > 0)
if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
if (first.getFieldIndex() == 0 &&
first.getInput().getType() == getType() &&
llvm::all_of(
llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
auto subindex =
elem.value().template getDefiningOp<SubfieldOp>();
return subindex && subindex.getInput() == first.getInput() &&
subindex.getFieldIndex() == elem.index();
}))
return first.getInput();

return collectFields(getContext(), adaptor.getOperands());
}

OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
// vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
// vector<..., 2>.
if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
llvm::all_of(
llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
auto subindex = elem.value().template getDefiningOp<SubindexOp>();
return subindex && subindex.getInput() == first.getInput() &&
subindex.getIndex() == elem.index();
}))
return first.getInput();
if (getNumOperands() > 0)
if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
llvm::all_of(
llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
auto subindex =
elem.value().template getDefiningOp<SubindexOp>();
return subindex && subindex.getInput() == first.getInput() &&
subindex.getIndex() == elem.index();
}))
return first.getInput();

return collectFields(getContext(), adaptor.getOperands());
}
Expand Down
52 changes: 51 additions & 1 deletion test/Dialect/FIRRTL/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2810,7 +2810,7 @@ firrtl.module @RemoveUnusedInvalid() {
}
// CHECK-NEXT: }

// CHECK-LABEL: firrtl.module @AggregateCreate
// CHECK-LABEL: firrtl.module @AggregateCreate(
firrtl.module @AggregateCreate(in %vector_in: !firrtl.vector<uint<1>, 2>,
in %bundle_in: !firrtl.bundle<a: uint<1>, b: uint<1>>,
out %vector_out: !firrtl.vector<uint<1>, 2>,
Expand All @@ -2828,6 +2828,56 @@ firrtl.module @AggregateCreate(in %vector_in: !firrtl.vector<uint<1>, 2>,
// CHECK-NEXT: firrtl.strictconnect %bundle_out, %bundle_in : !firrtl.bundle<a: uint<1>, b: uint<1>>
}

// CHECK-LABEL: firrtl.module @AggregateCreateSingle(
firrtl.module @AggregateCreateSingle(in %vector_in: !firrtl.vector<uint<1>, 1>,
in %bundle_in: !firrtl.bundle<a: uint<1>>,
out %vector_out: !firrtl.vector<uint<1>, 1>,
out %bundle_out: !firrtl.bundle<a: uint<1>>) {

%0 = firrtl.subindex %vector_in[0] : !firrtl.vector<uint<1>, 1>
%vector = firrtl.vectorcreate %0 : (!firrtl.uint<1>) -> !firrtl.vector<uint<1>, 1>
firrtl.strictconnect %vector_out, %vector : !firrtl.vector<uint<1>, 1>

%2 = firrtl.subfield %bundle_in["a"] : !firrtl.bundle<a: uint<1>>
%bundle = firrtl.bundlecreate %2 : (!firrtl.uint<1>) -> !firrtl.bundle<a: uint<1>>
firrtl.strictconnect %bundle_out, %bundle : !firrtl.bundle<a: uint<1>>
// CHECK-NEXT: firrtl.strictconnect %vector_out, %vector_in : !firrtl.vector<uint<1>, 1>
// CHECK-NEXT: firrtl.strictconnect %bundle_out, %bundle_in : !firrtl.bundle<a: uint<1>>
}

// CHECK-LABEL: firrtl.module @AggregateCreateEmpty(
firrtl.module @AggregateCreateEmpty(
out %vector_out: !firrtl.vector<uint<1>, 0>,
out %bundle_out: !firrtl.bundle<>) {

%vector = firrtl.vectorcreate : () -> !firrtl.vector<uint<1>, 0>
firrtl.strictconnect %vector_out, %vector : !firrtl.vector<uint<1>, 0>

%bundle = firrtl.bundlecreate : () -> !firrtl.bundle<>
firrtl.strictconnect %bundle_out, %bundle : !firrtl.bundle<>
// CHECK-DAG: %[[VEC:.+]] = firrtl.aggregateconstant [] : !firrtl.vector<uint<1>, 0>
// CHECK-DAG: %[[BUNDLE:.+]] = firrtl.aggregateconstant [] : !firrtl.bundle<>
// CHECK-DAG: firrtl.strictconnect %vector_out, %[[VEC]] : !firrtl.vector<uint<1>, 0>
// CHECK-DAG: firrtl.strictconnect %bundle_out, %[[BUNDLE]] : !firrtl.bundle<>
}

// CHECK-LABEL: firrtl.module @AggregateCreateConst(
firrtl.module @AggregateCreateConst(
out %vector_out: !firrtl.vector<uint<1>, 2>,
out %bundle_out: !firrtl.bundle<a: uint<1>, b: uint<1>>) {

%const = firrtl.constant 0 : !firrtl.uint<1>
%vector = firrtl.vectorcreate %const, %const : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.vector<uint<1>, 2>
firrtl.strictconnect %vector_out, %vector : !firrtl.vector<uint<1>, 2>

%bundle = firrtl.bundlecreate %const, %const : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.bundle<a: uint<1>, b: uint<1>>
firrtl.strictconnect %bundle_out, %bundle : !firrtl.bundle<a: uint<1>, b: uint<1>>
// CHECK-DAG: %[[VEC:.+]] = firrtl.aggregateconstant [0 : ui1, 0 : ui1] : !firrtl.vector<uint<1>, 2>
// CHECK-DAG: %[[BUNDLE:.+]] = firrtl.aggregateconstant [0 : ui1, 0 : ui1] : !firrtl.bundle<a: uint<1>, b: uint<1>>
// CHECK-DAG: firrtl.strictconnect %vector_out, %[[VEC]] : !firrtl.vector<uint<1>, 2>
// CHECK-DAG: firrtl.strictconnect %bundle_out, %[[BUNDLE]] : !firrtl.bundle<a: uint<1>, b: uint<1>>
}


// CHECK-LABEL: firrtl.module private @RWProbeUnused
firrtl.module private @RWProbeUnused(in %in: !firrtl.uint<4>, in %clk: !firrtl.clock, out %out: !firrtl.uint) {
Expand Down

0 comments on commit 336b89a

Please sign in to comment.