Skip to content

Commit

Permalink
Implement lowering of ?? and || via ternary operation
Browse files Browse the repository at this point in the history
Signed-off-by: Anton Korobeynikov <[email protected]>
  • Loading branch information
asl committed Jan 31, 2025
1 parent ceec9cf commit 1437d86
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 2 deletions.
53 changes: 51 additions & 2 deletions include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,8 @@ def ScopeOp : P4HIR_Op<"scope", [
}

def YieldOp : P4HIR_Op<"yield", [ReturnLike, Terminator,
ParentOneOf<["ScopeOp",
// "IfOp", "TernaryOp"
ParentOneOf<["ScopeOp", "TernaryOp",
// "IfOp",
// "SwitchOp", "CaseOp",
// "ForInOp", "ForOp",
// "CallOp"
Expand Down Expand Up @@ -441,5 +441,54 @@ def YieldOp : P4HIR_Op<"yield", [ReturnLike, Terminator,
];
}

def TernaryOp : P4HIR_Op<"ternary",
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> {
let summary = "The `cond ? a : b` C/C++ ternary operation";
let description = [{
The `p4hir.ternary` operation represents operation that is absent in P4 language, but otherwise
is useful to represent varios language-level constructs. It is essentialyl a C/C++ ternary. First
argument is a `p4hir.bool` condition to evaluate, followed by two regions to execute (true or false).
This is different from `p4hir.if` since each region is one block sized and the `p4hir.yield` closing the
block scope should have one argument.

Example:

```mlir
// x = a && b;

%x = p4hir.ternary (%a, true_region {
...
p4hir.yield %b : !p4hir.bool
}, false_region {
...
p4hir.yield %a : !p4hir.bool
}) -> !p4hir.bool
```
}];
let arguments = (ins BooleanType:$cond);
let regions = (region AnyRegion:$trueRegion,
AnyRegion:$falseRegion);
let results = (outs Optional<AnyP4Type>:$result);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "mlir::Value":$cond,
"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$trueBuilder,
"llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$falseBuilder)
>
];

// All constraints already verified elsewhere.
let hasVerifier = 0;

let assemblyFormat = [{
`(` $cond `,`
`true` $trueRegion `,`
`false` $falseRegion
`)` `:` functional-type(operands, results) attr-dict
}];
}


#endif // P4MLIR_DIALECT_P4HIR_P4HIR_OPS_TD
34 changes: 34 additions & 0 deletions lib/Dialect/P4HIR/P4HIR_Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,40 @@ static void printOmittedTerminatorRegion(mlir::OpAsmPrinter &printer, P4HIR::Sco
/*printBlockTerminators=*/!omitRegionTerm(scopeRegion));
}

//===----------------------------------------------------------------------===//
// TernaryOp
//===----------------------------------------------------------------------===//

void P4HIR::TernaryOp::getSuccessorRegions(mlir::RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `true` and the `false` region branch back to the parent operation.
if (!point.isParent()) {
regions.push_back(RegionSuccessor(this->getODSResults(0)));
return;
}

// If the condition isn't constant, both regions may be executed.
regions.push_back(RegionSuccessor(&getTrueRegion()));
regions.push_back(RegionSuccessor(&getFalseRegion()));
}

void P4HIR::TernaryOp::build(OpBuilder &builder, OperationState &result, Value cond,
function_ref<void(OpBuilder &, Location)> trueBuilder,
function_ref<void(OpBuilder &, Location)> falseBuilder) {
result.addOperands(cond);
OpBuilder::InsertionGuard guard(builder);
Region *trueRegion = result.addRegion();
auto *block = builder.createBlock(trueRegion);
trueBuilder(builder, result.location);
Region *falseRegion = result.addRegion();
builder.createBlock(falseRegion);
falseBuilder(builder, result.location);

auto yield = dyn_cast<YieldOp>(block->getTerminator());
assert((yield && yield.getNumOperands() <= 1) && "expected zero or one result type");
if (yield.getNumOperands() == 1) result.addTypes(TypeRange{yield.getOperandTypes().front()});
}

void P4HIR::P4HIRDialect::initialize() {
registerTypes();
registerAttributes();
Expand Down
23 changes: 23 additions & 0 deletions test/Dialect/P4HIR/ternary.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: p4mlir-opt %s -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s

module {
// No need to check stuff. If it parses, it's fine.
// CHECK: module
%0 = p4hir.const #p4hir.bool<false> : !p4hir.bool
%1 = p4hir.ternary(%0, true {
%29 = p4hir.const #p4hir.bool<true> : !p4hir.bool
p4hir.yield %29 : !p4hir.bool
}, false {
%29 = p4hir.const #p4hir.bool<false> : !p4hir.bool
p4hir.yield %29 : !p4hir.bool
}) : (!p4hir.bool) -> !p4hir.bool

%2 = p4hir.ternary(%1, true {
%29 = p4hir.const #p4hir.int<42> : !p4hir.int<32>
p4hir.yield %29 : !p4hir.int<32>
}, false {
%29 = p4hir.const #p4hir.int<100500> : !p4hir.int<32>
p4hir.yield %29 : !p4hir.int<32>
}) : (!p4hir.bool) -> !p4hir.int<32>
}
28 changes: 28 additions & 0 deletions test/Translate/Ops/logical.p4
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s

action logical() {
// CHECK-LABEL: module
bool flag1; bool flag2; bool flag3;

// CHECK: %[[FLAG1:.*]] = p4hir.alloca !p4hir.bool ["flag1"] : !p4hir.ref<!p4hir.bool>
// CHECK: %[[FLAG2:.*]] = p4hir.alloca !p4hir.bool ["flag2"] : !p4hir.ref<!p4hir.bool>
// CHECK: %[[FLAG3:.*]] = p4hir.alloca !p4hir.bool ["flag3"] : !p4hir.ref<!p4hir.bool>

bool f1 = flag2 || flag3;
// CHECK: %[[FLAG2_VAL:.*]] = p4hir.load %[[FLAG2]] : !p4hir.ref<!p4hir.bool>, !p4hir.bool
// CHECK-NEXT: %[[F1_VAL:.*]] = p4hir.ternary(%[[FLAG2_VAL]], true {
// CHECK-NEXT: %[[TRUE_VAL:.*]] = p4hir.const #p4hir.bool<true> : !p4hir.bool
// CHECK-NEXT: p4hir.yield %[[TRUE_VAL]] : !p4hir.bool
// CHECK-NEXT: }, false {
// CHECK-NEXT: %[[FLAG3_VAL:.*]] = p4hir.load %[[FLAG3]] : !p4hir.ref<!p4hir.bool>, !p4hir.bool
// CHECK-NEXT: p4hir.yield %[[FLAG3_VAL]] : !p4hir.bool
// CHECK-NEXT }) : (!p4hir.bool) -> !p4hir.bool

bool f2 = flag2 && flag3;
bool f3 = flag2 && flag3 || flag3;
bool f7 = flag2 || flag3 || flag3;
bool f8 = flag2 || flag3 && flag3;
bool f5 = flag2 || flag3;
f5 = f1 && f5 || flag1;
bool f6 = flag1 || flag2;
}
44 changes: 44 additions & 0 deletions tools/p4mlir-translate/translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext {
mlir::TypedAttr resolveConstantExpr(const P4::IR::Expression *expr);
mlir::Value resolveReference(const P4::IR::Node *node);

mlir::Value getBoolConstant(mlir::Location loc, bool value) {
auto boolType = P4HIR::BoolType::get(context());
return builder.create<P4HIR::ConstOp>(loc, boolType,
P4HIR::BoolAttr::get(context(), boolType, value));
}

public:
P4HIRConverter(mlir::OpBuilder &builder, const P4::TypeMap *typeMap)
: builder(builder), typeMap(typeMap) {
Expand Down Expand Up @@ -332,6 +338,8 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext {

bool preorder(const P4::IR::Declaration_Constant *decl) override;
bool preorder(const P4::IR::AssignmentStatement *assign) override;
bool preorder(const P4::IR::LOr *lor) override;
bool preorder(const P4::IR::LAnd *land) override;

mlir::Value emitUnOp(const P4::IR::Operation_Unary *unop, P4HIR::UnaryOpKind kind);
mlir::Value emitBinOp(const P4::IR::Operation_Binary *binop, P4HIR::BinOpKind kind);
Expand Down Expand Up @@ -578,6 +586,42 @@ bool P4HIRConverter::preorder(const P4::IR::AssignmentStatement *assign) {
return false;
}

bool P4HIRConverter::preorder(const P4::IR::LOr *lor) {
// Lower a || b into a ? true : b
visit(lor->left);

auto value = builder.create<P4HIR::TernaryOp>(
getLoc(builder, lor), getValue(lor->left),
[&](mlir::OpBuilder &b, mlir::Location loc) {
b.create<P4HIR::YieldOp>(getEndLoc(builder, lor->left), getBoolConstant(loc, true));
},
[&](mlir::OpBuilder &b, mlir::Location) {
visit(lor->right);
b.create<P4HIR::YieldOp>(getEndLoc(builder, lor->right), getValue(lor->right));
});

setValue(lor, value.getResult());
return false;
}

bool P4HIRConverter::preorder(const P4::IR::LAnd *land) {
// Lower a && b into a ? b : false
visit(land->left);

auto value = builder.create<P4HIR::TernaryOp>(
getLoc(builder, land), getValue(land->left),
[&](mlir::OpBuilder &b, mlir::Location) {
visit(land->right);
b.create<P4HIR::YieldOp>(getEndLoc(builder, land->right), getValue(land->right));
},
[&](mlir::OpBuilder &b, mlir::Location loc) {
b.create<P4HIR::YieldOp>(getEndLoc(builder, land->left), getBoolConstant(loc, false));
});

setValue(land, value.getResult());
return false;
}

} // namespace

namespace P4::P4MLIR {
Expand Down

0 comments on commit 1437d86

Please sign in to comment.