diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td index 3ee75d6..e18f470 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td @@ -816,7 +816,9 @@ def CallOp : P4HIR_Op<"call", }]; } -def StructOp : P4HIR_Op<"struct", [Pure]> { +def StructOp : P4HIR_Op<"struct", + [Pure, + DeclareOpInterfaceMethods]> { let summary = "Create a struct from constituent parts."; // FIXME: Better constraint type let arguments = (ins Variadic:$input); @@ -825,10 +827,9 @@ def StructOp : P4HIR_Op<"struct", [Pure]> { let hasVerifier = 1; } -// Extract the value of a field of a structure. def StructExtractOp : P4HIR_Op<"struct_extract", [Pure, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods ]> { let summary = "Extract a named field from a struct."; let description = [{ @@ -868,12 +869,11 @@ def StructExtractOp : P4HIR_Op<"struct_extract", }]; } -// Extract the value of a field of a structure. def StructExtractRefOp : P4HIR_Op<"struct_extract_ref", [Pure, DeclareOpInterfaceMethods ]> { - let summary = "Create a reference to a struct field"; + let summary = "Project from a struct reference to a reference to a named struct field"; let description = [{ ``` %result = p4hir.struct_extract_ref %input["field"] : > diff --git a/lib/Dialect/P4HIR/P4HIR_Ops.cpp b/lib/Dialect/P4HIR/P4HIR_Ops.cpp index f21d108..2241762 100644 --- a/lib/Dialect/P4HIR/P4HIR_Ops.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Ops.cpp @@ -1,5 +1,6 @@ #include "p4mlir/Dialect/P4HIR/P4HIR_Ops.h" +#include "llvm/ADT/SmallString.h" #include "llvm/Support/LogicalResult.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -642,6 +643,12 @@ LogicalResult P4HIR::StructOp::verify() { return success(); } +void P4HIR::StructOp::getAsmResultNames(function_ref setNameFn) { + llvm::SmallString<32> name("struct_"); + name += getType().getName(); + setNameFn(getResult(), name); +} + //===----------------------------------------------------------------------===// // StructExtractOp //===----------------------------------------------------------------------===// @@ -673,15 +680,45 @@ template static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand operand; StringAttr fieldName; - Type declType; + AggregateType declType; + + if (parser.parseOperand(operand) || parser.parseLSquare() || parser.parseAttribute(fieldName) || + parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColon() || parser.parseCustomTypeWithFallback(declType)) + return failure(); + + auto fieldIndex = declType.getFieldIndex(fieldName); + if (!fieldIndex) { + parser.emitError(parser.getNameLoc(), + "field name '" + fieldName.getValue() + "' not found in aggregate type"); + return failure(); + } + + auto indexAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex); + result.addAttribute("fieldIndex", indexAttr); + Type resultType = declType.getElements()[*fieldIndex].type; + result.addTypes(resultType); + + if (parser.resolveOperand(operand, declType, result.operands)) return failure(); + return success(); +} + +template +static ParseResult parseExtractRefOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand operand; + StringAttr fieldName; + P4HIR::ReferenceType declType; if (parser.parseOperand(operand) || parser.parseLSquare() || parser.parseAttribute(fieldName) || parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(declType)) + parser.parseColon() || parser.parseCustomTypeWithFallback(declType)) return failure(); - auto aggType = mlir::dyn_cast(declType); - if (!aggType) return parser.emitError(parser.getNameLoc(), "invalid kind of type specified"); + auto aggType = mlir::dyn_cast(declType.getObjectType()); + if (!aggType) { + parser.emitError(parser.getNameLoc(), "expected reference to aggregate type"); + return failure(); + } auto fieldIndex = aggType.getFieldIndex(fieldName); if (!fieldIndex) { parser.emitError(parser.getNameLoc(), @@ -691,7 +728,7 @@ static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { auto indexAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex); result.addAttribute("fieldIndex", indexAttr); - Type resultType = aggType.getElements()[*fieldIndex].type; + Type resultType = P4HIR::ReferenceType::get(aggType.getElements()[*fieldIndex].type); result.addTypes(resultType); if (parser.resolveOperand(operand, declType, result.operands)) return failure(); @@ -745,7 +782,7 @@ void P4HIR::StructExtractRefOp::getAsmResultNames(function_ref(parser, result); + return parseExtractRefOp(parser, result); } void P4HIR::StructExtractRefOp::print(OpAsmPrinter &printer) { printExtractOp(printer, *this); } diff --git a/test/Dialect/P4HIR/struct.mlir b/test/Dialect/P4HIR/struct.mlir new file mode 100644 index 0000000..a209e60 --- /dev/null +++ b/test/Dialect/P4HIR/struct.mlir @@ -0,0 +1,37 @@ +// RUN: p4mlir-opt %s | FileCheck %s + +!i32i = !p4hir.int<32> +!T = !p4hir.struct<"T", t1: !i32i, t2: !i32i> +!S = !p4hir.struct<"S", s1: !T, s2: !T> +!Empty = !p4hir.struct<"Empty"> +!b9i = !p4hir.bit<9> +!PortId_t = !p4hir.struct<"PortId_t", _v: !b9i> + +#int10_i32i = #p4hir.int<10> : !i32i +#int20_i32i = #p4hir.int<20> : !i32i +#int1_b9i = #p4hir.int<1> : !b9i + +// CHECK: module +module { + %e = p4hir.const ["e"] #p4hir.aggregate<[]> : !Empty + %t = p4hir.const ["t"] #p4hir.aggregate<[#int10_i32i, #int20_i32i]> : !T + + p4hir.func action @test2(%arg0: !p4hir.ref {p4hir.dir = #p4hir}) { + %_v = p4hir.struct_extract_ref %arg0["_v"] : + %val = p4hir.read %arg0 : + %_v_0 = p4hir.struct_extract %val["_v"] : !PortId_t + %c1_b9i = p4hir.const #int1_b9i + %add = p4hir.binop(add, %_v_0, %c1_b9i) : !b9i + p4hir.assign %add, %_v : + p4hir.return + } + + p4hir.func action @test() { + %vv = p4hir.variable ["vv"] : + %val = p4hir.read %vv : + %0 = p4hir.struct (%val) : !PortId_t + %p1 = p4hir.variable ["p1", init] : + p4hir.assign %0, %p1 : + p4hir.return + } +} diff --git a/test/Translate/Ops/struct.p4 b/test/Translate/Ops/struct.p4 new file mode 100644 index 0000000..d7a9893 --- /dev/null +++ b/test/Translate/Ops/struct.p4 @@ -0,0 +1,111 @@ +// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s + +struct P { + bit<32> f1; + bit<32> f2; +} + +struct T { + int<32> t1; + int<32> t2; +} + +struct S { + T s1; + T s2; +} + +struct Empty {}; + +// CHECK: !Empty = !p4hir.struct<"Empty"> +// CHECK: !PortId_t = !p4hir.struct<"PortId_t", _v: !b9i> +// CHECK: !T = !p4hir.struct<"T", t1: !i32i, t2: !i32i> +// CHECK: !S = !p4hir.struct<"S", s1: !T, s2: !T> +// CHECK: !metadata_t = !p4hir.struct<"metadata_t", foo: !PortId_t> + +// CHECK-LABEL: module + +const T t = { 32s10, 32s20 }; +const S s = { { 32s15, 32s25}, t }; + +// CHECK: %t = p4hir.const ["t"] #p4hir.aggregate<[#int10_i32i, #int20_i32i]> : !T +// CHECK: %s = p4hir.const ["s"] #p4hir.aggregate<[#p4hir.aggregate<[#int15_i32i, #int25_i32i]> : !T, #p4hir.aggregate<[#int10_i32i, #int20_i32i]> : !T]> : !S + +const int<32> x = t.t1; +const int<32> y = s.s1.t2; + +const int<32> w = .t.t1; + +// CHECK: %x = p4hir.const ["x"] #int10_i32i +// CHECK: %y = p4hir.const ["y"] #int25_i32i +// CHECK: %w = p4hir.const ["w"] #int10_i32i + +const T tt1 = s.s1; +const Empty e = {}; + +// CHECK: %tt1 = p4hir.const ["tt1"] #p4hir.aggregate<[#int15_i32i, #int25_i32i]> : !T +// CHECK: %e = p4hir.const ["e"] #p4hir.aggregate<[]> : !Empty + +const T t1 = { 10, 20 }; +const S s1 = { { 15, 25 }, t1 }; + +const int<32> x1 = t1.t1; +const int<32> y1 = s1.s1.t2; + +const int<32> w1 = .t1.t1; + +const T t2 = s1.s1; + +struct PortId_t { bit<9> _v; } + +const PortId_t PSA_CPU_PORT = { _v = 9w192 }; + +struct metadata_t { + PortId_t foo; +} + +action test2(inout PortId_t port) { + port._v = port._v + 1; +} + +// CHECK-LABEL: p4hir.func action @test2(%arg0: !p4hir.ref {p4hir.dir = #p4hir}) { +// CHECK: %[[_V_REF:.*]] = p4hir.struct_extract_ref %arg0["_v"] : +// CHECK: %[[VAL:.*]] = p4hir.read %arg0 : +// CHECK: %[[_V_VAL:.*]] = p4hir.struct_extract %[[VAL]]["_v"] : !PortId_t +// CHECK: p4hir.assign %{{.*}}, %[[_V_REF]] +// CHECK: p4hir.return + +// CHECK-LABEL: p4hir.func action @test(%arg0: !p4hir.ref {p4hir.dir = #p4hir}) { +// Just few important bits here +action test(inout metadata_t meta) { + bit<9> vv; + + PortId_t p1 = { _v = vv }; + + // CHECK: %[[VV_VAR:.*]] = p4hir.variable ["vv"] : + // CHECK: %[[VV_VAL:.*]] = p4hir.read %[[VV_VAR]] : + // CHECK: %[[STRUCT:.*]] = p4hir.struct (%[[VV_VAL]]) : !PortId_t + // CHECK: %[[P_VAR:.*]] = p4hir.variable ["p1", init] : + // CHECK: p4hir.assign %[[STRUCT]], %[[P_VAR]] : + + PortId_t p; + bit<9> v; + v = p._v; + + v = meta.foo._v; + + meta.foo._v = 1; + + // CHECK: p4hir.scope { + // CHECK: p4hir.call @test2 + test2(meta.foo); + // CHECK: } + + // CHECK: %[[METADATA_VAL:.*]] = p4hir.read %arg0 : + // CHECK: %[[FOO:.*]] = p4hir.struct_extract %[[METADATA_VAL]]["foo"] : !metadata_t + // CHECK: %[[PSA_CPU_PORT:.*]] = p4hir.const ["PSA_CPU_PORT"] #p4hir.aggregate<[#int192_b9i]> : !PortId_t + // CHECK: %eq = p4hir.cmp(eq, %[[FOO]], %[[PSA_CPU_PORT]]) : !PortId_t, !p4hir.bool + if (meta.foo == PSA_CPU_PORT) { + meta.foo._v = meta.foo._v + 1; + } +}