Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIRRTL] Add MatchOp statement #5037

Merged
merged 1 commit into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLStatements.td
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,50 @@ def WhenOp : FIRRTLOp<"when", [SingleBlock, NoTerminator, NoRegionArguments,
}];
}

def MatchOp : FIRRTLOp<"match", [SingleBlock, NoTerminator,
RecursiveMemoryEffects, RecursivelySpeculatable]> {
let summary = "Match Statement";
let description = [{
The "firrtl.match" operation represents a pattern matching statement on a
enumeration. This operation does not return a value and cannot be used as an
expression. Last connect semantics work similarly to a when statement.

Example:
```mlir
firrtl.match %in : !firrtl.enum<Some: uint<1>, None: uint<0>> {
case Some(%arg0) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ( and ) are not technically needed here, but I added them to match the FIRRTL syntax.

I might be able to get rid of the case keyword here but it made parsing easier.

!firrtl.strictconnect %w, %arg0 : !firrtl.uint<1>
}
case None(%arg0) {
!firrt.strictconnect %w, %c1 : !firrtl.uint<1>
}
}
```
}];
let arguments = (ins FEnumType:$input, I32ArrayAttr:$tags);
let results = (outs);
let regions = (region VariadicRegion<SizedRegion<1>>:$regions);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "::mlir::Value":$input,
"::mlir::ArrayAttr":$tags,
"::llvm::MutableArrayRef<std::unique_ptr<Region>>":$regions)>
];

let extraClassDeclaration = [{
IntegerAttr getFieldIndexAttr(size_t caseIndex) {
return getTags()[caseIndex].cast<IntegerAttr>();
}

uint32_t getFieldIndex(size_t caseIndex) {
return getFieldIndexAttr(caseIndex).getUInt();
}
}];
}

def ForceOp : FIRRTLOp<"force", [SameTypeOperands]> {
let summary = "Force procedural statement";
let description = "Maps to the corresponding `sv.force` operation.";
Expand Down
137 changes: 137 additions & 0 deletions lib/Dialect/FIRRTL/FIRRTLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,143 @@ void WhenOp::build(OpBuilder &builder, OperationState &result, Value condition,
}
}

//===----------------------------------------------------------------------===//
// MatchOp
//===----------------------------------------------------------------------===//

LogicalResult MatchOp::verify() {
auto type = getInput().getType();

// Make sure that the number of tags matches the number of regions.
auto numCases = getTags().size();
auto numRegions = getNumRegions();
if (numRegions != numCases)
return emitOpError("expected ")
<< numRegions << " tags but got " << numCases;

auto numTags = type.getNumElements();

SmallDenseSet<int64_t> seen;
for (const auto &[tag, region] : llvm::zip(getTags(), getRegions())) {
auto tagIndex = size_t(cast<IntegerAttr>(tag).getInt());

// Ensure that the block has a single argument.
if (region.getNumArguments() != 1)
return emitOpError("region should have exactly one argument");

// Make sure that it is a valid tag.
if (tagIndex >= numTags)
return emitOpError("the tag index ")
<< tagIndex << " is out of the range of valid tags in " << type;

// Make sure we have not already matched this tag.
auto [it, inserted] = seen.insert(tagIndex);
if (!inserted)
return emitOpError("the tag ") << type.getElementNameAttr(tagIndex)
<< " is matched more than once";

// Check that the block argument type matches the tag's type.
auto expectedType = type.getElementType(tagIndex);
auto regionType = region.getArgument(0).getType();
if (regionType != expectedType)
return emitOpError("region type ")
<< regionType << " does not match the expected type "
<< expectedType;
}

// Check that the match statement is exhaustive.
for (size_t i = 0, e = type.getNumElements(); i < e; ++i)
if (!seen.contains(i))
return emitOpError("missing case for tag ") << type.getElementNameAttr(i);

return success();
}

void MatchOp::print(OpAsmPrinter &p) {
auto input = getInput();
auto type = input.getType();
auto regions = getRegions();
p << " " << input << " : " << type;
SmallVector<StringRef> elided = {"tags"};
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elided);
p << " {";
p.increaseIndent();
for (const auto &[tag, region] : llvm::zip(getTags(), regions)) {
p.printNewline();
p << "case ";
p.printKeywordOrString(
type.getElementName(tag.cast<IntegerAttr>().getInt()));
p << "(";
p.printRegionArgument(region.front().getArgument(0), /*attrs=*/{},
/*omitType=*/true);
p << ") ";
p.printRegion(region, /*printEntryBlockArgs=*/false);
}
p.decreaseIndent();
p.printNewline();
p << "}";
}

ParseResult MatchOp::parse(OpAsmParser &parser, OperationState &result) {
auto *context = parser.getContext();
OpAsmParser::UnresolvedOperand input;
if (parser.parseOperand(input) || parser.parseColon())
return failure();

auto loc = parser.getCurrentLocation();
Type type;
if (parser.parseType(type))
return failure();
auto enumType = type.dyn_cast<FEnumType>();
if (!enumType)
return parser.emitError(loc, "expected enumeration type but got") << type;

if (parser.resolveOperand(input, type, result.operands) ||
parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
parser.parseLBrace())
return failure();

auto i32Type = IntegerType::get(context, 32);
SmallVector<Attribute> tags;
while (true) {
// Stop parsing when we don't find another "case" keyword.
if (failed(parser.parseOptionalKeyword("case")))
break;

// Parse the tag and region argument.
auto nameLoc = parser.getCurrentLocation();
std::string name;
OpAsmParser::Argument arg;
auto *region = result.addRegion();
if (parser.parseKeywordOrString(&name) || parser.parseLParen() ||
parser.parseArgument(arg) || parser.parseRParen())
return failure();

// Figure out the enum index of the tag.
auto index = enumType.getElementIndex(name);
if (!index)
return parser.emitError(nameLoc, "the tag \"")
<< name << "\" is not a member of the enumeration " << enumType;
tags.push_back(IntegerAttr::get(i32Type, *index));

// Parse the region.
arg.type = enumType.getElementType(*index);
if (parser.parseRegion(*region, arg))
return failure();
}
result.addAttribute("tags", ArrayAttr::get(context, tags));

return parser.parseRBrace();
}

void MatchOp::build(OpBuilder &builder, OperationState &result, Value input,
ArrayAttr tags,
MutableArrayRef<std::unique_ptr<Region>> regions) {
result.addOperands(input);
result.addAttribute("tags", tags);
result.addRegions(regions);
}

//===----------------------------------------------------------------------===//
// Expressions
//===----------------------------------------------------------------------===//
Expand Down
32 changes: 32 additions & 0 deletions test/Dialect/FIRRTL/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,38 @@ firrtl.circuit "MismatchedRegister" {

// -----

firrtl.circuit "EnumOutOfRange" {
firrtl.module @EnumSameCase(in %enum : !firrtl.enum<a : uint<8>>) {
// expected-error @+1 {{the tag index 1 is out of the range of valid tags in '!firrtl.enum<a: uint<8>>'}}
"firrtl.match"(%enum) ({
^bb0(%arg0: !firrtl.uint<8>):
}) {tags = [1 : i32]} : (!firrtl.enum<a: uint<8>>) -> ()
}
}
// -----

firrtl.circuit "EnumSameCase" {
firrtl.module @EnumSameCase(in %enum : !firrtl.enum<a : uint<8>>) {
// expected-error @+1 {{the tag "a" is matched more than once}}
"firrtl.match"(%enum) ({
^bb0(%arg0: !firrtl.uint<8>):
}, {
^bb0(%arg0: !firrtl.uint<8>):
}) {tags = [0 : i32, 0 : i32]} : (!firrtl.enum<a: uint<8>>) -> ()
}
}

// -----

firrtl.circuit "EnumNonExaustive" {
firrtl.module @EnumNonExaustive(in %enum : !firrtl.enum<a : uint<8>>) {
// expected-error @+1 {{missing case for tag "a"}}
"firrtl.match"(%enum) {tags = []} : (!firrtl.enum<a: uint<8>>) -> ()
}
}

// -----

// expected-error @+1 {{'firrtl.circuit' op main module 'private_main' must be public}}
firrtl.circuit "private_main" {
firrtl.module private @private_main() {}
Expand Down
10 changes: 9 additions & 1 deletion test/Dialect/FIRRTL/test.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ firrtl.module @EnumTest(in %in : !firrtl.enum<a: uint<1>, b: uint<2>>,

%c1_u1 = firrtl.constant 0 : !firrtl.uint<8>
%some = firrtl.enumcreate Some(%c1_u1) : !firrtl.enum<None: uint<0>, Some: uint<8>>
}

firrtl.match %in : !firrtl.enum<a: uint<1>, b: uint<2>> {
case a(%arg0) {
%w = firrtl.wire : !firrtl.uint<1>
}
case b(%arg0) {
%x = firrtl.wire : !firrtl.uint<1>
}
}
}
}