diff --git a/include/p4mlir/Dialect/P4HIR/CMakeLists.txt b/include/p4mlir/Dialect/P4HIR/CMakeLists.txt index f76e432..d7ca937 100644 --- a/include/p4mlir/Dialect/P4HIR/CMakeLists.txt +++ b/include/p4mlir/Dialect/P4HIR/CMakeLists.txt @@ -14,3 +14,11 @@ mlir_tablegen(P4HIR_Attrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=p4hir) add_public_tablegen_target(P4MLIR_P4HIR_IncGen) add_dependencies(mlir-headers P4MLIR_P4HIR_IncGen) + +set(LLVM_TARGET_DEFINITIONS P4HIR_TypeInterfaces.td) +mlir_tablegen(P4HIR_TypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(P4HIR_TypeInterfaces.cpp.inc -gen-type-interface-defs) + +add_public_tablegen_target(P4MLIR_P4HIR_TypeInterfacesIncGen) +add_dependencies(mlir-headers P4MLIR_P4HIR_TypeInterfacesIncGen) + diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR.td b/include/p4mlir/Dialect/P4HIR/P4HIR.td index ea54a30..b4e45c1 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR.td @@ -4,5 +4,6 @@ include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.td" include "p4mlir/Dialect/P4HIR/P4HIR_Ops.td" include "p4mlir/Dialect/P4HIR/P4HIR_Types.td" +include "p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.td" #endif // P4MLIR_DIALECT_P4HIR_P4HIR_TD diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td index a58eda7..b5c37f8 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Attrs.td @@ -59,6 +59,35 @@ def P4HIR_IntAttr : P4HIR_Attr<"Int", "int", [TypedAttrInterface]> { let hasCustomAssemblyFormat = 1; } +//===----------------------------------------------------------------------===// +// AggAttr +//===----------------------------------------------------------------------===// + +def P4HIR_AggAttr : P4HIR_Attr<"Agg", "aggregate", [TypedAttrInterface]> { + let summary = "An Attribute containing an aggregate value"; + let description = [{ + An aggregate attribute is a literal attribute that represents an aggregate + value of the specified type. For nested aggregates, embedded arrays are + used. + }]; + let parameters = (ins AttributeSelfTypeParameter<"">:$type, + "mlir::ArrayAttr":$fields); + + let builders = [ + AttrBuilderWithInferredContext<(ins "mlir::Type":$type, + "mlir::ArrayAttr":$members), [{ + return $_get(type.getContext(), type, members); + }]> + ]; + // let genVerifyDecl = 1; + //let hasCustomAssemblyFormat = 1; + let assemblyFormat = [{ + `<` $fields `>` + }]; + +} + + //===----------------------------------------------------------------------===// // ParamDirAttr //===----------------------------------------------------------------------===// diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td index 46f0200..e18f470 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td @@ -816,4 +816,98 @@ def CallOp : P4HIR_Op<"call", }]; } +def StructOp : P4HIR_Op<"struct", + [Pure, + DeclareOpInterfaceMethods]> { + let summary = "Create a struct from constituent parts."; + // FIXME: Better constraint type + let arguments = (ins Variadic:$input); + let results = (outs StructType:$result); + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def StructExtractOp : P4HIR_Op<"struct_extract", + [Pure, + DeclareOpInterfaceMethods + ]> { + let summary = "Extract a named field from a struct."; + let description = [{ + ``` + %result = p4hir.struct_extract %input["field"] : !p4hir.struct + TODO: Support the nested extractions + %result2 = p4hir.struct_extract %input["field1", "field2"] : !p4hir.struct> + + ``` + }]; + + let arguments = (ins StructType:$input, I32Attr:$fieldIndex); + // FIXME: Better constraint type + let results = (outs AnyP4Type:$result); + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "mlir::Value":$input, "StructType::FieldInfo":$field)>, + OpBuilder<(ins "mlir::Value":$input, "mlir::StringAttr":$fieldName)>, + OpBuilder<(ins "mlir::Value":$input, "llvm::StringRef":$fieldName), [{ + build($_builder, $_state, input, $_builder.getStringAttr(fieldName)); + }]> + ]; + + let extraClassDeclaration = [{ + /// Return the name attribute of the accessed field. + mlir::StringAttr getFieldNameAttr() { + StructType type = getInput().getType(); + return type.getElements()[getFieldIndex()].name; + } + + /// Return the name of the accessed field. + llvm::StringRef getFieldName() { + return getFieldNameAttr().getValue(); + } + }]; +} + +def StructExtractRefOp : P4HIR_Op<"struct_extract_ref", + [Pure, + DeclareOpInterfaceMethods + ]> { + let summary = "Project from a struct reference to a reference to a named struct field"; + let description = [{ + ``` + %result = p4hir.struct_extract_ref %input["field"] : > + TODO: Support the nested extractions + %result2 = p4hir.struct_extract_ref %input["field1", "field2"] : >> + + ``` + }]; + + let arguments = (ins StructRefType:$input, I32Attr:$fieldIndex); + let results = (outs ReferenceType:$result); + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "mlir::Value":$input, "StructType::FieldInfo":$field)>, + OpBuilder<(ins "mlir::Value":$input, "mlir::StringAttr":$fieldName)>, + OpBuilder<(ins "mlir::Value":$input, "llvm::StringRef":$fieldName), [{ + build($_builder, $_state, input, $_builder.getStringAttr(fieldName)); + }]> + ]; + + let extraClassDeclaration = [{ + /// Return the name attribute of the accessed field. + mlir::StringAttr getFieldNameAttr() { + auto type = mlir::cast(mlir::cast(getInput().getType()).getObjectType()); + return type.getElements()[getFieldIndex()].name; + } + + /// Return the name of the accessed field. + llvm::StringRef getFieldName() { + return getFieldNameAttr().getValue(); + } + }]; +} + #endif // P4MLIR_DIALECT_P4HIR_P4HIR_OPS_TD diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.h b/include/p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.h new file mode 100644 index 0000000..459c6a4 --- /dev/null +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.h @@ -0,0 +1,27 @@ +#ifndef P4MLIR_DIALECT_P4HIR_P4HIR_TYPEINTERFACES_H +#define P4MLIR_DIALECT_P4HIR_P4HIR_TYPEINTERFACES_H + +#include "mlir/IR/Types.h" + +namespace P4::P4MLIR::P4HIR { +namespace FieldIdImpl { +unsigned getMaxFieldID(::mlir::Type); + +std::pair<::mlir::Type, unsigned> getSubTypeByFieldID(::mlir::Type, unsigned fieldID); + +::mlir::Type getFinalTypeByFieldID(::mlir::Type type, unsigned fieldID); + +std::pair projectToChildFieldID(::mlir::Type, unsigned fieldID, unsigned index); + +std::pair getIndexAndSubfieldID(::mlir::Type type, unsigned fieldID); + +unsigned getFieldID(::mlir::Type type, unsigned index); + +unsigned getIndexForFieldID(::mlir::Type type, unsigned fieldID); + +} // namespace FieldIdImpl +} // namespace P4::P4MLIR::P4HIR + +#include "p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.h.inc" + +#endif // P4MLIR_DIALECT_P4HIR_P4HIR_TYPEINTERFACES_H diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.td b/include/p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.td new file mode 100644 index 0000000..d09e5c4 --- /dev/null +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.td @@ -0,0 +1,72 @@ +#ifndef P4MLIR_DIALECT_P4HIR_P4HIR_TYPEINTERFACES_TD +#define P4MLIR_DIALECT_P4HIR_P4HIR_TYPEINTERFACES_TD + +include "mlir/IR/OpBase.td" + +def FieldIDTypeInterface : TypeInterface<"FieldIDTypeInterface"> { + let description = [{ + Common methods for types which can be indexed by a FieldID. + FieldID is a depth-first numbering of the elements of a type. For example: + ``` + struct a /* 0 */ { + int b; /* 1 */ + struct c /* 2 */ { + int d; /* 3 */ + } + } + + int e; /* 0 */ + ``` + }]; + + let methods = [ + InterfaceMethod<"Get the maximum field ID for this type", + "unsigned", "getMaxFieldID">, + + InterfaceMethod<[{ + Get the sub-type of a type for a field ID, and the subfield's ID. Strip + off a single layer of this type and return the sub-type and a field ID + targeting the same field, but rebased on the sub-type. + + The resultant type *may* not be a FieldIDTypeInterface if the resulting + fieldID is zero. This means that leaf types may be ground without + implementing an interface. An empty aggregate will also appear as a + zero. + }], + "std::pair<::mlir::Type, unsigned>", "getSubTypeByFieldID", (ins "unsigned":$fieldID)>, + + InterfaceMethod<[{ + Returns the effective field id when treating the index field as the + root of the type. Essentially maps a fieldID to a fieldID after a + subfield op. Returns the new id and whether the id is in the given + child. + }], + "std::pair", "projectToChildFieldID", (ins "unsigned":$fieldID, "unsigned":$index)>, + + InterfaceMethod<[{ + Returns the index (e.g. struct or vector element) for a given FieldID. + This returns the containing index in the case that the fieldID points to a + child field of a field. + }], + "unsigned", "getIndexForFieldID", (ins "unsigned":$fieldID)>, + + InterfaceMethod<[{ + Return the fieldID of a given index (e.g. struct or vector element). + Field IDs start at 1, and are assigned + to each field in a recursive depth-first walk of all + elements. A field ID of 0 is used to reference the type itself. + }], + "unsigned", "getFieldID", (ins "unsigned":$index)>, + + InterfaceMethod<[{ + Find the index of the element that contains the given fieldID. + As well, rebase the fieldID to the element. + }], + "std::pair", "getIndexAndSubfieldID", (ins "unsigned":$fieldID)>, + + ]; + + let cppNamespace = "::P4::P4MLIR::P4HIR"; +} + +#endif // P4MLIR_DIALECT_P4HIR_P4HIR_TYPEINTERFACES_TD diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.h b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.h index 58928c3..d296033 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.h +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.h @@ -6,9 +6,22 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" #include "p4mlir/Dialect/P4HIR/P4HIR_OpsEnums.h" +#include "p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.h" + +namespace P4::P4MLIR::P4HIR { + +namespace detail { +/// Struct defining a field. Used in structs. +struct FieldInfo { + mlir::StringAttr name; + mlir::Type type; +}; +} // namespace detail +} // namespace P4::P4MLIR::P4HIR #define GET_TYPEDEF_CLASSES #include "p4mlir/Dialect/P4HIR/P4HIR_Types.h.inc" -#endif // P4MLIR_DIALECT_P4HIR_P4HIR_TYPES_H +#endif // P4MLIR_DIALECT_P4HIR_P4HIR_TYPES_H diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td index eebe5e1..281a322 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td @@ -3,8 +3,10 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/EnumAttr.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.td" +include "p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.td" //===----------------------------------------------------------------------===// // P4HIR type definitions. @@ -122,6 +124,7 @@ def VoidType : P4HIR_Type<"Void", "void"> { llvm::StringRef getAlias() const { return "void"; }; }]; } + //===----------------------------------------------------------------------===// // ReferenceType //===----------------------------------------------------------------------===// @@ -152,6 +155,10 @@ def ReferenceType : P4HIR_Type<"Reference", "ref"> { let skipDefaultBuilders = 1; } +//===----------------------------------------------------------------------===// +// FuncType +//===----------------------------------------------------------------------===// + def FuncType : P4HIR_Type<"Func", "func"> { let summary = "P4 function-like type (actions, methods, functions)"; let description = [{ @@ -212,14 +219,56 @@ def FuncType : P4HIR_Type<"Func", "func"> { }]; } +//===----------------------------------------------------------------------===// +// StructType +//===----------------------------------------------------------------------===// + +// A packed struct. Declares the P4HIR::StructType in C++. +def StructType : P4HIR_Type<"Struct", "struct", [ + DeclareTypeInterfaceMethods, + DeclareTypeInterfaceMethods +]> { + let summary = "struct type"; + let description = [{ + Represents a structure of name, value pairs. + !p4hir.struct<"name", fieldName1: Type1, fieldName2: Type2> + }]; + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let parameters = ( + ins StringRefParameter<"struct name">:$name, ArrayRefParameter< + "P4HIR::StructType::FieldInfo", "struct fields">:$elements + ); + + let extraClassDeclaration = [{ + using FieldInfo = P4HIR::detail::FieldInfo; + mlir::Type getFieldType(mlir::StringRef fieldName); + void getInnerTypes(mlir::SmallVectorImpl&); + std::optional getFieldIndex(mlir::StringRef fieldName); + std::optional getFieldIndex(mlir::StringAttr fieldName); + }]; +} + //===----------------------------------------------------------------------===// // P4HIR type constraints. //===----------------------------------------------------------------------===// -def AnyP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, +def AnyP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType, DontcareType, ErrorType, UnknownType]> {} def CallResultP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, VoidType]> {} -def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType]> {} +def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, StructType]> {} + +/// A ref type with the specified constraints on the nested type. +class SpecificRefType : ConfinedType($_self).getObjectType()", + type.predicate>], + "ref of " # type.summary, "P4HIR::ReferenceType" +> { + Type objectType = type; +} +def StructRefType : SpecificRefType; #endif // P4MLIR_DIALECT_P4HIR_P4HIR_TYPES_TD diff --git a/lib/Dialect/P4HIR/CMakeLists.txt b/lib/Dialect/P4HIR/CMakeLists.txt index cb38ee5..a490454 100644 --- a/lib/Dialect/P4HIR/CMakeLists.txt +++ b/lib/Dialect/P4HIR/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(P4MLIR_P4HIR P4HIR_Ops.cpp P4HIR_Types.cpp P4HIR_Attrs.cpp + P4HIR_TypeInterfaces.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/p4mlir/Dialect/P4HIR diff --git a/lib/Dialect/P4HIR/P4HIR_Ops.cpp b/lib/Dialect/P4HIR/P4HIR_Ops.cpp index 29a51f7..2241762 100644 --- a/lib/Dialect/P4HIR/P4HIR_Ops.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Ops.cpp @@ -1,5 +1,6 @@ #include "p4mlir/Dialect/P4HIR/P4HIR_Ops.h" +#include "llvm/ADT/SmallString.h" #include "llvm/Support/LogicalResult.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -35,6 +36,13 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType, if (mlir::isa(attrType)) return success(); + if (mlir::isa(attrType)) { + if (!mlir::isa(opType)) + return op->emitOpError("result type (") << opType << ") is not an aggregate type"; + + return success(); + } + assert(isa(attrType) && "expected typed attribute"); return op->emitOpError("constant with type ") << cast(attrType).getType() << " not supported"; @@ -590,6 +598,221 @@ LogicalResult P4HIR::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable return success(); } +//===----------------------------------------------------------------------===// +// StructOp +//===----------------------------------------------------------------------===// + +ParseResult P4HIR::StructOp::parse(OpAsmParser &parser, OperationState &result) { + llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation(); + llvm::SmallVector operands; + Type declType; + + if (parser.parseLParen() || parser.parseOperandList(operands) || parser.parseRParen() || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(declType)) + return failure(); + + auto structType = mlir::dyn_cast(declType); + if (!structType) return parser.emitError(parser.getNameLoc(), "expected !p4hir.struct type"); + + llvm::SmallVector structInnerTypes; + structType.getInnerTypes(structInnerTypes); + result.addTypes(structType); + + if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc, result.operands)) + return failure(); + return success(); +} + +void P4HIR::StructOp::print(OpAsmPrinter &printer) { + printer << " ("; + printer.printOperands(getInput()); + printer << ")"; + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getType(); +} + +LogicalResult P4HIR::StructOp::verify() { + auto elements = mlir::cast(getType()).getElements(); + + if (elements.size() != getInput().size()) return emitOpError("struct field count mismatch"); + + for (const auto &[field, value] : llvm::zip(elements, getInput())) + if (field.type != value.getType()) + return emitOpError("struct field `") << field.name << "` type does not match"; + + return success(); +} + +void P4HIR::StructOp::getAsmResultNames(function_ref setNameFn) { + llvm::SmallString<32> name("struct_"); + name += getType().getName(); + setNameFn(getResult(), name); +} + +//===----------------------------------------------------------------------===// +// StructExtractOp +//===----------------------------------------------------------------------===// + +/// Ensure an aggregate op's field index is within the bounds of +/// the aggregate type and the accessed field is of 'elementType'. +template +static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op, AggregateType aggType, + Type elementType) { + auto index = op.getFieldIndex(); + if (index >= aggType.getElements().size()) + return op.emitOpError() << "field index " << index + << " exceeds element count of aggregate type"; + + if (elementType != aggType.getElements()[index].type) + return op.emitOpError() << "type " << aggType.getElements()[index].type + << " of accessed field in aggregate at index " << index + << " does not match expected type " << elementType; + + return success(); +} + +LogicalResult P4HIR::StructExtractOp::verify() { + return verifyAggregateFieldIndexAndType( + *this, getInput().getType(), getType()); +} + +template +static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand operand; + StringAttr fieldName; + AggregateType declType; + + if (parser.parseOperand(operand) || parser.parseLSquare() || parser.parseAttribute(fieldName) || + parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColon() || parser.parseCustomTypeWithFallback(declType)) + return failure(); + + auto fieldIndex = declType.getFieldIndex(fieldName); + if (!fieldIndex) { + parser.emitError(parser.getNameLoc(), + "field name '" + fieldName.getValue() + "' not found in aggregate type"); + return failure(); + } + + auto indexAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex); + result.addAttribute("fieldIndex", indexAttr); + Type resultType = declType.getElements()[*fieldIndex].type; + result.addTypes(resultType); + + if (parser.resolveOperand(operand, declType, result.operands)) return failure(); + return success(); +} + +template +static ParseResult parseExtractRefOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand operand; + StringAttr fieldName; + P4HIR::ReferenceType declType; + + if (parser.parseOperand(operand) || parser.parseLSquare() || parser.parseAttribute(fieldName) || + parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColon() || parser.parseCustomTypeWithFallback(declType)) + return failure(); + + auto aggType = mlir::dyn_cast(declType.getObjectType()); + if (!aggType) { + parser.emitError(parser.getNameLoc(), "expected reference to aggregate type"); + return failure(); + } + auto fieldIndex = aggType.getFieldIndex(fieldName); + if (!fieldIndex) { + parser.emitError(parser.getNameLoc(), + "field name '" + fieldName.getValue() + "' not found in aggregate type"); + return failure(); + } + + auto indexAttr = IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex); + result.addAttribute("fieldIndex", indexAttr); + Type resultType = P4HIR::ReferenceType::get(aggType.getElements()[*fieldIndex].type); + result.addTypes(resultType); + + if (parser.resolveOperand(operand, declType, result.operands)) return failure(); + return success(); +} + +/// Use the same printer for both struct_extract and struct_extract_ref since the +/// syntax is identical. +template +static void printExtractOp(OpAsmPrinter &printer, AggType op) { + printer << " "; + printer.printOperand(op.getInput()); + printer << "[\"" << op.getFieldName() << "\"]"; + printer.printOptionalAttrDict(op->getAttrs(), {"fieldIndex"}); + printer << " : "; + auto type = op.getInput().getType(); + if (auto validType = mlir::dyn_cast(type)) + printer.printStrippedAttrOrType(validType); + else + printer << type; +} + +ParseResult P4HIR::StructExtractOp::parse(OpAsmParser &parser, OperationState &result) { + return parseExtractOp(parser, result); +} + +void P4HIR::StructExtractOp::print(OpAsmPrinter &printer) { printExtractOp(printer, *this); } + +void P4HIR::StructExtractOp::build(OpBuilder &builder, OperationState &odsState, Value input, + StructType::FieldInfo field) { + auto fieldIndex = mlir::cast(input.getType()).getFieldIndex(field.name); + assert(fieldIndex.has_value() && "field name not found in aggregate type"); + build(builder, odsState, field.type, input, *fieldIndex); +} + +void P4HIR::StructExtractOp::build(OpBuilder &builder, OperationState &odsState, Value input, + StringAttr fieldName) { + auto structType = mlir::cast(input.getType()); + auto fieldIndex = structType.getFieldIndex(fieldName); + assert(fieldIndex.has_value() && "field name not found in aggregate type"); + auto resultType = structType.getElements()[*fieldIndex].type; + build(builder, odsState, resultType, input, *fieldIndex); +} + +void P4HIR::StructExtractOp::getAsmResultNames(function_ref setNameFn) { + setNameFn(getResult(), getFieldName()); +} + +void P4HIR::StructExtractRefOp::getAsmResultNames(function_ref setNameFn) { + setNameFn(getResult(), getFieldName()); +} + +ParseResult P4HIR::StructExtractRefOp::parse(OpAsmParser &parser, OperationState &result) { + return parseExtractRefOp(parser, result); +} + +void P4HIR::StructExtractRefOp::print(OpAsmPrinter &printer) { printExtractOp(printer, *this); } + +LogicalResult P4HIR::StructExtractRefOp::verify() { + auto type = + mlir::cast(mlir::cast(getInput().getType()).getObjectType()); + return verifyAggregateFieldIndexAndType( + *this, type, getType().getObjectType()); +} + +void P4HIR::StructExtractRefOp::build(OpBuilder &builder, OperationState &odsState, Value input, + StructType::FieldInfo field) { + auto structType = + mlir::cast(mlir::cast(input.getType()).getObjectType()); + auto fieldIndex = structType.getFieldIndex(field.name); + assert(fieldIndex.has_value() && "field name not found in aggregate type"); + build(builder, odsState, ReferenceType::get(field.type), input, *fieldIndex); +} + +void P4HIR::StructExtractRefOp::build(OpBuilder &builder, OperationState &odsState, Value input, + StringAttr fieldName) { + auto structType = + mlir::cast(mlir::cast(input.getType()).getObjectType()); + auto fieldIndex = structType.getFieldIndex(fieldName); + assert(fieldIndex.has_value() && "field name not found in aggregate type"); + auto resultType = ReferenceType::get(structType.getElements()[*fieldIndex].type); + build(builder, odsState, resultType, input, *fieldIndex); +} + namespace { struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; @@ -610,6 +833,11 @@ struct P4HIROpAsmDialectInterface : public OpAsmDialectInterface { return AliasResult::OverridableAlias; } + if (auto structType = mlir::dyn_cast(type)) { + os << structType.getName(); + return AliasResult::OverridableAlias; + } + return AliasResult::NoAlias; } diff --git a/lib/Dialect/P4HIR/P4HIR_TypeInterfaces.cpp b/lib/Dialect/P4HIR/P4HIR_TypeInterfaces.cpp new file mode 100644 index 0000000..914a174 --- /dev/null +++ b/lib/Dialect/P4HIR/P4HIR_TypeInterfaces.cpp @@ -0,0 +1,55 @@ +#include "p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.h" + +using namespace mlir; +using namespace P4::P4MLIR::P4HIR; +using namespace FieldIdImpl; + +mlir::Type FieldIdImpl::getFinalTypeByFieldID(mlir::Type type, unsigned fieldID) { + std::pair pair(type, fieldID); + while (pair.second) { + if (auto ftype = dyn_cast(pair.first)) + pair = ftype.getSubTypeByFieldID(pair.second); + else + llvm::report_fatal_error("fieldID indexing into a non-aggregate type"); + } + return pair.first; +} + +std::pair FieldIdImpl::getSubTypeByFieldID(mlir::Type type, unsigned fieldID) { + if (!fieldID) return {type, 0}; + if (auto ftype = dyn_cast(type)) + return ftype.getSubTypeByFieldID(fieldID); + + llvm::report_fatal_error("fieldID indexing into a non-aggregate type"); +} + +unsigned FieldIdImpl::getMaxFieldID(mlir::Type type) { + if (auto ftype = dyn_cast(type)) return ftype.getMaxFieldID(); + return 0; +} + +std::pair FieldIdImpl::projectToChildFieldID(mlir::Type type, unsigned fieldID, + unsigned index) { + if (auto ftype = dyn_cast(type)) + return ftype.projectToChildFieldID(fieldID, index); + return {0, fieldID == 0}; +} + +unsigned FieldIdImpl::getIndexForFieldID(mlir::Type type, unsigned fieldID) { + if (auto ftype = dyn_cast(type)) return ftype.getIndexForFieldID(fieldID); + return 0; +} + +unsigned FieldIdImpl::getFieldID(mlir::Type type, unsigned fieldID) { + if (auto ftype = dyn_cast(type)) return ftype.getFieldID(fieldID); + return 0; +} + +std::pair FieldIdImpl::getIndexAndSubfieldID(mlir::Type type, + unsigned fieldID) { + if (auto ftype = dyn_cast(type)) + return ftype.getIndexAndSubfieldID(fieldID); + return {0, fieldID == 0}; +} + +#include "p4mlir/Dialect/P4HIR/P4HIR_TypeInterfaces.cpp.inc" diff --git a/lib/Dialect/P4HIR/P4HIR_Types.cpp b/lib/Dialect/P4HIR/P4HIR_Types.cpp index 85b5b0e..c79a0b0 100644 --- a/lib/Dialect/P4HIR/P4HIR_Types.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Types.cpp @@ -1,11 +1,16 @@ #include "p4mlir/Dialect/P4HIR/P4HIR_Types.h" +#include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.h" #include "p4mlir/Dialect/P4HIR/P4HIR_OpsEnums.h" +using namespace mlir; +using namespace P4::P4MLIR::P4HIR; +using namespace P4::P4MLIR::P4HIR::detail; + static mlir::ParseResult parseFuncType(mlir::AsmParser &p, mlir::Type &optionalResultType, llvm::SmallVector ¶ms); @@ -15,9 +20,6 @@ static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalResultType, #define GET_TYPEDEF_CLASSES #include "p4mlir/Dialect/P4HIR/P4HIR_Types.cpp.inc" -using namespace mlir; -using namespace P4::P4MLIR::P4HIR; - void BitsType::print(mlir::AsmPrinter &printer) const { printer << (isSigned() ? "int" : "bit") << '<' << getWidth() << '>'; } @@ -137,6 +139,182 @@ bool FuncType::isVoid() const { return !rt; } +namespace P4::P4MLIR::P4HIR::detail { +bool operator==(const FieldInfo &a, const FieldInfo &b) { + return a.name == b.name && a.type == b.type; +} +llvm::hash_code hash_value(const FieldInfo &fi) { return llvm::hash_combine(fi.name, fi.type); } +} // namespace P4::P4MLIR::P4HIR::detail + +/// Parse a list of unique field names and types within <> plus name. E.g.: +/// +static ParseResult parseFields(AsmParser &p, std::string &name, + SmallVectorImpl ¶meters) { + llvm::StringSet<> nameSet; + bool hasDuplicateName = false; + bool parsedName = false; + auto parseResult = + p.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult { + // First, try to parse name + if (!parsedName) { + if (p.parseKeywordOrString(&name)) return failure(); + parsedName = true; + return success(); + } + + // Parse fields + std::string fieldName; + Type fieldType; + + auto fieldLoc = p.getCurrentLocation(); + if (p.parseKeywordOrString(&fieldName) || p.parseColon() || p.parseType(fieldType)) + return failure(); + + if (!nameSet.insert(fieldName).second) { + p.emitError(fieldLoc, "duplicate field name \'" + name + "\'"); + // Continue parsing to print all duplicates, but make sure to error + // eventually + hasDuplicateName = true; + } + + parameters.push_back(FieldInfo{StringAttr::get(p.getContext(), fieldName), fieldType}); + return success(); + }); + + if (hasDuplicateName) return failure(); + return parseResult; +} + +/// Print out a list of named fields surrounded by <>. +static void printFields(AsmPrinter &p, StringRef name, ArrayRef fields) { + p << '<'; + p.printString(name); + if (!fields.empty()) p << ", "; + llvm::interleaveComma(fields, p, [&](const FieldInfo &field) { + p.printKeywordOrString(field.name.getValue()); + p << ": " << field.type; + }); + p << ">"; +} + +Type StructType::parse(AsmParser &p) { + llvm::SmallVector parameters; + std::string name; + if (parseFields(p, name, parameters)) return {}; + return get(p.getContext(), name, parameters); +} + +LogicalResult StructType::verify(function_ref emitError, StringRef, + ArrayRef elements) { + llvm::SmallDenseSet fieldNameSet; + LogicalResult result = success(); + fieldNameSet.reserve(elements.size()); + for (const auto &elt : elements) + if (!fieldNameSet.insert(elt.name).second) { + result = failure(); + emitError() << "duplicate field name '" << elt.name.getValue() + << "' in p4hir.struct type"; + } + return result; +} + +void StructType::print(AsmPrinter &p) const { printFields(p, getName(), getElements()); } + +Type StructType::getFieldType(mlir::StringRef fieldName) { + for (const auto &field : getElements()) + if (field.name == fieldName) return field.type; + return Type(); +} + +std::optional StructType::getFieldIndex(mlir::StringRef fieldName) { + ArrayRef elems = getElements(); + for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx) + if (elems[idx].name == fieldName) return idx; + return {}; +} + +std::optional StructType::getFieldIndex(mlir::StringAttr fieldName) { + ArrayRef elems = getElements(); + for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx) + if (elems[idx].name == fieldName) return idx; + return {}; +} + +static std::pair> getFieldIDsStruct(const StructType &st) { + unsigned fieldID = 0; + auto elements = st.getElements(); + SmallVector fieldIDs; + fieldIDs.reserve(elements.size()); + for (auto &element : elements) { + auto type = element.type; + fieldID += 1; + fieldIDs.push_back(fieldID); + // Increment the field ID for the next field by the number of subfields. + fieldID += FieldIdImpl::getMaxFieldID(type); + } + return {fieldID, fieldIDs}; +} + +std::pair StructType::getSubTypeByFieldID(unsigned fieldID) const { + if (fieldID == 0) return {*this, 0}; + auto [maxId, fieldIDs] = getFieldIDsStruct(*this); + auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID)); + auto subfieldIndex = std::distance(fieldIDs.begin(), it); + auto subfieldType = getElements()[subfieldIndex].type; + auto subfieldID = fieldID - fieldIDs[subfieldIndex]; + return {subfieldType, subfieldID}; +} + +Type StructType::getTypeAtIndex(Attribute index) const { + auto indexAttr = llvm::dyn_cast(index); + if (!indexAttr) return {}; + + return getSubTypeByFieldID(indexAttr.getInt()).first; +} + +unsigned StructType::getFieldID(unsigned index) const { + auto [maxId, fieldIDs] = getFieldIDsStruct(*this); + return fieldIDs[index]; +} + +unsigned StructType::getMaxFieldID() const { + unsigned fieldID = 0; + for (const auto &field : getElements()) fieldID += 1 + FieldIdImpl::getMaxFieldID(field.type); + return fieldID; +} + +unsigned StructType::getIndexForFieldID(unsigned fieldID) const { + assert(!getElements().empty() && "struct must have >0 fields"); + auto [maxId, fieldIDs] = getFieldIDsStruct(*this); + auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID)); + return std::distance(fieldIDs.begin(), it); +} + +std::pair StructType::getIndexAndSubfieldID(unsigned fieldID) const { + auto index = getIndexForFieldID(fieldID); + auto elementFieldID = getFieldID(index); + return {index, fieldID - elementFieldID}; +} + +std::optional> StructType::getSubelementIndexMap() const { + DenseMap destructured; + for (auto [i, field] : llvm::enumerate(getElements())) + destructured.try_emplace(IntegerAttr::get(IndexType::get(getContext()), i), field.type); + return destructured; +} + +std::pair StructType::projectToChildFieldID(unsigned fieldID, + unsigned index) const { + auto [maxId, fieldIDs] = getFieldIDsStruct(*this); + auto childRoot = fieldIDs[index]; + auto rangeEnd = index + 1 >= getElements().size() ? maxId : (fieldIDs[index + 1] - 1); + return std::make_pair(fieldID - childRoot, fieldID >= childRoot && fieldID <= rangeEnd); +} + +void StructType::getInnerTypes(SmallVectorImpl &types) { + for (const auto &field : getElements()) types.push_back(field.type); +} + void P4HIRDialect::registerTypes() { addTypes< #define GET_TYPEDEF_LIST diff --git a/test/Dialect/P4HIR/struct.mlir b/test/Dialect/P4HIR/struct.mlir new file mode 100644 index 0000000..a209e60 --- /dev/null +++ b/test/Dialect/P4HIR/struct.mlir @@ -0,0 +1,37 @@ +// RUN: p4mlir-opt %s | FileCheck %s + +!i32i = !p4hir.int<32> +!T = !p4hir.struct<"T", t1: !i32i, t2: !i32i> +!S = !p4hir.struct<"S", s1: !T, s2: !T> +!Empty = !p4hir.struct<"Empty"> +!b9i = !p4hir.bit<9> +!PortId_t = !p4hir.struct<"PortId_t", _v: !b9i> + +#int10_i32i = #p4hir.int<10> : !i32i +#int20_i32i = #p4hir.int<20> : !i32i +#int1_b9i = #p4hir.int<1> : !b9i + +// CHECK: module +module { + %e = p4hir.const ["e"] #p4hir.aggregate<[]> : !Empty + %t = p4hir.const ["t"] #p4hir.aggregate<[#int10_i32i, #int20_i32i]> : !T + + p4hir.func action @test2(%arg0: !p4hir.ref {p4hir.dir = #p4hir}) { + %_v = p4hir.struct_extract_ref %arg0["_v"] : + %val = p4hir.read %arg0 : + %_v_0 = p4hir.struct_extract %val["_v"] : !PortId_t + %c1_b9i = p4hir.const #int1_b9i + %add = p4hir.binop(add, %_v_0, %c1_b9i) : !b9i + p4hir.assign %add, %_v : + p4hir.return + } + + p4hir.func action @test() { + %vv = p4hir.variable ["vv"] : + %val = p4hir.read %vv : + %0 = p4hir.struct (%val) : !PortId_t + %p1 = p4hir.variable ["p1", init] : + p4hir.assign %0, %p1 : + p4hir.return + } +} diff --git a/test/Dialect/P4HIR/types.mlir b/test/Dialect/P4HIR/types.mlir index b71d2ed..ac45688 100644 --- a/test/Dialect/P4HIR/types.mlir +++ b/test/Dialect/P4HIR/types.mlir @@ -3,11 +3,16 @@ !unknown = !p4hir.unknown !error = !p4hir.error !dontcare = !p4hir.dontcare +!bit42 = !p4hir.bit<42> !ref = !p4hir.ref> +!void = !p4hir.void !action_noparams = !p4hir.func<()> !action_params = !p4hir.func<(!p4hir.int<42>, !ref, !p4hir.int<42>, !p4hir.bool)> +!struct = !p4hir.struct<"struct_name", boolfield : !p4hir.bool, bitfield : !bit42> +!nested_struct = !p4hir.struct<"another_name", neststructfield : !struct, bitfield : !bit42> + // No need to check stuff. If it parses, it's fine. // CHECK: module module { diff --git a/test/Translate/Ops/constants.p4 b/test/Translate/Ops/constants.p4 index 5c5b17e..366339e 100644 --- a/test/Translate/Ops/constants.p4 +++ b/test/Translate/Ops/constants.p4 @@ -20,8 +20,6 @@ const int<32> btwotwo = (int<32>)btwo; const bit<32> btwothree = (bit<32>)btwotwo; const bit<6> btwofour = (bit<6>)(bit<32>)(int<32>)btwo; -// TODO: Support constant structs -/* struct S { bit<32> a; bit<32> b; @@ -39,7 +37,6 @@ const T zz = { { 32w0, 32w1 }, { 32w2, 32w3 } }; -*/ const bit<32> x = 32w0; const bit<32> x1 = ~32w0; @@ -122,9 +119,10 @@ const int<1> szz3 = (int<1>) szz2[0:0]; // CHECK: #[[$ATTR_23:.+]] = #p4hir.int<2> : !b6i // CHECK: #[[$ATTR_24:.+]] = #p4hir.int<2> : !i32i // CHECK: #[[$ATTR_25:.+]] = #p4hir.int<2> : !i8i -// CHECK: #[[$ATTR_26:.+]] = #p4hir.int<5> : !b4i -// CHECK: #[[$ATTR_27:.+]] = #p4hir.int<5> : !b7i -// CHECK: #[[$ATTR_28:.+]] = #p4hir.int<5> : !infint +// CHECK: #[[$ATTR_26:.+]] = #p4hir.int<3> : !b32i +// CHECK: #[[$ATTR_27:.+]] = #p4hir.int<5> : !b4i +// CHECK: #[[$ATTR_28:.+]] = #p4hir.int<5> : !b7i +// CHECK: #[[$ATTR_29:.+]] = #p4hir.int<5> : !infint // CHECK-LABEL: module // CHECK-NEXT: %[[VAL_0:.*]] = p4hir.const ["bzero"] #[[$ATTR_14]] @@ -139,48 +137,51 @@ const int<1> szz3 = (int<1>) szz2[0:0]; // CHECK: %[[VAL_9:.*]] = p4hir.const ["btwotwo"] #[[$ATTR_24]] // CHECK: %[[VAL_10:.*]] = p4hir.const ["btwothree"] #[[$ATTR_22]] // CHECK: %[[VAL_11:.*]] = p4hir.const ["btwofour"] #[[$ATTR_23]] -// CHECK: %[[VAL_12:.*]] = p4hir.const ["x"] #[[$ATTR_14]] -// CHECK: %[[VAL_13:.*]] = p4hir.const ["x1"] #[[$ATTR_4]] -// CHECK: %[[VAL_14:.*]] = p4hir.const ["izero"] #[[$ATTR_17]] -// CHECK: %[[VAL_15:.*]] = p4hir.const ["fa"] #[[$ATTR_26]] -// CHECK: %[[VAL_16:.*]] = p4hir.const ["fb"] #[[$ATTR_28]] -// CHECK: %[[VAL_17:.*]] = p4hir.const ["fc"] #[[$ATTR_27]] -// CHECK: %[[VAL_18:.*]] = p4hir.const ["fd"] #[[$ATTR_7]] -// CHECK: %[[VAL_19:.*]] = p4hir.const ["fe"] #[[$ATTR_10]] -// CHECK: %[[VAL_20:.*]] = p4hir.const ["ff"] #[[$ATTR_5]] -// CHECK: %[[VAL_21:.*]] = p4hir.const ["fg"] #[[$ATTR_8]] -// CHECK: %[[VAL_22:.*]] = p4hir.const ["fh"] #[[$ATTR_10]] -// CHECK: %[[VAL_23:.*]] = p4hir.const ["sa"] #[[$ATTR_18]] -// CHECK: %[[VAL_24:.*]] = p4hir.const ["sb"] #[[$ATTR_9]] -// CHECK: %[[VAL_25:.*]] = p4hir.const ["sc"] #[[$ATTR_12]] -// CHECK: %[[VAL_26:.*]] = p4hir.const ["sd"] #[[$ATTR_2]] -// CHECK: %[[VAL_27:.*]] = p4hir.const ["se"] #[[$ATTR_3]] -// CHECK: %[[VAL_28:.*]] = p4hir.const ["sf"] #[[$ATTR_19]] -// CHECK: %[[VAL_29:.*]] = p4hir.const ["sg"] #[[$ATTR_21]] -// CHECK: %[[VAL_30:.*]] = p4hir.const ["sh"] #[[$ATTR_18]] -// CHECK: %[[VAL_31:.*]] = p4hir.const ["si"] #[[$ATTR_21]] -// CHECK: %[[VAL_32:.*]] = p4hir.const ["sj"] #[[$ATTR_25]] -// CHECK: %[[VAL_33:.*]] = p4hir.const ["sk"] #[[$ATTR_19]] -// CHECK: %[[VAL_34:.*]] = p4hir.const ["sl"] #[[$ATTR_3]] -// CHECK: %[[VAL_35:.*]] = p4hir.const ["sm"] #[[$ATTR_2]] -// CHECK: %[[VAL_36:.*]] = p4hir.const ["sn"] #[[$ATTR_9]] -// CHECK: %[[VAL_37:.*]] = p4hir.const ["so"] #[[$ATTR_18]] -// CHECK: %[[VAL_38:.*]] = p4hir.const ["sa0"] #[[$ATTR_18]] -// CHECK: %[[VAL_39:.*]] = p4hir.const ["sb0"] #[[$ATTR_9]] -// CHECK: %[[VAL_40:.*]] = p4hir.const ["sc0"] #[[$ATTR_12]] -// CHECK: %[[VAL_41:.*]] = p4hir.const ["sd0"] #[[$ATTR_2]] -// CHECK: %[[VAL_42:.*]] = p4hir.const ["se0"] #[[$ATTR_3]] -// CHECK: %[[VAL_43:.*]] = p4hir.const ["sf0"] #[[$ATTR_19]] -// CHECK: %[[VAL_44:.*]] = p4hir.const ["sg0"] #[[$ATTR_21]] -// CHECK: %[[VAL_45:.*]] = p4hir.const ["sh0"] #[[$ATTR_18]] -// CHECK: %[[VAL_46:.*]] = p4hir.const ["si0"] #[[$ATTR_21]] -// CHECK: %[[VAL_47:.*]] = p4hir.const ["sj0"] #[[$ATTR_25]] -// CHECK: %[[VAL_48:.*]] = p4hir.const ["sk0"] #[[$ATTR_19]] -// CHECK: %[[VAL_49:.*]] = p4hir.const ["sl0"] #[[$ATTR_3]] -// CHECK: %[[VAL_50:.*]] = p4hir.const ["sm0"] #[[$ATTR_2]] -// CHECK: %[[VAL_51:.*]] = p4hir.const ["sn0"] #[[$ATTR_9]] -// CHECK: %[[VAL_52:.*]] = p4hir.const ["so0"] #[[$ATTR_18]] -// CHECK: %[[VAL_53:.*]] = p4hir.const ["szz0"] #[[$ATTR_16]] -// CHECK: %[[VAL_54:.*]] = p4hir.const ["szz1"] #[[$ATTR_6]] -// CHECK: %[[VAL_55:.*]] = p4hir.const ["szz2"] #[[$ATTR_11]] -// CHECK: %[[VAL_56:.*]] = p4hir.const ["szz3"] #[[$ATTR_16]] +// CHECK: %[[VAL_12:.*]] = p4hir.const ["v"] #p4hir.aggregate<[#[[$ATTR_26]], #[[$ATTR_20]]]> : !S +// CHECK: %[[VAL_13:.*]] = p4hir.const ["zz"] #p4hir.aggregate<[#p4hir.aggregate<[#[[$ATTR_14]], #[[$ATTR_20]]]> : !S, #p4hir.aggregate<[#[[$ATTR_22]], #[[$ATTR_26]]]> : !S]> : !T +// CHECK: %[[VAL_14:.*]] = p4hir.const ["x"] #[[$ATTR_14]] +// CHECK: %[[VAL_15:.*]] = p4hir.const ["x1"] #[[$ATTR_4]] +// CHECK: %[[VAL_16:.*]] = p4hir.const ["izero"] #[[$ATTR_17]] +// CHECK: %[[VAL_17:.*]] = p4hir.const ["fa"] #[[$ATTR_27]] +// CHECK: %[[VAL_18:.*]] = p4hir.const ["fb"] #[[$ATTR_29]] +// CHECK: %[[VAL_19:.*]] = p4hir.const ["fc"] #[[$ATTR_28]] +// CHECK: %[[VAL_20:.*]] = p4hir.const ["fd"] #[[$ATTR_7]] +// CHECK: %[[VAL_21:.*]] = p4hir.const ["fe"] #[[$ATTR_10]] +// CHECK: %[[VAL_22:.*]] = p4hir.const ["ff"] #[[$ATTR_5]] +// CHECK: %[[VAL_23:.*]] = p4hir.const ["fg"] #[[$ATTR_8]] +// CHECK: %[[VAL_24:.*]] = p4hir.const ["fh"] #[[$ATTR_10]] +// CHECK: %[[VAL_25:.*]] = p4hir.const ["sa"] #[[$ATTR_18]] +// CHECK: %[[VAL_26:.*]] = p4hir.const ["sb"] #[[$ATTR_9]] +// CHECK: %[[VAL_27:.*]] = p4hir.const ["sc"] #[[$ATTR_12]] +// CHECK: %[[VAL_28:.*]] = p4hir.const ["sd"] #[[$ATTR_2]] +// CHECK: %[[VAL_29:.*]] = p4hir.const ["se"] #[[$ATTR_3]] +// CHECK: %[[VAL_30:.*]] = p4hir.const ["sf"] #[[$ATTR_19]] +// CHECK: %[[VAL_31:.*]] = p4hir.const ["sg"] #[[$ATTR_21]] +// CHECK: %[[VAL_32:.*]] = p4hir.const ["sh"] #[[$ATTR_18]] +// CHECK: %[[VAL_33:.*]] = p4hir.const ["si"] #[[$ATTR_21]] +// CHECK: %[[VAL_34:.*]] = p4hir.const ["sj"] #[[$ATTR_25]] +// CHECK: %[[VAL_35:.*]] = p4hir.const ["sk"] #[[$ATTR_19]] +// CHECK: %[[VAL_36:.*]] = p4hir.const ["sl"] #[[$ATTR_3]] +// CHECK: %[[VAL_37:.*]] = p4hir.const ["sm"] #[[$ATTR_2]] +// CHECK: %[[VAL_38:.*]] = p4hir.const ["sn"] #[[$ATTR_9]] +// CHECK: %[[VAL_39:.*]] = p4hir.const ["so"] #[[$ATTR_18]] +// CHECK: %[[VAL_40:.*]] = p4hir.const ["sa0"] #[[$ATTR_18]] +// CHECK: %[[VAL_41:.*]] = p4hir.const ["sb0"] #[[$ATTR_9]] +// CHECK: %[[VAL_42:.*]] = p4hir.const ["sc0"] #[[$ATTR_12]] +// CHECK: %[[VAL_43:.*]] = p4hir.const ["sd0"] #[[$ATTR_2]] +// CHECK: %[[VAL_44:.*]] = p4hir.const ["se0"] #[[$ATTR_3]] +// CHECK: %[[VAL_45:.*]] = p4hir.const ["sf0"] #[[$ATTR_19]] +// CHECK: %[[VAL_46:.*]] = p4hir.const ["sg0"] #[[$ATTR_21]] +// CHECK: %[[VAL_47:.*]] = p4hir.const ["sh0"] #[[$ATTR_18]] +// CHECK: %[[VAL_48:.*]] = p4hir.const ["si0"] #[[$ATTR_21]] +// CHECK: %[[VAL_49:.*]] = p4hir.const ["sj0"] #[[$ATTR_25]] +// CHECK: %[[VAL_50:.*]] = p4hir.const ["sk0"] #[[$ATTR_19]] +// CHECK: %[[VAL_51:.*]] = p4hir.const ["sl0"] #[[$ATTR_3]] +// CHECK: %[[VAL_52:.*]] = p4hir.const ["sm0"] #[[$ATTR_2]] +// CHECK: %[[VAL_53:.*]] = p4hir.const ["sn0"] #[[$ATTR_9]] +// CHECK: %[[VAL_54:.*]] = p4hir.const ["so0"] #[[$ATTR_18]] +// CHECK: %[[VAL_55:.*]] = p4hir.const ["szz0"] #[[$ATTR_16]] +// CHECK: %[[VAL_56:.*]] = p4hir.const ["szz1"] #[[$ATTR_6]] +// CHECK: %[[VAL_57:.*]] = p4hir.const ["szz2"] #[[$ATTR_11]] +// CHECK: %[[VAL_58:.*]] = p4hir.const ["szz3"] #[[$ATTR_16]] + diff --git a/test/Translate/Ops/struct.p4 b/test/Translate/Ops/struct.p4 new file mode 100644 index 0000000..d7a9893 --- /dev/null +++ b/test/Translate/Ops/struct.p4 @@ -0,0 +1,111 @@ +// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s + +struct P { + bit<32> f1; + bit<32> f2; +} + +struct T { + int<32> t1; + int<32> t2; +} + +struct S { + T s1; + T s2; +} + +struct Empty {}; + +// CHECK: !Empty = !p4hir.struct<"Empty"> +// CHECK: !PortId_t = !p4hir.struct<"PortId_t", _v: !b9i> +// CHECK: !T = !p4hir.struct<"T", t1: !i32i, t2: !i32i> +// CHECK: !S = !p4hir.struct<"S", s1: !T, s2: !T> +// CHECK: !metadata_t = !p4hir.struct<"metadata_t", foo: !PortId_t> + +// CHECK-LABEL: module + +const T t = { 32s10, 32s20 }; +const S s = { { 32s15, 32s25}, t }; + +// CHECK: %t = p4hir.const ["t"] #p4hir.aggregate<[#int10_i32i, #int20_i32i]> : !T +// CHECK: %s = p4hir.const ["s"] #p4hir.aggregate<[#p4hir.aggregate<[#int15_i32i, #int25_i32i]> : !T, #p4hir.aggregate<[#int10_i32i, #int20_i32i]> : !T]> : !S + +const int<32> x = t.t1; +const int<32> y = s.s1.t2; + +const int<32> w = .t.t1; + +// CHECK: %x = p4hir.const ["x"] #int10_i32i +// CHECK: %y = p4hir.const ["y"] #int25_i32i +// CHECK: %w = p4hir.const ["w"] #int10_i32i + +const T tt1 = s.s1; +const Empty e = {}; + +// CHECK: %tt1 = p4hir.const ["tt1"] #p4hir.aggregate<[#int15_i32i, #int25_i32i]> : !T +// CHECK: %e = p4hir.const ["e"] #p4hir.aggregate<[]> : !Empty + +const T t1 = { 10, 20 }; +const S s1 = { { 15, 25 }, t1 }; + +const int<32> x1 = t1.t1; +const int<32> y1 = s1.s1.t2; + +const int<32> w1 = .t1.t1; + +const T t2 = s1.s1; + +struct PortId_t { bit<9> _v; } + +const PortId_t PSA_CPU_PORT = { _v = 9w192 }; + +struct metadata_t { + PortId_t foo; +} + +action test2(inout PortId_t port) { + port._v = port._v + 1; +} + +// CHECK-LABEL: p4hir.func action @test2(%arg0: !p4hir.ref {p4hir.dir = #p4hir}) { +// CHECK: %[[_V_REF:.*]] = p4hir.struct_extract_ref %arg0["_v"] : +// CHECK: %[[VAL:.*]] = p4hir.read %arg0 : +// CHECK: %[[_V_VAL:.*]] = p4hir.struct_extract %[[VAL]]["_v"] : !PortId_t +// CHECK: p4hir.assign %{{.*}}, %[[_V_REF]] +// CHECK: p4hir.return + +// CHECK-LABEL: p4hir.func action @test(%arg0: !p4hir.ref {p4hir.dir = #p4hir}) { +// Just few important bits here +action test(inout metadata_t meta) { + bit<9> vv; + + PortId_t p1 = { _v = vv }; + + // CHECK: %[[VV_VAR:.*]] = p4hir.variable ["vv"] : + // CHECK: %[[VV_VAL:.*]] = p4hir.read %[[VV_VAR]] : + // CHECK: %[[STRUCT:.*]] = p4hir.struct (%[[VV_VAL]]) : !PortId_t + // CHECK: %[[P_VAR:.*]] = p4hir.variable ["p1", init] : + // CHECK: p4hir.assign %[[STRUCT]], %[[P_VAR]] : + + PortId_t p; + bit<9> v; + v = p._v; + + v = meta.foo._v; + + meta.foo._v = 1; + + // CHECK: p4hir.scope { + // CHECK: p4hir.call @test2 + test2(meta.foo); + // CHECK: } + + // CHECK: %[[METADATA_VAL:.*]] = p4hir.read %arg0 : + // CHECK: %[[FOO:.*]] = p4hir.struct_extract %[[METADATA_VAL]]["foo"] : !metadata_t + // CHECK: %[[PSA_CPU_PORT:.*]] = p4hir.const ["PSA_CPU_PORT"] #p4hir.aggregate<[#int192_b9i]> : !PortId_t + // CHECK: %eq = p4hir.cmp(eq, %[[FOO]], %[[PSA_CPU_PORT]]) : !PortId_t, !p4hir.bool + if (meta.foo == PSA_CPU_PORT) { + meta.foo._v = meta.foo._v + 1; + } +} diff --git a/tools/p4mlir-translate/translate.cpp b/tools/p4mlir-translate/translate.cpp index e4452ed..efcd237 100644 --- a/tools/p4mlir-translate/translate.cpp +++ b/tools/p4mlir-translate/translate.cpp @@ -89,7 +89,8 @@ class ConversionTracer { public: ConversionTracer(const char *Kind, const P4::IR::Node *node) { // TODO: Add TimeTrace here - LOG4(P4::IndentCtl::indent << Kind << dbp(node)); + LOG4(P4::IndentCtl::indent << Kind << dbp(node) << (LOGGING(5) ? ":" : "")); + LOG5(node); } ~ConversionTracer() { LOG4_UNINDENT; } }; @@ -131,6 +132,7 @@ class P4TypeConverter : public P4::Inspector { bool preorder(const P4::IR::Type_Action *act) override; bool preorder(const P4::IR::Type_Method *m) override; bool preorder(const P4::IR::Type_Void *v) override; + bool preorder(const P4::IR::Type_Struct *v) override; mlir::Type getType() const { return type; } bool setType(const P4::IR::Type *type, mlir::Type mlirType); @@ -158,8 +160,7 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { llvm::DenseMap p4Symbols; mlir::TypedAttr resolveConstant(const P4::IR::CompileTimeValue *ctv); - mlir::TypedAttr resolveConstantExpr(const P4::IR::Expression *expr); - mlir::Value resolveReference(const P4::IR::Node *node); + mlir::Value resolveReference(const P4::IR::Node *node, bool unchecked = true); mlir::Value getBoolConstant(mlir::Location loc, bool value) { auto boolType = P4HIR::BoolType::get(context()); @@ -226,15 +227,7 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { } */ - mlir::TypedAttr getOrCreateConstantExpr(const P4::IR::Expression *expr) { - auto cst = p4Constants.lookup(expr); - if (cst) return cst; - - cst = resolveConstantExpr(expr); - - BUG_CHECK(cst, "expected %1% to be converted as constant", expr); - return cst; - } + mlir::TypedAttr getOrCreateConstantExpr(const P4::IR::Expression *expr); mlir::Value getValue(const P4::IR::Node *node) { // If this is a PathExpression, resolve it @@ -245,6 +238,15 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { auto val = p4Values.lookup(node); BUG_CHECK(val, "expected %1% (aka %2%) to be converted", node, dbp(node)); + // See, if node is a top-level constant. If yes, then clone the value + // into the present scope as other top-level things are + // IsolatedFromAbove. + // TODO: Save new constant value into scoped value tables when we will have one + if (auto constOp = val.getDefiningOp(); + constOp && mlir::isa_and_nonnull(constOp->getParentOp())) { + val = builder.clone(*constOp)->getResult(0); + } + if (mlir::isa(val.getType())) // Getting value out of variable involves a load. return builder.create(getLoc(builder, node), val); @@ -348,6 +350,7 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { HANDLE_IN_POSTORDER(Cast) HANDLE_IN_POSTORDER(Declaration_Variable) HANDLE_IN_POSTORDER(ReturnStatement) + HANDLE_IN_POSTORDER(Member) #undef HANDLE_IN_POSTORDER @@ -363,6 +366,7 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { } bool preorder(const P4::IR::MethodCallExpression *mce) override; + bool preorder(const P4::IR::StructExpression *str) override; mlir::Value emitUnOp(const P4::IR::Operation_Unary *unop, P4HIR::UnaryOpKind kind); mlir::Value emitBinOp(const P4::IR::Operation_Binary *binop, P4HIR::BinOpKind kind); @@ -404,7 +408,7 @@ bool P4TypeConverter::preorder(const P4::IR::Type_Unknown *type) { bool P4TypeConverter::preorder(const P4::IR::Type_Name *name) { if ((this->type = converter.findType(name))) return false; - ConversionTracer trace("TypeConverting ", name); + ConversionTracer trace("Resolving type by name ", name); const auto *type = converter.resolveType(name); CHECK_NULL(type); mlir::Type mlirType = convert(type); @@ -457,6 +461,20 @@ bool P4TypeConverter::preorder(const P4::IR::Type_Void *type) { return setType(type, mlirType); } +bool P4TypeConverter::preorder(const P4::IR::Type_Struct *type) { + if ((this->type = converter.findType(type))) return false; + + ConversionTracer trace("TypeConverting ", type); + llvm::SmallVector fields; + for (const auto *field : type->fields) { + fields.push_back({mlir::StringAttr::get(converter.context(), field->name.string_view()), + convert(field->type)}); + } + + auto mlirType = P4HIR::StructType::get(converter.context(), type->name.string_view(), fields); + return setType(type, mlirType); +} + bool P4TypeConverter::setType(const P4::IR::Type *type, mlir::Type mlirType) { this->type = mlirType; converter.setType(type, mlirType); @@ -470,16 +488,35 @@ mlir::Type P4TypeConverter::convert(const P4::IR::Type *type) { return getType(); } -mlir::Value P4HIRConverter::resolveReference(const P4::IR::Node *node) { - // If this is a PathExpression, resolve it +// Resolve an l-value-kind expression, building access operation for each "layer". +mlir::Value P4HIRConverter::resolveReference(const P4::IR::Node *node, bool unchecked) { + auto ref = p4Values.lookup(node); + if (ref) return ref; + + ConversionTracer trace("Resolving reference ", node); + // Check if this is a reference to a member of something we can recognize + if (const auto *m = node->to()) { + auto base = resolveReference(m->expr); + auto field = builder.create(getLoc(builder, m), base, + m->member.string_view()); + return setValue(m, field.getResult()); + } + + // If this is a PathExpression, resolve it to the actual declaration, usualy this + // is a "leaf" case. if (const auto *pe = node->to()) { node = resolvePath(pe->path, false)->checkedTo(); } - // The result is expected to be an l-value - auto ref = p4Values.lookup(node); + ref = p4Values.lookup(node); + if (!ref) { + visit(node); + ref = p4Values.lookup(node); + } + BUG_CHECK(ref, "expected %1% (aka %2%) to be converted", node, dbp(node)); - BUG_CHECK(mlir::isa(ref.getType()), + // The result is expected to be an l-value + BUG_CHECK(unchecked || mlir::isa(ref.getType()), "expected reference type for node %1%", node); return ref; @@ -489,8 +526,17 @@ mlir::TypedAttr P4HIRConverter::resolveConstant(const P4::IR::CompileTimeValue * BUG("cannot resolve this constant yet %1%", ctv); } -mlir::TypedAttr P4HIRConverter::resolveConstantExpr(const P4::IR::Expression *expr) { - LOG4("Resolving " << dbp(expr) << " as constant expression"); +mlir::TypedAttr P4HIRConverter::getOrCreateConstantExpr(const P4::IR::Expression *expr) { + if (auto cst = p4Constants.lookup(expr)) return cst; + + ConversionTracer trace("Resolving constant expression ", expr); + + // If this is a PathExpression, resolve it to the actual constant + // declaration initializer, usualy this is a "leaf" case. + if (const auto *pe = expr->to()) { + auto *cst = resolvePath(pe->path, false)->checkedTo(); + return getOrCreateConstantExpr(cst->initializer); + } if (const auto *cst = expr->to()) { auto type = getOrCreateType(cst->type); @@ -527,8 +573,34 @@ mlir::TypedAttr P4HIRConverter::resolveConstantExpr(const P4::IR::Expression *ex } } } + if (const auto *str = expr->to()) { + auto type = getOrCreateType(str->type); + llvm::SmallVector fields; + for (const auto *field : str->components) + fields.push_back(getOrCreateConstantExpr(field->expression)); + return setConstantExpr(expr, P4HIR::AggAttr::get(type, builder.getArrayAttr(fields))); + } + if (const auto *m = expr->to()) { + auto base = mlir::cast(getOrCreateConstantExpr(m->expr)); + auto structType = mlir::cast(base.getType()); + + if (auto maybeIdx = structType.getFieldIndex(m->member.string_view())) { + auto field = base.getFields()[*maybeIdx]; + auto fieldType = structType.getFieldType(m->member.string_view()); + + // TODO: We'd likely would want to convert this to some kind of interface, + if (mlir::isa(fieldType)) + return setConstantExpr(expr, mlir::cast(field)); + + if (mlir::isa(fieldType)) + return setConstantExpr(expr, mlir::cast(field)); + + return setConstantExpr(expr, mlir::cast(field)); + } else + BUG("invalid member reference %1%", m); + } - BUG("cannot resolve this constant expression yet %1%", expr); + BUG("cannot resolve this constant expression yet %1% (aka %2%)", expr, dbp(expr)); } mlir::Value P4HIRConverter::materializeConstantExpr(const P4::IR::Expression *expr) { @@ -559,8 +631,8 @@ void P4HIRConverter::postorder(const P4::IR::Declaration_Variable *decl) { auto type = getOrCreateType(decl); // TODO: Choose better insertion point for alloca (entry BB or so) - auto var = builder.create( - getLoc(builder, decl), type, mlir::StringAttr::get(context(), decl->name.string_view())); + auto var = builder.create(getLoc(builder, decl), type, + builder.getStringAttr(decl->name.string_view())); if (const auto *init = decl->initializer) { var.setInit(true); @@ -645,9 +717,8 @@ bool P4HIRConverter::preorder(const P4::IR::AssignmentStatement *assign) { ConversionTracer trace("Converting ", assign); // TODO: Handle slice on LHS here - visit(assign->left); - visit(assign->right); auto ref = resolveReference(assign->left); + visit(assign->right); builder.create(getLoc(builder, assign), getValue(assign->right), ref); return false; } @@ -736,7 +807,7 @@ bool P4HIRConverter::preorder(const P4::IR::IfStatement *ifs) { } static llvm::SmallVector convertParamDirections( - const P4::IR::ParameterList *params, mlir::MLIRContext *ctxt) { + const P4::IR::ParameterList *params, mlir::OpBuilder &b) { // Create attributes for directions llvm::SmallVector argAttrs; for (const auto *p : params->parameters) { @@ -756,11 +827,9 @@ static llvm::SmallVector convertParamDirections( break; }; - mlir::NamedAttribute dirAttr( - mlir::StringAttr::get(ctxt, P4HIR::FuncOp::getDirectionAttrName()), - P4HIR::ParamDirectionAttr::get(ctxt, dir)); - - argAttrs.emplace_back(mlir::DictionaryAttr::get(ctxt, dirAttr)); + argAttrs.emplace_back(b.getDictionaryAttr( + b.getNamedAttr(P4HIR::FuncOp::getDirectionAttrName(), + P4HIR::ParamDirectionAttr::get(b.getContext(), dir)))); } return argAttrs; @@ -772,7 +841,7 @@ bool P4HIRConverter::preorder(const P4::IR::Function *f) { auto funcType = mlir::cast(getOrCreateType(f->type)); const auto ¶ms = f->getParameters()->parameters; - auto argAttrs = convertParamDirections(f->getParameters(), context()); + auto argAttrs = convertParamDirections(f->getParameters(), builder); assert(funcType.getNumInputs() == argAttrs.size() && "invalid parameter conversion"); auto func = builder.create(getLoc(builder, f), f->name.string_view(), funcType, @@ -813,7 +882,7 @@ bool P4HIRConverter::preorder(const P4::IR::Method *m) { auto funcType = mlir::cast(getOrCreateType(m->type)); - auto argAttrs = convertParamDirections(m->getParameters(), context()); + auto argAttrs = convertParamDirections(m->getParameters(), builder); assert(funcType.getNumInputs() == argAttrs.size() && "invalid parameter conversion"); auto func = builder.create(getLoc(builder, m), m->name.string_view(), funcType, @@ -835,7 +904,7 @@ bool P4HIRConverter::preorder(const P4::IR::P4Action *act) { auto actType = mlir::cast(getOrCreateType(typeMap->getType(act, true))); const auto ¶ms = act->getParameters()->parameters; - auto argAttrs = convertParamDirections(act->getParameters(), context()); + auto argAttrs = convertParamDirections(act->getParameters(), builder); assert(actType.getNumInputs() == argAttrs.size() && "invalid parameter conversion"); auto action = @@ -901,28 +970,26 @@ bool P4HIRConverter::preorder(const P4::IR::MethodCallExpression *mce) { llvm::SmallVector operands; for (auto [idx, arg] : llvm::enumerate(*mce->arguments)) { ConversionTracer trace("Converting ", arg); - visit(arg->expression); mlir::Value argVal; switch (auto dir = params[idx]->direction) { case P4::IR::Direction::None: case P4::IR::Direction::In: // Nothing to do special, just pass things direct + visit(arg->expression); argVal = getValue(arg->expression); break; case P4::IR::Direction::Out: case P4::IR::Direction::InOut: { - // Just create temporary to hold the output value, initialize in case of inout - auto ref = P4HIR::ReferenceType::get(getOrCreateType(arg->expression)); + // Create temporary to hold the output value, initialize in case of inout + auto ref = resolveReference(arg->expression); auto copyIn = b.create( - loc, ref, - mlir::StringAttr::get( - context(), - llvm::Twine(params[idx]->name.string_view()) + - (dir == P4::IR::Direction::InOut ? "_inout_arg" : "_out_arg"))); + loc, ref.getType(), + b.getStringAttr(llvm::Twine(params[idx]->name.string_view()) + + (dir == P4::IR::Direction::InOut ? "_inout_arg" : "_out_arg"))); if (dir == P4::IR::Direction::InOut) { copyIn.setInit(true); - b.create(loc, getValue(arg->expression), copyIn); + b.create(loc, b.create(loc, ref), copyIn); } argVal = copyIn; break; @@ -988,6 +1055,35 @@ bool P4HIRConverter::preorder(const P4::IR::MethodCallExpression *mce) { return false; } +void P4HIRConverter::postorder(const P4::IR::Member *m) { + // Resolve member rvalue expression to something we can reason about + // TODO: Likely we can do similar things for the majority of struct-like + // types + auto parentType = getOrCreateType(m->expr); + if (auto structType = mlir::dyn_cast(parentType)) { + // We can access to parent using struct operations + auto parent = getValue(m->expr); + auto field = builder.create(getLoc(builder, m), parent, + m->member.string_view()); + setValue(m, field.getResult()); + } else { + BUG("cannot convert this member reference %1% (aka %2%) yet", m, dbp(m)); + } +} + +bool P4HIRConverter::preorder(const P4::IR::StructExpression *str) { + auto type = getOrCreateType(str->structType); + llvm::SmallVector fields; + + for (const auto *field : str->components) { + visit(field->expression); + fields.push_back(getValue(field->expression)); + } + + setValue(str, builder.create(getLoc(builder, str), type, fields).getResult()); + + return false; +} } // namespace