diff --git a/lib/SDFG/Dialect/Ops.cpp b/lib/SDFG/Dialect/Ops.cpp index bc796488..d18caa8d 100644 --- a/lib/SDFG/Dialect/Ops.cpp +++ b/lib/SDFG/Dialect/Ops.cpp @@ -1014,6 +1014,41 @@ LogicalResult EdgeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // AllocOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts an allocation operation using the provided +/// PatternRewriter. +AllocOp AllocOp::create(PatternRewriter &rewriter, Location loc, Type res, + StringRef name, bool transient) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, getOperationName()); + StringAttr nameAttr = rewriter.getStringAttr(utils::generateName(name.str())); + build(builder, state, res, {}, nameAttr, transient); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an allocation operation using the provided +/// PatternRewriter. +AllocOp AllocOp::create(PatternRewriter &rewriter, Location loc, Type res, + bool transient) { + return create(rewriter, loc, res, "arr", transient); +} + +/// Builds, creates and inserts an allocation operation using Operation::create. +AllocOp AllocOp::create(Location loc, Type res, StringRef name, + bool transient) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, getOperationName()); + StringAttr nameAttr = builder.getStringAttr(name); + + if (!res.isa()) { + SizedType sized = SizedType::get(res.getContext(), res, {}, {}, {}); + res = ArrayType::get(res.getContext(), sized); + } + + build(builder, state, res, {}, nameAttr, transient); + return cast(Operation::create(state)); +} + +/// Attempts to parse an allocation operation. ParseResult AllocOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1034,6 +1069,7 @@ ParseResult AllocOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints an allocation operation in human-readable form. void AllocOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << " ("; @@ -1042,6 +1078,7 @@ void AllocOp::print(OpAsmPrinter &p) { p << getOperation()->getResultTypes(); } +/// Verifies the correct structure of an allocation operation. LogicalResult AllocOp::verify() { SizedType result = utils::getSizedType(getRes().getType()); @@ -1056,49 +1093,25 @@ LogicalResult AllocOp::verify() { return success(); } -AllocOp AllocOp::create(PatternRewriter &rewriter, Location loc, Type res, - StringRef name, bool transient) { - OpBuilder builder(loc->getContext()); - OperationState state(loc, getOperationName()); - StringAttr nameAttr = rewriter.getStringAttr(utils::generateName(name.str())); - build(builder, state, res, {}, nameAttr, transient); - return cast(rewriter.create(state)); -} - -AllocOp AllocOp::create(PatternRewriter &rewriter, Location loc, Type res, - bool transient) { - return create(rewriter, loc, res, "arr", transient); -} - -AllocOp AllocOp::create(Location loc, Type res, StringRef name, - bool transient) { - OpBuilder builder(loc->getContext()); - OperationState state(loc, getOperationName()); - StringAttr nameAttr = builder.getStringAttr(name); - - if (!res.isa()) { - SizedType sized = SizedType::get(res.getContext(), res, {}, {}, {}); - res = ArrayType::get(res.getContext(), sized); - } - - build(builder, state, res, {}, nameAttr, transient); - return cast(Operation::create(state)); -} - +/// Returns the type of the elements in the allocated data container. Type AllocOp::getElementType() { return utils::getSizedType(getType()).getElementType(); } +/// Returns true if the allocated data container is a scalar. bool AllocOp::isScalar() { return utils::getSizedType(getType()).getShape().empty(); } +/// Returns true if the allocated data container is a stream. bool AllocOp::isStream() { return getType().isa(); } +/// Returns true if the allocation operation is inside a state. bool AllocOp::isInState() { return utils::getParentState(*this->getOperation()) != nullptr; } +/// Returns the name of the allocated data container. std::string AllocOp::getContainerName() { if ((*this)->hasAttr("name")) { Attribute nameAttr = (*this)->getAttr("name"); @@ -1116,11 +1129,8 @@ std::string AllocOp::getContainerName() { // LoadOp //===----------------------------------------------------------------------===// -LoadOp LoadOp::create(PatternRewriter &rewriter, Location loc, AllocOp alloc, - ValueRange indices) { - return create(rewriter, loc, alloc.getType(), alloc, indices); -} - +/// Builds, creates and inserts a load operation using the provided +/// PatternRewriter. LoadOp LoadOp::create(PatternRewriter &rewriter, Location loc, Type t, Value mem, ValueRange indices) { OpBuilder builder(loc->getContext()); @@ -1144,10 +1154,14 @@ LoadOp LoadOp::create(PatternRewriter &rewriter, Location loc, Type t, return cast(rewriter.create(state)); } -LoadOp LoadOp::create(Location loc, AllocOp alloc, ValueRange indices) { - return create(loc, alloc.getType(), alloc, indices); +/// Builds, creates and inserts a load operation using the provided +/// PatternRewriter. +LoadOp LoadOp::create(PatternRewriter &rewriter, Location loc, AllocOp alloc, + ValueRange indices) { + return create(rewriter, loc, alloc.getType(), alloc, indices); } +/// Builds, creates and inserts a load operation using Operation::create. LoadOp LoadOp::create(Location loc, Type t, Value mem, ValueRange indices) { OpBuilder builder(loc->getContext()); OperationState state(loc, getOperationName()); @@ -1168,6 +1182,12 @@ LoadOp LoadOp::create(Location loc, Type t, Value mem, ValueRange indices) { return cast(Operation::create(state)); } +/// Builds, creates and inserts a load operation using Operation::create. +LoadOp LoadOp::create(Location loc, AllocOp alloc, ValueRange indices) { + return create(loc, alloc.getType(), alloc, indices); +} + +/// Attempts to parse a load operation. ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1203,6 +1223,7 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a load operation in human-readable form. void LoadOp::print(OpAsmPrinter &p) { printOptionalAttrDictNoNumList(p, (*this)->getAttrs(), /*elidedAttrs*/ {"indices"}); @@ -1216,6 +1237,7 @@ void LoadOp::print(OpAsmPrinter &p) { p << ArrayRef(getRes().getType()); } +/// Verifies the correct structure of a load operation. LogicalResult LoadOp::verify() { size_t idx_size = getNumListSize(getOperation(), "indices"); size_t mem_size = utils::getSizedType(getArr().getType()).getRank(); @@ -1226,12 +1248,15 @@ LogicalResult LoadOp::verify() { return success(); } +/// Returns true if the load operation has non-constant indices. bool LoadOp::isIndirect() { return !getIndices().empty(); } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a store operation using the provided +/// PatternRewriter. StoreOp StoreOp::create(PatternRewriter &rewriter, Location loc, Value val, Value mem, ValueRange indices) { OpBuilder builder(loc->getContext()); @@ -1252,6 +1277,7 @@ StoreOp StoreOp::create(PatternRewriter &rewriter, Location loc, Value val, return cast(rewriter.create(state)); } +/// Builds, creates and inserts a store operation using Operation::create. StoreOp StoreOp::create(Location loc, Value val, Value mem, ValueRange indices) { OpBuilder builder(loc->getContext()); @@ -1272,6 +1298,7 @@ StoreOp StoreOp::create(Location loc, Value val, Value mem, return cast(Operation::create(state)); } +/// Builds, creates and inserts a store operation using Operation::create. StoreOp StoreOp::create(Location loc, Value val, Value mem, ArrayRef indices) { OpBuilder builder(loc->getContext()); @@ -1295,6 +1322,7 @@ StoreOp StoreOp::create(Location loc, Value val, Value mem, return cast(Operation::create(state)); } +/// Builds, creates and inserts a store operation using Operation::create. StoreOp StoreOp::create(Location loc, Value val, Value mem) { OpBuilder builder(loc->getContext()); OperationState state(loc, getOperationName()); @@ -1314,6 +1342,7 @@ StoreOp StoreOp::create(Location loc, Value val, Value mem) { return cast(Operation::create(state)); } +/// Attempts to parse a store operation. ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1357,6 +1386,7 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a store operation in human-readable form. void StoreOp::print(OpAsmPrinter &p) { printOptionalAttrDictNoNumList(p, (*this)->getAttrs(), /*elidedAttrs=*/{"indices"}); @@ -1370,6 +1400,7 @@ void StoreOp::print(OpAsmPrinter &p) { p << ArrayRef(getArr().getType()); } +/// Verifies the correct structure of a store operation. LogicalResult StoreOp::verify() { size_t idx_size = getNumListSize(getOperation(), "indices"); size_t mem_size = utils::getSizedType(getArr().getType()).getRank(); @@ -1380,12 +1411,15 @@ LogicalResult StoreOp::verify() { return success(); } +/// Returns true if the store operation has non-constant indices. bool StoreOp::isIndirect() { return !getIndices().empty(); } //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a copy operation using the provided +/// PatternRewriter. CopyOp CopyOp::create(PatternRewriter &rewriter, Location loc, Value src, Value dst) { OpBuilder builder(loc->getContext()); @@ -1398,6 +1432,7 @@ CopyOp CopyOp::create(PatternRewriter &rewriter, Location loc, Value src, return cast(rewriter.create(state)); } +/// Attempts to parse a copy operation. ParseResult CopyOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1426,6 +1461,7 @@ ParseResult CopyOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a copy operation in human-readable form. void CopyOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << ' ' << getSrc() << " -> " << getDest(); @@ -1433,12 +1469,15 @@ void CopyOp::print(OpAsmPrinter &p) { p << ArrayRef(getSrc().getType()); } +/// Verifies the correct structure of a copy operation. LogicalResult CopyOp::verify() { return success(); } //===----------------------------------------------------------------------===// // ViewCastOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a viewcast operation using the provided +/// PatternRewriter. ViewCastOp ViewCastOp::create(PatternRewriter &rewriter, Location loc, Value array, Type type) { OpBuilder builder(loc->getContext()); @@ -1447,6 +1486,7 @@ ViewCastOp ViewCastOp::create(PatternRewriter &rewriter, Location loc, return cast(rewriter.create(state)); } +/// Attempts to parse a viewcast operation. ParseResult ViewCastOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1473,6 +1513,7 @@ ParseResult ViewCastOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a viewcast operation in human-readable form. void ViewCastOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << ' ' << getSrc(); @@ -1482,6 +1523,7 @@ void ViewCastOp::print(OpAsmPrinter &p) { p << getOperation()->getResultTypes(); } +/// Verifies the correct structure of a viewcast operation. LogicalResult ViewCastOp::verify() { size_t src_size = utils::getSizedType(getSrc().getType()).getRank(); size_t res_size = utils::getSizedType(getRes().getType()).getRank(); @@ -1496,6 +1538,8 @@ LogicalResult ViewCastOp::verify() { // SubviewOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a subview operation using the provided +/// PatternRewriter. SubviewOp SubviewOp::create(PatternRewriter &rewriter, Location loc, Type res, Value src, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides) { @@ -1512,6 +1556,7 @@ SubviewOp SubviewOp::create(PatternRewriter &rewriter, Location loc, Type res, return cast(rewriter.create(state)); } +/// Attempts to parse a subview operation. ParseResult SubviewOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1550,6 +1595,7 @@ ParseResult SubviewOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a subview operation in human-readable form. void SubviewOp::print(OpAsmPrinter &p) { printOptionalAttrDictNoNumList(p, (*this)->getAttrs(), {"offsets", "sizes", "strides"}); @@ -1566,12 +1612,14 @@ void SubviewOp::print(OpAsmPrinter &p) { p << getOperation()->getResultTypes(); } +/// Verifies the correct structure of a subview operation. LogicalResult SubviewOp::verify() { return success(); } //===----------------------------------------------------------------------===// // StreamPopOp //===----------------------------------------------------------------------===// +/// Attempts to parse a stream pop operation. ParseResult StreamPopOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1598,6 +1646,7 @@ ParseResult StreamPopOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a stream pop operation in human-readable form. void StreamPopOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << ' ' << getStr(); @@ -1607,12 +1656,14 @@ void StreamPopOp::print(OpAsmPrinter &p) { p << ArrayRef(getRes().getType()); } +/// Verifies the correct structure of a stream pop operation. LogicalResult StreamPopOp::verify() { return success(); } //===----------------------------------------------------------------------===// // StreamPushOp //===----------------------------------------------------------------------===// +/// Attempts to parse a stream push operation. ParseResult StreamPushOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1647,6 +1698,7 @@ ParseResult StreamPushOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a stream push operation in human-readable form. void StreamPushOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << ' ' << getVal() << ", " << getStr(); @@ -1656,12 +1708,14 @@ void StreamPushOp::print(OpAsmPrinter &p) { p << ArrayRef(getStr().getType()); } +/// Verifies the correct structure of a stream push operation. LogicalResult StreamPushOp::verify() { return success(); } //===----------------------------------------------------------------------===// // StreamLengthOp //===----------------------------------------------------------------------===// +/// Attempts to parse a stream length operation. ParseResult StreamLengthOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1688,6 +1742,7 @@ ParseResult StreamLengthOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a stream length operation in human-readable form. void StreamLengthOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << ' ' << getStr(); @@ -1697,6 +1752,7 @@ void StreamLengthOp::print(OpAsmPrinter &p) { p << getOperation()->getResultTypes(); } +/// Verifies the correct structure of a stream length operation. LogicalResult StreamLengthOp::verify() { Operation *parent = (*this)->getParentOp(); if (parent == nullptr) @@ -1714,6 +1770,25 @@ LogicalResult StreamLengthOp::verify() { // ReturnOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a return operation using the provided +/// PatternRewriter. +sdfg::ReturnOp sdfg::ReturnOp::create(PatternRewriter &rewriter, Location loc, + mlir::ValueRange input) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, getOperationName()); + build(builder, state, input); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts a return operation using Operation::create. +sdfg::ReturnOp sdfg::ReturnOp::create(Location loc, mlir::ValueRange input) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, getOperationName()); + build(builder, state, input); + return cast(Operation::create(state)); +} + +/// Attempts to parse a return operation. ParseResult sdfg::ReturnOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1733,12 +1808,14 @@ ParseResult sdfg::ReturnOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a return operation in human-readable form. void sdfg::ReturnOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); if (getNumOperands() > 0) p << ' ' << getInput() << " : " << getInput().getTypes(); } +/// Verifies the correct structure of a return operation. LogicalResult sdfg::ReturnOp::verify() { TaskletNode task = dyn_cast((*this)->getParentOp()); @@ -1748,21 +1825,6 @@ LogicalResult sdfg::ReturnOp::verify() { return success(); } -sdfg::ReturnOp sdfg::ReturnOp::create(PatternRewriter &rewriter, Location loc, - mlir::ValueRange input) { - OpBuilder builder(loc->getContext()); - OperationState state(loc, getOperationName()); - build(builder, state, input); - return cast(rewriter.create(state)); -} - -sdfg::ReturnOp sdfg::ReturnOp::create(Location loc, mlir::ValueRange input) { - OpBuilder builder(loc->getContext()); - OperationState state(loc, getOperationName()); - build(builder, state, input); - return cast(Operation::create(state)); -} - //===----------------------------------------------------------------------===// // LibCallOp //===----------------------------------------------------------------------===//