diff --git a/CMakeLists.txt b/CMakeLists.txt index 2a1c55015..1d51cd805 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + cmake_minimum_required(VERSION 3.20.0) project(sdfg-opt LANGUAGES CXX C) diff --git a/LICENSE b/LICENSE index 76d4b9da4..8c811386e 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 3-Clause License -Copyright (c) 2022, SPCL +Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt index c4feb2a26..a26945b68 100644 --- a/include/CMakeLists.txt +++ b/include/CMakeLists.txt @@ -1 +1,3 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_subdirectory(SDFG) diff --git a/include/SDFG/CMakeLists.txt b/include/SDFG/CMakeLists.txt index 60af241d1..c7f2167ba 100644 --- a/include/SDFG/CMakeLists.txt +++ b/include/SDFG/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_subdirectory(Dialect) add_subdirectory(Conversion) add_subdirectory(Translate) diff --git a/include/SDFG/Conversion/CMakeLists.txt b/include/SDFG/Conversion/CMakeLists.txt index 3ad79fafd..ef6fd1fef 100644 --- a/include/SDFG/Conversion/CMakeLists.txt +++ b/include/SDFG/Conversion/CMakeLists.txt @@ -1,2 +1,6 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_subdirectory(GenericToSDFG) add_subdirectory(LinalgToSDFG) + +add_subdirectory(SDFGToGeneric) diff --git a/include/SDFG/Conversion/GenericToSDFG/CMakeLists.txt b/include/SDFG/Conversion/GenericToSDFG/CMakeLists.txt index 4aadd3fcd..e7e107e9d 100644 --- a/include/SDFG/Conversion/GenericToSDFG/CMakeLists.txt +++ b/include/SDFG/Conversion/GenericToSDFG/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name GenericToSDFG) add_public_tablegen_target(MLIRGenericToSDFGPassIncGen) diff --git a/include/SDFG/Conversion/GenericToSDFG/PassDetail.h b/include/SDFG/Conversion/GenericToSDFG/PassDetail.h index 9fcdf5ad7..87d669bf5 100644 --- a/include/SDFG/Conversion/GenericToSDFG/PassDetail.h +++ b/include/SDFG/Conversion/GenericToSDFG/PassDetail.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for Generic to SDFG conversion pass details. + #ifndef SDFG_Conversion_GenericToSDFG_PassDetail_H #define SDFG_Conversion_GenericToSDFG_PassDetail_H @@ -8,6 +12,7 @@ namespace mlir { namespace sdfg { namespace conversion { +/// Generate the code for base classes. #define GEN_PASS_CLASSES #include "SDFG/Conversion/GenericToSDFG/Passes.h.inc" diff --git a/include/SDFG/Conversion/GenericToSDFG/Passes.h b/include/SDFG/Conversion/GenericToSDFG/Passes.h index 33957445c..a2858136e 100644 --- a/include/SDFG/Conversion/GenericToSDFG/Passes.h +++ b/include/SDFG/Conversion/GenericToSDFG/Passes.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for Generic to SDFG conversion passes. + #ifndef SDFG_Conversion_GenericToSDFG_H #define SDFG_Conversion_GenericToSDFG_H diff --git a/include/SDFG/Conversion/GenericToSDFG/Passes.td b/include/SDFG/Conversion/GenericToSDFG/Passes.td index 4800fb37d..f2290ddd0 100644 --- a/include/SDFG/Conversion/GenericToSDFG/Passes.td +++ b/include/SDFG/Conversion/GenericToSDFG/Passes.td @@ -1,11 +1,16 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for Generic to SDFG conversion passes. + #ifndef SDFG_Conversion_GenericToSDFG #define SDFG_Conversion_GenericToSDFG include "mlir/Pass/PassBase.td" include "SDFG/Dialect/Dialect.td" +/// Define generic to SDFG pass. def GenericToSDFGPass : Pass<"convert-to-sdfg", "ModuleOp"> { - let summary = "Convert SCF, Arith and Memref dialect to SDFG dialect"; + let summary = "Convert SCF, Arith, Math and Memref dialect to SDFG dialect"; let constructor = "mlir::sdfg::conversion::createGenericToSDFGPass()"; let dependentDialects = ["mlir::sdfg::SDFGDialect"]; let options = [ diff --git a/include/SDFG/Conversion/LinalgToSDFG/CMakeLists.txt b/include/SDFG/Conversion/LinalgToSDFG/CMakeLists.txt index 70891efe7..148f06a74 100644 --- a/include/SDFG/Conversion/LinalgToSDFG/CMakeLists.txt +++ b/include/SDFG/Conversion/LinalgToSDFG/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name LinalgToSDFG) add_public_tablegen_target(MLIRLinalgToSDFGPassIncGen) diff --git a/include/SDFG/Conversion/LinalgToSDFG/PassDetail.h b/include/SDFG/Conversion/LinalgToSDFG/PassDetail.h index 296fd5c46..4c4e4011f 100644 --- a/include/SDFG/Conversion/LinalgToSDFG/PassDetail.h +++ b/include/SDFG/Conversion/LinalgToSDFG/PassDetail.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for Linalg to SDFG conversion pass details. + #ifndef SDFG_Conversion_LinalgToSDFG_PassDetail_H #define SDFG_Conversion_LinalgToSDFG_PassDetail_H @@ -8,6 +12,7 @@ namespace mlir { namespace sdfg { namespace conversion { +/// Generate the code for base classes. #define GEN_PASS_CLASSES #include "SDFG/Conversion/LinalgToSDFG/Passes.h.inc" diff --git a/include/SDFG/Conversion/LinalgToSDFG/Passes.h b/include/SDFG/Conversion/LinalgToSDFG/Passes.h index f691dc1f0..6d0573abd 100644 --- a/include/SDFG/Conversion/LinalgToSDFG/Passes.h +++ b/include/SDFG/Conversion/LinalgToSDFG/Passes.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for Linalg to SDFG conversion passes. + #ifndef SDFG_Conversion_LinalgToSDFG_H #define SDFG_Conversion_LinalgToSDFG_H diff --git a/include/SDFG/Conversion/LinalgToSDFG/Passes.td b/include/SDFG/Conversion/LinalgToSDFG/Passes.td index 5df9beeba..2a263271f 100644 --- a/include/SDFG/Conversion/LinalgToSDFG/Passes.td +++ b/include/SDFG/Conversion/LinalgToSDFG/Passes.td @@ -1,9 +1,14 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for Linalg to SDFG conversion passes. + #ifndef SDFG_Conversion_LinalgToSDFG #define SDFG_Conversion_LinalgToSDFG include "mlir/Pass/PassBase.td" include "SDFG/Dialect/Dialect.td" +/// Define Linalg to SDFG Pass. def LinalgToSDFGPass : Pass<"linalg-to-sdfg", "ModuleOp"> { let summary = "Convert Linalg dialect to SDFG dialect"; let constructor = "mlir::sdfg::conversion::createLinalgToSDFGPass()"; diff --git a/include/SDFG/Conversion/SDFGToGeneric/CMakeLists.txt b/include/SDFG/Conversion/SDFGToGeneric/CMakeLists.txt new file mode 100644 index 000000000..55e061ed8 --- /dev/null +++ b/include/SDFG/Conversion/SDFGToGeneric/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name SDFGToGeneric) +add_public_tablegen_target(MLIRSDFGToGenericPassIncGen) + +target_sources(SOURCE_FILES_H PRIVATE PassDetail.h Passes.h SymbolicParser.h + OpCreators.h) diff --git a/include/SDFG/Conversion/SDFGToGeneric/OpCreators.h b/include/SDFG/Conversion/SDFGToGeneric/OpCreators.h new file mode 100644 index 000000000..4f24ff1e1 --- /dev/null +++ b/include/SDFG/Conversion/SDFGToGeneric/OpCreators.h @@ -0,0 +1,137 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for convenience functions, creating various operations. + +#ifndef SDFG_Conversion_SDFGToGeneric_Op_Creators_H +#define SDFG_Conversion_SDFGToGeneric_Op_Creators_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace mlir::sdfg::conversion { + +/// Builds, creates and inserts a func::FuncOp. +func::FuncOp createFunc(PatternRewriter &rewriter, Location loc, StringRef name, + TypeRange inputTypes, TypeRange resultTypes, + StringRef visibility); + +/// Builds, creates and inserts a func::CallOp. +func::CallOp createCall(PatternRewriter &rewriter, Location loc, + TypeRange resultTypes, StringRef callee, + ValueRange operands); + +/// Builds, creates and inserts a func::ReturnOp. +func::ReturnOp createReturn(PatternRewriter &rewriter, Location loc, + ValueRange operands); + +/// Builds, creates and inserts a cf::BranchOp. +cf::BranchOp createBranch(PatternRewriter &rewriter, Location loc, + ValueRange operands, Block *dest); + +/// Builds, creates and inserts a cf::CondBranchOp. +cf::CondBranchOp createCondBranch(PatternRewriter &rewriter, Location loc, + Value condition, Block *trueDest, + Block *falseDest); + +/// Builds, creates and inserts a memref::AllocOp. +memref::AllocOp createAlloc(PatternRewriter &rewriter, Location loc, + MemRefType memrefType, ValueRange dynamicSizes); + +/// Builds, creates and inserts a memref::LoadOp. +memref::LoadOp createLoad(PatternRewriter &rewriter, Location loc, Value memref, + ValueRange indices); + +/// Builds, creates and inserts a memref::StoreOp. +memref::StoreOp createStore(PatternRewriter &rewriter, Location loc, + Value value, Value memref, ValueRange indices); + +/// Builds, creates and inserts a memref::CopyOp. +memref::CopyOp createCopy(PatternRewriter &rewriter, Location loc, Value source, + Value target); + +/// Allocates a symbol as a memref if it's not already allocated and +/// populates the symbol map. +void allocSymbol(PatternRewriter &rewriter, Location loc, StringRef symName, + llvm::StringMap &symbolMap); + +/// Builds, creates and inserts an arith::ConstantIntOp. +arith::ConstantIntOp createConstantInt(PatternRewriter &rewriter, Location loc, + int val, int width); + +/// Builds, creates and inserts an arith::AddIOp. +arith::AddIOp createAddI(PatternRewriter &rewriter, Location loc, Value a, + Value b); + +/// Builds, creates and inserts an arith::SubIOp. +arith::SubIOp createSubI(PatternRewriter &rewriter, Location loc, Value a, + Value b); + +/// Builds, creates and inserts an arith::MulIOp. +arith::MulIOp createMulI(PatternRewriter &rewriter, Location loc, Value a, + Value b); + +/// Builds, creates and inserts an arith::DivSIOp. +arith::DivSIOp createDivSI(PatternRewriter &rewriter, Location loc, Value a, + Value b); + +/// Builds, creates and inserts an arith::FloorDivSIOp. +arith::FloorDivSIOp createFloorDivSI(PatternRewriter &rewriter, Location loc, + Value a, Value b); + +/// Builds, creates and inserts an arith::RemSIOp. +arith::RemSIOp createRemSI(PatternRewriter &rewriter, Location loc, Value a, + Value b); + +/// Builds, creates and inserts an arith::OrIOp. +arith::OrIOp createOrI(PatternRewriter &rewriter, Location loc, Value a, + Value b); + +/// Builds, creates and inserts an arith::AndIOp. +arith::AndIOp createAndI(PatternRewriter &rewriter, Location loc, Value a, + Value b); + +/// Builds, creates and inserts an arith::XOrIOp. +arith::XOrIOp createXOrI(PatternRewriter &rewriter, Location loc, Value a, + Value b); + +/// Builds, creates and inserts an arith::ShLIOp. +arith::ShLIOp createShLI(PatternRewriter &rewriter, Location loc, Value a, + Value b); + +/// Builds, creates and inserts an arith::ShRSIOp. +arith::ShRSIOp createShRSI(PatternRewriter &rewriter, Location loc, Value a, + Value b); + +/// Builds, creates and inserts an arith::CmpIOp. +arith::CmpIOp createCmpI(PatternRewriter &rewriter, Location loc, + arith::CmpIPredicate predicate, Value lhs, Value rhs); + +/// Builds, creates and inserts an arith::ExtSIOp. +arith::ExtSIOp createExtSI(PatternRewriter &rewriter, Location loc, Type out, + Value in); + +/// Builds, creates and inserts an arith::TruncIOp. +arith::TruncIOp createTruncI(PatternRewriter &rewriter, Location loc, Type out, + Value in); + +/// Builds, creates and inserts an arith::IndexCastOp. +arith::IndexCastOp createIndexCast(PatternRewriter &rewriter, Location loc, + Type out, Value in); + +/// Builds, creates and inserts a scf::ParallelOp. +scf::ParallelOp createParallel(PatternRewriter &rewriter, Location loc, + ValueRange lowerBounds, ValueRange upperBounds, + ValueRange steps); + +/// Builds, creates and inserts a scf::YieldOp. +scf::YieldOp createYield(PatternRewriter &rewriter, Location loc); + +} // namespace mlir::sdfg::conversion + +#endif // SDFG_Conversion_SDFGToGeneric_Op_Creators_H diff --git a/include/SDFG/Conversion/SDFGToGeneric/PassDetail.h b/include/SDFG/Conversion/SDFGToGeneric/PassDetail.h new file mode 100644 index 000000000..587812669 --- /dev/null +++ b/include/SDFG/Conversion/SDFGToGeneric/PassDetail.h @@ -0,0 +1,23 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for SDFG to Generic conversion pass details. + +#ifndef SDFG_Conversion_SDFGToGeneric_PassDetail_H +#define SDFG_Conversion_SDFGToGeneric_PassDetail_H + +#include "SDFG/Dialect/Dialect.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace sdfg { +namespace conversion { + +/// Generate the code for base classes. +#define GEN_PASS_CLASSES +#include "SDFG/Conversion/SDFGToGeneric/Passes.h.inc" + +} // namespace conversion +} // namespace sdfg +} // end namespace mlir + +#endif // SDFG_Conversion_SDFGToGeneric_PassDetail_H diff --git a/include/SDFG/Conversion/SDFGToGeneric/Passes.h b/include/SDFG/Conversion/SDFGToGeneric/Passes.h new file mode 100644 index 000000000..a7b808a5b --- /dev/null +++ b/include/SDFG/Conversion/SDFGToGeneric/Passes.h @@ -0,0 +1,25 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for SDFG to Generic conversion passes. + +#ifndef SDFG_Conversion_SDFGToGeneric_H +#define SDFG_Conversion_SDFGToGeneric_H + +#include "mlir/Pass/Pass.h" + +namespace mlir::sdfg::conversion { + +/// Creates a sdfg to generic converting pass +std::unique_ptr createSDFGToGenericPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "SDFG/Conversion/SDFGToGeneric/Passes.h.inc" + +} // namespace mlir::sdfg::conversion + +#endif // SDFG_Conversion_SDFGToGeneric_H diff --git a/include/SDFG/Conversion/SDFGToGeneric/Passes.td b/include/SDFG/Conversion/SDFGToGeneric/Passes.td new file mode 100644 index 000000000..8c968475a --- /dev/null +++ b/include/SDFG/Conversion/SDFGToGeneric/Passes.td @@ -0,0 +1,23 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for SDFG to Generic conversion passes. + +#ifndef SDFG_Conversion_SDFGToGeneric +#define SDFG_Conversion_SDFGToGeneric + +include "mlir/Pass/PassBase.td" +include "SDFG/Dialect/Dialect.td" + +/// Define SDFG to generic pass. +def SDFGToGenericPass : Pass<"lower-sdfg", "ModuleOp"> { + let summary = "Convert SDFG dialect to Func, CF, Memref and SCF dialects"; + let constructor = "mlir::sdfg::conversion::createSDFGToGenericPass()"; + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::cf::ControlFlowDialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect" + ]; +} + +#endif // SDFG_Conversion_SDFGToGeneric diff --git a/include/SDFG/Conversion/SDFGToGeneric/SymbolicParser.h b/include/SDFG/Conversion/SDFGToGeneric/SymbolicParser.h new file mode 100644 index 000000000..957144215 --- /dev/null +++ b/include/SDFG/Conversion/SDFGToGeneric/SymbolicParser.h @@ -0,0 +1,269 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for a simple LL(1) parser, parsing symbolic expressions. + +#ifndef SDFG_Conversion_SDFGToGeneric_Symbolic_Parser_H +#define SDFG_Conversion_SDFGToGeneric_Symbolic_Parser_H + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::sdfg::conversion { + +/// Parent class representing any AST node. +class ASTNode { +public: + virtual ~ASTNode() = default; + + /// Converts the node into MLIR code. SymbolMap is used for permanent mapping + /// of symbols to values. RefMap is a temporary mapping overriding SymbolMap. + virtual Value codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) = 0; +}; + +/// Integer AST node representing an integer constant. +class IntNode : public ASTNode { +public: + int value; + + IntNode(int value) : value(value){}; + + /// Converts the integer node into MLIR code. SymbolMap is used for permanent + /// mapping of symbols to values. RefMap is a temporary mapping overriding + /// SymbolMap. + Value codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) override; +}; + +/// Boolean AST node representing a boolean constant. +class BoolNode : public ASTNode { +public: + bool value; + + BoolNode(bool value) : value(value) {} + + /// Converts the boolean node into MLIR code. SymbolMap is used for permanent + /// mapping of symbols to values. RefMap is a temporary mapping overriding + /// SymbolMap. + Value codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) override; +}; + +/// Variable AST node representing a symbol. +class VarNode : public ASTNode { +public: + std::string name; + + VarNode(std::string name) : name(name) {} + + /// Converts the variable node into MLIR code. SymbolMap is used for permanent + /// mapping of symbols to values. RefMap is a temporary mapping overriding + /// SymbolMap. + Value codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) override; +}; + +/// Assignment AST node representing the assignment of an expression to a +/// variable. +class AssignNode : public ASTNode { +public: + std::unique_ptr variable; + std::unique_ptr expr; + + AssignNode(std::unique_ptr variable, std::unique_ptr expr) + : variable(std::move(variable)), expr(std::move(expr)) {} + + /// Converts the assignment node into MLIR code. SymbolMap is used for + /// permanent mapping of symbols to values. RefMap is a temporary mapping + /// overriding SymbolMap. + Value codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) override; +}; + +/// Unary Operation AST node representing an unary operation performed on an +/// expression. +class UnOpNode : public ASTNode { +public: + /// Enum representing all possible unary operations. + enum UnOp { ADD, SUB, LOG_NOT, BIT_NOT }; + + UnOp op; + std::unique_ptr expr; + + UnOpNode(UnOp op, std::unique_ptr expr) + : op(op), expr(std::move(expr)) {} + + /// Converts the unary operation node into MLIR code. SymbolMap is used for + /// permanent mapping of symbols to values. RefMap is a temporary mapping + /// overriding SymbolMap. + Value codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) override; +}; + +/// Binary Operation AST node representing a binary operation performed on an +/// expression. +class BinOpNode : public ASTNode { +public: + /// Enum representing all possible binary operations. + enum BinOp { + ADD, + SUB, + MUL, + DIV, + FLOORDIV, + MOD, + EXP, + BIT_OR, + BIT_XOR, + BIT_AND, + LSHIFT, + RSHIFT, + LOG_OR, + LOG_AND, + EQ, + NE, + LT, + LE, + GT, + GE + }; + + std::unique_ptr left; + BinOp op; + std::unique_ptr right; + + BinOpNode(std::unique_ptr left, BinOp op, + std::unique_ptr right) + : left(std::move(left)), op(op), right(std::move(right)) {} + + /// Converts the binary operation node into MLIR code. SymbolMap is used for + /// permanent mapping of symbols to values. RefMap is a temporary mapping + /// overriding SymbolMap. + Value codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) override; +}; + +/// Enum representing all accepted token types. +enum TokenType { + EQ, + NE, + LT, + LE, + GT, + GE, + ASSIGN, + LOG_OR, + LOG_AND, + LOG_NOT, + ADD, + SUB, + MUL, + DIV, + FLOORDIV, + MOD, + EXP, + TRUE, + FALSE, + BIT_OR, + BIT_XOR, + BIT_AND, + BIT_NOT, + LSHIFT, + RSHIFT, + LPAREN, + RPAREN, + INT_CONST, + IDENT, + WS +}; + +/// Struct to assign a token type to each parsed token. +struct Token { + TokenType type; + std::string value; +}; + +/// This parser parses symbolic expressions. The private functions attempt to +/// parse one specific grammar rule in descending precedence order. +class SymbolicParser { +public: + /// Parses a symbolic expression provided as a string to an AST. + std::unique_ptr parse(StringRef input); + +private: + unsigned pos; + SmallVector tokens; + + /// Converts the symbolic expression to individual tokens. + Optional> tokenize(StringRef input); + + /// Attempts to parse a statement: + /// stmt ::= assignment | log_or_expr + std::unique_ptr stmt(); + /// Attempts to parse an assignment: + /// assignment ::= IDENT ASSIGN log_or_expr + std::unique_ptr assignment(); + + /// Attempts to parse a logical OR expression: + /// log_or_expr ::= log_and_expr ( LOG_OR log_and_expr )* + std::unique_ptr log_or_expr(); + /// Attempts to parse a logical AND expression: + /// log_and_expr ::= eq_expr ( LOG_AND eq_expr )* + std::unique_ptr log_and_expr(); + + /// Attempts to parse an equality expression: + /// eq_expr ::= rel_expr ( ( EQ | NE ) rel_expr )* + std::unique_ptr eq_expr(); + /// Attempts to parse an inequality expression: + /// rel_expr ::= shift_expr ( ( LT | LE | GT | GE ) shift_expr )* + std::unique_ptr rel_expr(); + /// Attempts to parse a shift expression: + /// shift_expr ::= bit_or_expr ( (LSHIFT | RSHIFT ) bit_or_expr )* + std::unique_ptr shift_expr(); + + /// Attempts to parse a bitwise OR expression: + /// bit_or_expr ::= bit_xor_expr ( BIT_OR bit_xor_expr )* + std::unique_ptr bit_or_expr(); + /// Attempts to parse a bitwise XOR expression: + /// bit_xor_expr ::= bit_and_expr ( BIT_XOR bit_and_expr )* + std::unique_ptr bit_xor_expr(); + /// Attempts to parse a bitwise AND expression: + /// bit_and_expr ::= add_expr ( BIT_AND add_expr )* + std::unique_ptr bit_and_expr(); + + /// Attempts to parse an arithmetic addition / subtraction expression: + /// add_expr ::= mul_expr ( ( ADD | SUB ) mul_expr )* + std::unique_ptr add_expr(); + /// Attempts to parse an arithmetic multiplication / division / floor / modulo + /// expression: + /// mul_expr ::= exp_expr ( ( MUL | DIV | FLOORDIV | MOD ) exp_expr )* + std::unique_ptr mul_expr(); + /// Attempts to parse an arithmetic exponential expression: + /// exp_expr ::= unary_expr ( EXP unary_expr )* + std::unique_ptr exp_expr(); + /// Attempts to parse an unary positive / negative / logical and bitwise NOT + /// expression: + /// unary_expr ::= ( ADD | SUB | LOG_NOT | BIT_NOT )? factor + std::unique_ptr unary_expr(); + /// Attempts to parse a single factor: + /// factor ::= LPAREN log_or_expr RPAREN | const_expr | IDENT + std::unique_ptr factor(); + + /// Attempts to parse a constant expression: + /// const_expr ::= bool_const | INT_CONST + std::unique_ptr const_expr(); + /// Attempts to parse a constant boolean expression: + /// bool_const ::= TRUE | FALSE + std::unique_ptr bool_const(); +}; + +} // namespace mlir::sdfg::conversion + +#endif // SDFG_Conversion_SDFGToGeneric_Symbolic_Parser_H diff --git a/include/SDFG/Dialect/CMakeLists.txt b/include/SDFG/Dialect/CMakeLists.txt index 2956367cd..7736c38ee 100644 --- a/include/SDFG/Dialect/CMakeLists.txt +++ b/include/SDFG/Dialect/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_mlir_dialect(Ops sdfg) add_mlir_doc(Dialect Dialect SDFG/ -gen-dialect-doc) add_mlir_doc(Ops Ops SDFG/ -gen-op-doc) diff --git a/include/SDFG/Dialect/Dialect.h b/include/SDFG/Dialect/Dialect.h index 0717ffd24..29544fb6a 100644 --- a/include/SDFG/Dialect/Dialect.h +++ b/include/SDFG/Dialect/Dialect.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for SDFG dialect. + #ifndef SDFG_DIALECT_DIALECT_H #define SDFG_DIALECT_DIALECT_H @@ -9,9 +13,11 @@ #include "SDFG/Dialect/OpsDialect.h.inc" +/// Generate the code for type definitions. #define GET_TYPEDEF_CLASSES #include "SDFG/Dialect/OpsTypes.h.inc" +/// Generate the code for operation definitions. #define GET_OP_CLASSES #include "SDFG/Dialect/Ops.h.inc" diff --git a/include/SDFG/Dialect/Dialect.td b/include/SDFG/Dialect/Dialect.td index ecd0d52df..d3cc5295a 100644 --- a/include/SDFG/Dialect/Dialect.td +++ b/include/SDFG/Dialect/Dialect.td @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for SDFG dialect. + #ifndef SDFG_Dialect #define SDFG_Dialect @@ -11,6 +15,7 @@ include "mlir/Interfaces/LoopLikeInterface.td" // SDFG Dialect //===----------------------------------------------------------------------===// +/// Defining the SDFG dialect. def SDFG_Dialect : Dialect{ let name = "sdfg"; let summary = "A high-level dialect for representing SDFGs."; @@ -24,6 +29,7 @@ def SDFG_Dialect : Dialect{ // SDFG Types //===----------------------------------------------------------------------===// +/// Defining the SDFG base type. class SDFG_Type traits = []> : TypeDef{} @@ -31,6 +37,7 @@ class SDFG_Type traits = []> : // SizedType //===----------------------------------------------------------------------===// +/// Defining the SDFG sizes type. def SDFG_SizedType : SDFG_Type<"Sized">{ let parameters = (ins "Type":$elementType, @@ -44,7 +51,7 @@ def SDFG_SizedType : SDFG_Type<"Sized">{ size_t getUndefRank(){ size_t undefSize = 0; - for(int64_t dim : getIntegers()) if(dim == -1) undefSize++; + for(int64_t dim : getIntegers()) if(dim < 0) undefSize++; return undefSize; } @@ -59,6 +66,7 @@ def SDFG_SizedType : SDFG_Type<"Sized">{ // ArrayType //===----------------------------------------------------------------------===// +/// Defining the SDFG array type. def SDFG_ArrayType : SDFG_Type<"Array">{ let mnemonic = "array"; let summary = "A array type"; @@ -68,12 +76,20 @@ def SDFG_ArrayType : SDFG_Type<"Array">{ let parameters = (ins SDFG_SizedType:$dimensions); let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + Type getElementType(); + ArrayRef getSymbols(); + ArrayRef getIntegers(); + ArrayRef getShape(); + }]; } //===----------------------------------------------------------------------===// // StreamType //===----------------------------------------------------------------------===// +/// Defining the SDFG stream type. def SDFG_StreamType : SDFG_Type<"Stream">{ let mnemonic = "stream"; let summary = "A stream type"; diff --git a/include/SDFG/Dialect/Ops.td b/include/SDFG/Dialect/Ops.td index 1d7841382..718f3b092 100644 --- a/include/SDFG/Dialect/Ops.td +++ b/include/SDFG/Dialect/Ops.td @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for SDFG operations. + #ifndef SDFG_OPS #define SDFG_OPS diff --git a/include/SDFG/Dialect/nodes/consume.td b/include/SDFG/Dialect/nodes/consume.td index 7d49036bc..92540807d 100644 --- a/include/SDFG/Dialect/nodes/consume.td +++ b/include/SDFG/Dialect/nodes/consume.td @@ -1,6 +1,11 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for consume nodes. + #ifndef SDFG_ConsumeNode #define SDFG_ConsumeNode +/// Defining the SDFG consume scope. def SDFG_ConsumeNode : SDFG_Op<"consume", [ ParentOneOf<["StateNode", "MapNode", "ConsumeNode"]>, SingleBlock, @@ -17,7 +22,6 @@ def SDFG_ConsumeNode : SDFG_Op<"consume", [ ```mlir sdfg.consume{num_pes=5} (%a : !sdfg.stream) -> (pe: %p, elem: %e) { %c = sdfg.call @add_one(%a) : i32 -> i32 - sdfg.store(wcr="add") %c, %C[] : i32 -> !sdfg.memlet ... } ``` diff --git a/include/SDFG/Dialect/nodes/map.td b/include/SDFG/Dialect/nodes/map.td index 106d4c984..6c9ddda25 100644 --- a/include/SDFG/Dialect/nodes/map.td +++ b/include/SDFG/Dialect/nodes/map.td @@ -1,6 +1,11 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for map nodes. + #ifndef SDFG_MapNode #define SDFG_MapNode +/// Defining the SDFG map scope. def SDFG_MapNode : SDFG_Op<"map", [ ParentOneOf<["StateNode", "MapNode", "ConsumeNode"]>, SingleBlock, @@ -15,7 +20,7 @@ def SDFG_MapNode : SDFG_Op<"map", [ ```mlir sdfg.map (%i, %j) = (0, 0) to (2, 2) step (1, 1) { ... - %a = sdfg.load %A[%i, %j] : !sdfg.memlet<12x34xi32> + %a = sdfg.load %A[%i, %j] : !sdfg.array<12x34xi32> ... } ``` @@ -24,7 +29,7 @@ def SDFG_MapNode : SDFG_Op<"map", [ let arguments = (ins I32Attr:$entryID, I32Attr:$exitID, - Variadic:$ranges, + Variadic:$ranges, // FIXME: This seems unused ArrayAttr:$lowerBounds, ArrayAttr:$upperBounds, ArrayAttr:$steps diff --git a/include/SDFG/Dialect/nodes/sdfg.td b/include/SDFG/Dialect/nodes/sdfg.td index f5c2898e8..ad21a961f 100644 --- a/include/SDFG/Dialect/nodes/sdfg.td +++ b/include/SDFG/Dialect/nodes/sdfg.td @@ -1,6 +1,11 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for SDFG nodes. + #ifndef SDFG_SDFGNode #define SDFG_SDFGNode +/// Defining the SDFG top-level SDFG. def SDFG_SDFGNode : SDFG_Op<"sdfg", [ ParentOneOf<["ModuleOp"]>, NoTerminator, @@ -34,9 +39,15 @@ def SDFG_SDFGNode : SDFG_Op<"sdfg", [ static SDFGNode create(PatternRewriter &rewriter, Location location); StateNode getStateBySymRef(StringRef symRef); StateNode getFirstState(); + StateNode getEntryState(); + Block::BlockArgListType getArgs(); + TypeRange getArgTypes(); + Block::BlockArgListType getResults(); + TypeRange getResultTypes(); }]; } +/// Defining the SDFG nested SDFG scopes. def SDFG_NestedSDFGNode : SDFG_Op<"nested_sdfg", [ ParentOneOf<["StateNode","MapNode","ConsumeNode"]>, NoTerminator, @@ -71,6 +82,9 @@ def SDFG_NestedSDFGNode : SDFG_Op<"nested_sdfg", [ static NestedSDFGNode create(PatternRewriter &rewriter, Location location); StateNode getStateBySymRef(StringRef symRef); StateNode getFirstState(); + StateNode getEntryState(); + ValueRange getArgs(); + ValueRange getResults(); }]; } diff --git a/include/SDFG/Dialect/nodes/state.td b/include/SDFG/Dialect/nodes/state.td index 9f564f87c..66080450a 100644 --- a/include/SDFG/Dialect/nodes/state.td +++ b/include/SDFG/Dialect/nodes/state.td @@ -1,6 +1,11 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for state nodes. + #ifndef SDFG_StateNode #define SDFG_StateNode +/// Defining the SDFG state scope. def SDFG_StateNode : SDFG_Op<"state", [ ParentOneOf<["SDFGNode","NestedSDFGNode"]>, SingleBlock, diff --git a/include/SDFG/Dialect/nodes/tasklet.td b/include/SDFG/Dialect/nodes/tasklet.td index ddde2102a..4a53c0c65 100644 --- a/include/SDFG/Dialect/nodes/tasklet.td +++ b/include/SDFG/Dialect/nodes/tasklet.td @@ -1,6 +1,11 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for tasklet nodes. + #ifndef SDFG_TaskletNode #define SDFG_TaskletNode +/// Defining the SDFG tasklet scope. def SDFG_TaskletNode : SDFG_Op<"tasklet", [ ParentOneOf<["StateNode","MapNode","ConsumeNode"]>, AffineScope, diff --git a/include/SDFG/Dialect/ops/edge.td b/include/SDFG/Dialect/ops/edge.td index d691af799..3b6b8d1fd 100644 --- a/include/SDFG/Dialect/ops/edge.td +++ b/include/SDFG/Dialect/ops/edge.td @@ -1,6 +1,11 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for edge operations. + #ifndef SDFG_EdgeOp #define SDFG_EdgeOp +/// Defining the SDFG edges. def SDFG_EdgeOp : SDFG_Op<"edge", [ ParentOneOf<["SDFGNode", "NestedSDFGNode"]>, DeclareOpInterfaceMethods diff --git a/include/SDFG/Dialect/ops/memlet.td b/include/SDFG/Dialect/ops/memlet.td index 08a0b89c4..ecc91831d 100644 --- a/include/SDFG/Dialect/ops/memlet.td +++ b/include/SDFG/Dialect/ops/memlet.td @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for memlet operations. + #ifndef SDFG_MemletOps #define SDFG_MemletOps @@ -5,6 +9,7 @@ // AllocOp //===----------------------------------------------------------------------===// +/// Defining the SDFG containers. def SDFG_AllocOp : SDFG_Op<"alloc", [ ParentOneOf<["SDFGNode", "NestedSDFGNode", "StateNode"]> ]> { @@ -42,6 +47,7 @@ def SDFG_AllocOp : SDFG_Op<"alloc", [ // LoadOp //===----------------------------------------------------------------------===// +/// Defining part of the SDFG memlet. def SDFG_LoadOp : SDFG_Op<"load", [ ParentOneOf<["StateNode", "MapNode", "ConsumeNode"]>, TypesMatchWith<"result type matches element type of 'memlet'", "arr", "res", @@ -79,6 +85,7 @@ def SDFG_LoadOp : SDFG_Op<"load", [ // StoreOp //===----------------------------------------------------------------------===// +/// Defining part of the SDFG memlet. def SDFG_StoreOp : SDFG_Op<"store", [ ParentOneOf<["StateNode", "MapNode", "ConsumeNode"]>, TypesMatchWith<"value type matches element type of 'memlet'", "arr", "val", @@ -119,6 +126,7 @@ def SDFG_StoreOp : SDFG_Op<"store", [ // CopyOp //===----------------------------------------------------------------------===// +/// Combines load and store operations. def SDFG_CopyOp : SDFG_Op<"copy", [ ParentOneOf<["StateNode", "MapNode", "ConsumeNode"]>, SameTypeOperands @@ -148,6 +156,7 @@ def SDFG_CopyOp : SDFG_Op<"copy", [ // ViewCastOp //===----------------------------------------------------------------------===// +/// Defining the SDFG viewcast operation. def SDFG_ViewCastOp : SDFG_Op<"view_cast", [ ParentOneOf<["SDFGNode", "NestedSDFGNode", "StateNode", "MapNode", "ConsumeNode"]> ]> { @@ -177,6 +186,7 @@ def SDFG_ViewCastOp : SDFG_Op<"view_cast", [ // SubviewOp //===----------------------------------------------------------------------===// +/// Defining the SDFG subview operation. def SDFG_SubviewOp : SDFG_Op<"subview", [ ParentOneOf<["SDFGNode", "NestedSDFGNode", "StateNode", "MapNode", "ConsumeNode"]> ]> { diff --git a/include/SDFG/Dialect/ops/stream.td b/include/SDFG/Dialect/ops/stream.td index 6f875f98f..22e6a7569 100644 --- a/include/SDFG/Dialect/ops/stream.td +++ b/include/SDFG/Dialect/ops/stream.td @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for stream operations. + #ifndef SDFG_StreamOps #define SDFG_StreamOps @@ -5,6 +9,7 @@ // StreamPopOp //===----------------------------------------------------------------------===// +/// Defining part of the SDFG stream memlet. def SDFG_StreamPopOp : SDFG_Op<"stream_pop", [ ParentOneOf<["StateNode", "MapNode"]>, TypesMatchWith<"result type matches element type of 'stream'", "str", "res", @@ -32,6 +37,7 @@ def SDFG_StreamPopOp : SDFG_Op<"stream_pop", [ // StreamPushOp //===----------------------------------------------------------------------===// +/// Defining part of the SDFG stream memlet. def SDFG_StreamPushOp : SDFG_Op<"stream_push", [ ParentOneOf<["StateNode", "MapNode"]>, TypesMatchWith<"value type matches element type of 'stream'", "str", "val", @@ -59,6 +65,7 @@ def SDFG_StreamPushOp : SDFG_Op<"stream_push", [ // StreamLengthOp //===----------------------------------------------------------------------===// +/// Defining the SDFG stream length operation. def SDFG_StreamLengthOp : SDFG_Op<"stream_length"> { let summary = "Stream length operation"; let description = [{ diff --git a/include/SDFG/Dialect/ops/symbol.td b/include/SDFG/Dialect/ops/symbol.td index fd8a03a5d..a14ba1c2a 100644 --- a/include/SDFG/Dialect/ops/symbol.td +++ b/include/SDFG/Dialect/ops/symbol.td @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for symbol operations. + #ifndef SDFG_SymbolOps #define SDFG_SymbolOps @@ -5,6 +9,7 @@ // AllocSymbolOp //===----------------------------------------------------------------------===// +/// Defining the SDFG symbol definition operation. def SDFG_AllocSymbolOp : SDFG_Op<"alloc_symbol", [ ParentOneOf<["SDFGNode", "NestedSDFGNode", "StateNode", "MapNode", "ConsumeNode"]> ]> { @@ -34,6 +39,7 @@ def SDFG_AllocSymbolOp : SDFG_Op<"alloc_symbol", [ // SymOp //===----------------------------------------------------------------------===// +/// Defining the SDFG symbolic expression. def SDFG_SymOp : SDFG_Op<"sym",[ ParentOneOf<["StateNode", "MapNode", "ConsumeNode"]> ]>{ diff --git a/include/SDFG/Dialect/ops/tasklet.td b/include/SDFG/Dialect/ops/tasklet.td index 788c0b097..738fb43f5 100644 --- a/include/SDFG/Dialect/ops/tasklet.td +++ b/include/SDFG/Dialect/ops/tasklet.td @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Table-driven file for tasklet operations. + #ifndef SDFG_TaskletOps #define SDFG_TaskletOps @@ -5,6 +9,7 @@ // ReturnOp //===----------------------------------------------------------------------===// +/// Defining the SDFG tasklet terminator. def SDFG_ReturnOp : SDFG_Op<"return", [ HasParent<"TaskletNode">, Terminator @@ -37,6 +42,7 @@ def SDFG_ReturnOp : SDFG_Op<"return", [ // LibCallOp //===----------------------------------------------------------------------===// +/// Defining the SDFG library call. def SDFG_LibCallOp : SDFG_Op<"libcall", [ ParentOneOf<["StateNode", "MapNode", "ConsumeNode"]>, CallOpInterface diff --git a/include/SDFG/Translate/CMakeLists.txt b/include/SDFG/Translate/CMakeLists.txt index ec33fd925..9fae173bc 100644 --- a/include/SDFG/Translate/CMakeLists.txt +++ b/include/SDFG/Translate/CMakeLists.txt @@ -1,2 +1,4 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + target_sources(SOURCE_FILES_H PRIVATE JsonEmitter.h liftToPython.h Node.h Translation.h) diff --git a/include/SDFG/Translate/JsonEmitter.h b/include/SDFG/Translate/JsonEmitter.h index b8037116b..f09ab6366 100644 --- a/include/SDFG/Translate/JsonEmitter.h +++ b/include/SDFG/Translate/JsonEmitter.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for JSON emitter in SDFG translation. + #ifndef SDFG_JsonEmitter_H #define SDFG_JsonEmitter_H @@ -11,54 +15,81 @@ namespace mlir::sdfg::emitter { struct JsonEmitter { explicit JsonEmitter(raw_ostream &os); - // Avoid writing directly to the output stream if possible. + /// Returns a reference to the output stream. Avoid writing directly to the + /// output stream if possible. raw_ostream &ostream() { return os; }; + /// Returns the current indentation level. unsigned getIndentation() { return indentation; }; - // Checks for errors (open objects/lists) and adds trailing newline + /// Checks for errors (open objects/lists) and adds trailing newline. Returns + /// a LogicalResult indicating success or failure. LogicalResult finish(); + /// Increases the indentation level. void indent(); + /// Decreases the indentation level. void unindent(); + /// Starts a new line in the output stream. void newLine(); + /// Prints a literal string to the output stream. void printLiteral(StringRef str); + /// Prints a string to the output stream, surrounding it with quotation marks. void printString(StringRef str); + /// Prints an integer to the output stream, surrounding it with quotation + /// marks. void printInt(int i); + /// Starts a new JSON object. void startObject(); + /// Starts a new named (keyed) JSON object. void startNamedObject(StringRef name); + /// Ends the current JSON object. void endObject(); + /// Starts a new named JSON list. void startNamedList(StringRef name); + /// Ends the current JSON list. void endList(); + /// Starts a new entry in the current JSON object or list. void startEntry(); + /// Prints a key-value pair to the output stream. If desired, turns the value + /// into string. void printKVPair(StringRef key, StringRef val, bool stringify = true); + /// Prints a key-value pair to the output stream. If desired, turns the value + /// into string. void printKVPair(StringRef key, int val, bool stringify = true); + /// Prints a key-value pair to the output stream. If desired, turns the value + /// into string. void printKVPair(StringRef key, Attribute val, bool stringify = true); + /// Prints a list of NamedAttributes as key-value pairs. void printAttributes(ArrayRef arr, ArrayRef elidedAttrs = {}); private: - // output stream + /// The output stream. raw_ostream &os; + /// The current indentation level. unsigned indentation; - // Avoids printing commas for first entries (objects or lists) + /// Flag indicating whether the current entry is the first in its parent + /// object or list. bool firstEntry; - // Stores if the current line is empty or not + /// Flag indicating whether the current line is empty. bool emptyLine; - // Used to check for proper closing of opened objects/lists + /// Enum class to represent the type of the current JSON symbol. enum class SYM { - // "{" or "}" + /// "{" or "}" BRACE, - // "[" or "]" + /// "[" or "]" SQUARE }; + /// Stack to keep track of the opened JSON symbols. SmallVector symStack; + /// Tries to pop a symbol from the symStack, checking for matching symbols. void tryPop(SYM sym); - // Tracks if there was an erronous printing + /// Flag indicating whether there was an error during printing. bool error; }; diff --git a/include/SDFG/Translate/Node.h b/include/SDFG/Translate/Node.h index b028786a7..32b40a4f4 100644 --- a/include/SDFG/Translate/Node.h +++ b/include/SDFG/Translate/Node.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for the internal IR of the translator. + #ifndef SDFG_Translation_Node_H #define SDFG_Translation_Node_H @@ -7,6 +11,7 @@ #include "mlir/IR/Location.h" namespace mlir::sdfg::translation { +// Forward declarations. class Emittable; class Attribute; @@ -60,6 +65,7 @@ class ConsumeExitImpl; // Interfaces //===----------------------------------------------------------------------===// +/// All classes that can be printed (emitted) implement this interface. class Emittable { public: virtual void emit(emitter::JsonEmitter &jemit) = 0; @@ -69,16 +75,33 @@ class Emittable { // DataClasses //===----------------------------------------------------------------------===// -enum class DType { null, boolean, int8, int32, int64, float32, float64 }; +/// DaCe Datatypes. +enum class DType { + null, + boolean, + int8, + int16, + int32, + int64, + float16, + float32, + float64 +}; + +/// Node Types. enum class NType { SDFG, State, Access, MapEntry, ConsumeEntry, Other }; -enum class CodeLanguage { Python, MLIR }; +/// Programming languages. +enum class CodeLanguage { Python, CPP, MLIR }; + +/// Stores an attribute for a node. class Attribute { public: std::string name; // Store attribute or string? }; +/// Stores a symbol with symbol name and data type. class Symbol { public: std::string name; @@ -86,6 +109,7 @@ class Symbol { Symbol(StringRef name, DType type) : name(name), type(type) {} }; +/// Represents a condition for an edge. class Condition { public: std::string condition; @@ -93,6 +117,7 @@ class Condition { Condition(StringRef condition) : condition(condition) {} }; +/// Represents an assignment for an edge. class Assignment { public: std::string key; @@ -101,6 +126,7 @@ class Assignment { Assignment(StringRef key, StringRef value) : key(key), value(value) {} }; +/// Stores code for tasklets with the associated programming language. class Code { public: std::string data; @@ -111,6 +137,7 @@ class Code { : data(data), language(language) {} }; +/// Represents a DaCe data container. class Array : public Emittable { public: std::string name; @@ -125,9 +152,11 @@ class Array : public Emittable { Array(StringRef name, bool transient, bool stream, SizedType shape) : name(name), transient(transient), stream(stream), shape(shape) {} + /// Emits this array to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Represents a range for memlets. class Range : public Emittable { public: std::string start; @@ -138,6 +167,7 @@ class Range : public Emittable { Range(StringRef start, StringRef end, StringRef step, StringRef tile) : start(start), end(end), step(step), tile(tile) {} + /// Emits this range to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; @@ -145,9 +175,12 @@ class Range : public Emittable { // Node //===----------------------------------------------------------------------===// +/// Base class for all SDFG nodes. class Node : public Emittable { protected: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; + /// Stores the type of this node. NType type; public: @@ -155,52 +188,79 @@ class Node : public Emittable { bool operator==(const Node other) const { return other.ptr == ptr; } + /// Sets the ID of the node. void setID(unsigned id); + /// Returns the ID of the node. unsigned getID(); + /// Returns the source code location. Location getLocation(); + /// Returns the type of the node. NType getType(); + /// Sets the name of the node. void setName(StringRef name); + /// Returns the name of the node. StringRef getName(); + /// Sets the parent of the node. void setParent(Node parent); + /// Returns the parent of the node. Node getParent(); + /// Return true if this node has a parent node. bool hasParent(); + /// Returns the top-level SDFG. virtual SDFG getSDFG(); + /// Returns the surrounding state. virtual State getState(); + /// Adds an attribute to this node, replaces existing attributes with the same + /// name. void addAttribute(Attribute attribute); + /// Emits this node to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the base node class. class NodeImpl : public Emittable { protected: + /// Unique node ID. unsigned id; + /// Source code location. Location location; + /// Name of this node. std::string name; + /// An array of associated attributes. std::vector attributes; + /// Pointer to the parent node. Node parent; public: NodeImpl(Location location) : id(0), location(location), parent(nullptr) {} + /// ID setter. void setID(unsigned id); + /// ID getter. unsigned getID(); + /// Source code location getter. Location getLocation(); + /// Name setter. void setName(StringRef name); + /// Name getter. StringRef getName(); + /// Parent node setter. void setParent(Node parent); + /// Parent node getter. Node getParent(); - // check for existing attribtues - // Replace or add to list + /// Adds an attribute to this node, replaces existing attributes with the same + /// name. void addAttribute(Attribute attribute); - + /// Emits this node to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; @@ -208,8 +268,10 @@ class NodeImpl : public Emittable { // ConnectorNode //===----------------------------------------------------------------------===// +/// Special type of node capable of connecting to other nodes (memlets). class ConnectorNode : public Node { protected: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -219,37 +281,55 @@ class ConnectorNode : public Node { ConnectorNode(Node n) : Node(n), ptr(std::static_pointer_cast(Node::ptr)) {} + /// Adds an incoming connector. void addInConnector(Connector connector); + /// Adds an outgoing connector. void addOutConnector(Connector connector); + /// Returns to number of incoming connectors. unsigned getInConnectorCount(); + /// Returns to number of outgoing connectors. unsigned getOutConnectorCount(); + /// Emits the connectors to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the connector node class. class ConnectorNodeImpl : public NodeImpl { protected: + /// Array of incoming connectors. std::vector inConnectors; + /// Array of outgoing connectors. std::vector outConnectors; public: ConnectorNodeImpl(Location location) : NodeImpl(location) {} + /// Adds an incoming connector. void addInConnector(Connector connector); + /// Adds an outgoing connector. void addOutConnector(Connector connector); + /// Returns to number of incoming connectors. unsigned getInConnectorCount(); + /// Returns to number of outgoing connectors. unsigned getOutConnectorCount(); - // Emits connectors + /// Emits the connectors to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Represents a single connector. class Connector { public: + /// The connector node this connector belongs to. ConnectorNode parent; + /// The name of the connector. std::string name; + /// Null type connector for unnamed connectors (e.g. access nodes). bool isNull; + /// The ranges of the moved data (memlet). std::vector ranges; + /// The name of the data being moved. std::string data; // IDEA: Add DType? @@ -262,8 +342,11 @@ class Connector { return other.parent == parent && other.name == name; } + /// Adds a data range to the connector. void addRange(Range range) { ranges.push_back(range); } + /// Sets the data ranges of the connector. void setRanges(std::vector ranges) { this->ranges = ranges; } + /// Sets the name of the data being moved. void setData(StringRef data) { this->data = data.str(); } }; @@ -272,16 +355,21 @@ class Connector { //===----------------------------------------------------------------------===// // IDEA: Rewrite to use PImpl? +/// Represents an edge moving data between multiple connectors (memlet). class MultiEdge : public Emittable { private: + /// Source code location. Location location; + /// Source connector. Connector source; + /// Destination connector. Connector destination; public: MultiEdge(Location location, Connector source, Connector destination) : location(location), source(source), destination(destination) {} + /// Emits this edge to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; @@ -289,8 +377,10 @@ class MultiEdge : public Emittable { // ScopeNode //===----------------------------------------------------------------------===// +/// Special type of connector node containing a scope. class ScopeNode : public ConnectorNode { protected: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -302,31 +392,46 @@ class ScopeNode : public ConnectorNode { : ConnectorNode(n), ptr(std::static_pointer_cast(Node::ptr)) {} + /// Adds a connector node to the scope. virtual void addNode(ConnectorNode node); + /// Adds a multiedge from the source to the destination connector. virtual void routeWrite(Connector from, Connector to); + /// Adds an edge to the scope. virtual void addEdge(MultiEdge edge); + /// Maps the MLIR value to the specified connector. virtual void mapConnector(Value value, Connector connector); + /// Returns the connector associated with a MLIR value. virtual Connector lookup(Value value); + /// Emits all nodes and edges to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the scoped node class. class ScopeNodeImpl : public ConnectorNodeImpl { protected: + /// Lookup table for Value-Connector mapping. std::map lut; + /// Array of all nodes in the scope. std::vector nodes; + /// Array of all edges in the scope. std::vector edges; public: ScopeNodeImpl(Location location) : ConnectorNodeImpl(location) {} + /// Adds a connector node to the scope. virtual void addNode(ConnectorNode node); + /// Adds a multiedge from the source to the destination connector. virtual void routeWrite(Connector from, Connector to); + /// Adds an edge to the scope. virtual void addEdge(MultiEdge edge); + /// Maps the MLIR value to the specified connector. virtual void mapConnector(Value value, Connector connector); + /// Returns the connector associated with a MLIR value. virtual Connector lookup(Value value); - // Emits nodes & edges + /// Emits all nodes and edges to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; @@ -334,8 +439,10 @@ class ScopeNodeImpl : public ConnectorNodeImpl { // State //===----------------------------------------------------------------------===// +/// Represents a SDFG state. class State : public ScopeNode { private: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -351,15 +458,24 @@ class State : public ScopeNode { type = NType::State; } + /// Modified lookup function creates access nodes if the value could not be + /// found. Connector lookup(Value value) override; + + /// Emits the state node to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the state node class. class StateImpl : public ScopeNodeImpl { public: StateImpl(Location location) : ScopeNodeImpl(location) {} + /// Modified lookup function creates access nodes if the value could not be + /// found. Connector lookup(Value value) override; + + /// Emits the state node to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; @@ -367,8 +483,10 @@ class StateImpl : public ScopeNodeImpl { // SDFG //===----------------------------------------------------------------------===// +/// Represents the top-level SDFG. class SDFG : public Node { private: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -384,30 +502,50 @@ class SDFG : public Node { type = NType::SDFG; } + /// Returns the state associated with the provided name. State lookup(StringRef name); + /// Adds a state to the SDFG. void addState(State state); + /// Adds a state to the SDFG and marks it as the entry state. void setStartState(State state); + /// Adds an interstate edge to the SDFG, connecting two states. void addEdge(InterstateEdge edge); + /// Adds an array (data container) to the SDFG. void addArray(Array array); + /// Adds an array (data container) to the SDFG and marks it as an argument. void addArg(Array arg); + /// Adds a symbol to the SDFG. void addSymbol(Symbol symbol); + /// Returns an array of all symbols in the SDFG. std::vector getSymbols(); + /// Emits the SDFG to the output stream. void emit(emitter::JsonEmitter &jemit) override; + /// Emits the SDFG as a nested SDFG to the output stream. void emitNested(emitter::JsonEmitter &jemit); }; +/// Implementation of the SDFG node class. class SDFGImpl : public NodeImpl { private: + /// Lookup table mapping names to states. std::map lut; + /// Array of states in the SDFG. std::vector states; + /// Array of interstate edges in the SDFG. std::vector edges; + /// Array of arrays (data containers) in the SDFG. std::vector arrays; + /// Array of argument arrays (data containers) in the SDFG. std::vector args; + /// Array of symbols in the SDFG. std::vector symbols; + /// The entry state of the SDFG State startState; + /// Global counter for the ID of SDFGs. static unsigned list_id; + /// Emits the body of the SDFG to the output stream. void emitBody(emitter::JsonEmitter &jemit); public: @@ -415,16 +553,26 @@ class SDFGImpl : public NodeImpl { id = SDFGImpl::list_id++; } + /// Returns the state associated with the provided name. State lookup(StringRef name); + /// Adds a state to the SDFG. void addState(State state); + /// Adds a state to the SDFG and marks it as the entry state. void setStartState(State state); + /// Adds an interstate edge to the SDFG, connecting two states. void addEdge(InterstateEdge edge); + /// Adds an array (data container) to the SDFG. void addArray(Array array); + /// Adds an array (data container) to the SDFG and marks it as an argument. void addArg(Array arg); + /// Adds a symbol to the SDFG. void addSymbol(Symbol symbol); + /// Returns an array of all symbols in the SDFG. std::vector getSymbols(); + /// Emits the SDFG to the output stream. void emit(emitter::JsonEmitter &jemit) override; + /// Emits the SDFG as a nested SDFG to the output stream. void emitNested(emitter::JsonEmitter &jemit); }; @@ -432,8 +580,10 @@ class SDFGImpl : public NodeImpl { // NestedSDFG //===----------------------------------------------------------------------===// +/// Represents a nested SDFG. class NestedSDFG : public ConnectorNode { private: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -442,17 +592,21 @@ class NestedSDFG : public ConnectorNode { std::make_shared(location, sdfg))), ptr(std::static_pointer_cast(ConnectorNode::ptr)) {} + /// Emits the nested SDFG to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the nested SDFG node class. class NestedSDFGImpl : public ConnectorNodeImpl { private: + /// The contained SDFG. SDFG sdfg; public: NestedSDFGImpl(Location location, SDFG sdfg) : ConnectorNodeImpl(location), sdfg(sdfg) {} + /// Emits the nested SDFG to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; @@ -460,8 +614,10 @@ class NestedSDFGImpl : public ConnectorNodeImpl { // InterstateEdge //===----------------------------------------------------------------------===// +/// Represents an edge connecting muliple states. class InterstateEdge : public Emittable { private: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -469,19 +625,28 @@ class InterstateEdge : public Emittable { : ptr(std::make_shared(location, source, destination)) {} + /// Sets the condition of the interstate edge. void setCondition(Condition condition); + /// Adds an assignment to the interstate edge. void addAssignment(Assignment assignment); + /// Emits the interstate edge to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the interstate edge class. class InterstateEdgeImpl : public Emittable { private: + /// Source code location. Location location; + /// The source state of this edge. State source; + /// The destination state of this edge. State destination; + /// The condition of this edge. Condition condition; + /// Array of assignments on this edge. std::vector assignments; public: @@ -489,10 +654,12 @@ class InterstateEdgeImpl : public Emittable { : location(location), source(source), destination(destination), condition("1") {} + /// Sets the condition of the interstate edge. void setCondition(Condition condition); - // Check for duplicates + /// Adds an assignment to the interstate edge. void addAssignment(Assignment assignment); + /// Emits the interstate edge to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; @@ -500,8 +667,10 @@ class InterstateEdgeImpl : public Emittable { // Tasklet //===----------------------------------------------------------------------===// +/// Represents a SDFG tasklet. class Tasklet : public ConnectorNode { private: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -510,18 +679,39 @@ class Tasklet : public ConnectorNode { std::make_shared(location))), ptr(std::static_pointer_cast(ConnectorNode::ptr)) {} + /// Sets the code of the tasklet. void setCode(Code code); + /// Sets the global code of the tasklet. + void setGlobalCode(Code code_global); + /// Sets the side effect flag of the tasklet. + void setHasSideEffect(bool hasSideEffect); + + /// Emits the tasklet to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the tasklet node class. class TaskletImpl : public ConnectorNodeImpl { private: + /// The code in the tasklet. Code code; + /// The global code required to run the tasklet. + Code code_global; + /// Flag indicating if this tasklet has side effects. + bool hasSideEffect; public: - TaskletImpl(Location location) : ConnectorNodeImpl(location) {} + TaskletImpl(Location location) + : ConnectorNodeImpl(location), hasSideEffect(false) {} + /// Sets the code of the tasklet. void setCode(Code code); + /// Sets the global code of the tasklet. + void setGlobalCode(Code code_global); + /// Sets the side effect flag of the tasklet. + void setHasSideEffect(bool hasSideEffect); + + /// Emits the tasklet to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; @@ -529,8 +719,10 @@ class TaskletImpl : public ConnectorNodeImpl { // Library //===----------------------------------------------------------------------===// +/// Represents a SDFG libary node. class Library : public ConnectorNode { private: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -539,18 +731,24 @@ class Library : public ConnectorNode { std::make_shared(location))), ptr(std::static_pointer_cast(ConnectorNode::ptr)) {} + /// Sets the library code path. void setClasspath(StringRef classpath); + /// Emits the library node to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the library node class. class LibraryImpl : public ConnectorNodeImpl { private: + /// The path to the library code. std::string classpath; public: LibraryImpl(Location location) : ConnectorNodeImpl(location) {} + /// Sets the library code path. void setClasspath(StringRef classpath); + /// Emits the library node to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; @@ -558,8 +756,10 @@ class LibraryImpl : public ConnectorNodeImpl { // Access //===----------------------------------------------------------------------===// +/// Represents an access node in the SDFG. class Access : public ConnectorNode { private: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -570,14 +770,17 @@ class Access : public ConnectorNode { type = NType::Access; } + /// Emits the access node to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the access node class. class AccessImpl : public ConnectorNodeImpl { private: public: AccessImpl(Location location) : ConnectorNodeImpl(location) {} + /// Emits the access node to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; @@ -585,8 +788,10 @@ class AccessImpl : public ConnectorNodeImpl { // Map //===----------------------------------------------------------------------===// +/// Represents a map entry node in the SDFG. class MapEntry : public ScopeNode { private: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -602,20 +807,34 @@ class MapEntry : public ScopeNode { MapEntry() : ScopeNode(nullptr) {} + /// Adds a parameter to the map entry. void addParam(StringRef param); + /// Adds a range for a parameter. void addRange(Range range); + /// Sets the map exit this map entry belongs to. void setExit(MapExit exit); + /// Returns the matching map exit. MapExit getExit(); + /// Adds a connector node to the scope. void addNode(ConnectorNode node) override; + /// Adds a multiedge from the source to the destination connector. void routeWrite(Connector from, Connector to) override; + /// Adds an edge to the scope. void addEdge(MultiEdge edge) override; + /// Maps the MLIR value to the specified connector. void mapConnector(Value value, Connector connector) override; + /// Returns the connector associated with a MLIR value, inserting map + /// connectors when needed. Connector lookup(Value value) override; + + /// Emits the map entry to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Represents a map exit node in the SDFG. class MapExit : public ConnectorNode { private: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -626,39 +845,63 @@ class MapExit : public ConnectorNode { MapExit() : ConnectorNode(nullptr) {} + /// Sets the map entry this map exit belongs to. void setEntry(MapEntry entry); + + /// Emits the map exit to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the map entry node class. class MapEntryImpl : public ScopeNodeImpl { private: + /// The matching map exit. MapExit exit; + /// Array of parameters. std::vector params; + /// Array of ranges for the parameters. std::vector ranges; public: MapEntryImpl(Location location) : ScopeNodeImpl(location) {} - void setExit(MapExit exit); - MapExit getExit(); + /// Adds a parameter to the map entry. void addParam(StringRef param); + /// Adds a range for a parameter. void addRange(Range range); + /// Sets the map exit this map entry belongs to. + void setExit(MapExit exit); + /// Returns the matching map exit. + MapExit getExit(); + /// Adds a connector node to the scope. void addNode(ConnectorNode node) override; + /// Adds a multiedge from the source to the destination connector. void routeWrite(Connector from, Connector to) override; + /// Adds an edge to the scope. void addEdge(MultiEdge edge) override; + /// Maps the MLIR value to the specified connector. void mapConnector(Value value, Connector connector) override; + /// Returns the connector associated with a MLIR value, inserting map + /// connectors when needed. Connector lookup(Value value, MapEntry mapEntry); + + /// Emits the map entry to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the map exit node class. class MapExitImpl : public ConnectorNodeImpl { private: + /// The matching map entry. MapEntry entry; public: MapExitImpl(Location location) : ConnectorNodeImpl(location) {} + /// Sets the map entry this map exit belongs to. void setEntry(MapEntry entry); + + /// Emits the map exit to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; @@ -666,8 +909,10 @@ class MapExitImpl : public ConnectorNodeImpl { // Consume //===----------------------------------------------------------------------===// +/// Represents a consume entry node in the SDFG. class ConsumeEntry : public ScopeNode { private: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -684,24 +929,38 @@ class ConsumeEntry : public ScopeNode { ConsumeEntry() : ScopeNode(nullptr) {} + /// Sets the consume exit this consume entry belongs to. void setExit(ConsumeExit exit); + /// Returns the matching consume exit. ConsumeExit getExit(); + /// Adds a connector node to the scope. void addNode(ConnectorNode node) override; + /// Adds a multiedge from the source to the destination connector. void routeWrite(Connector from, Connector to) override; + /// Adds an edge to the scope. void addEdge(MultiEdge edge) override; + /// Maps the MLIR value to the specified connector. void mapConnector(Value value, Connector connector) override; + /// Returns the connector associated with a MLIR value, inserting consume + /// connectors when needed. Connector lookup(Value value) override; + /// Sets the number of processing elements. void setNumPes(StringRef pes); + /// Sets the name of the processing element index. void setPeIndex(StringRef pe); + /// Sets the condition to continue stream consumption. void setCondition(Code condition); + /// Emits the consume entry to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Represents a consume exit node in the SDFG. class ConsumeExit : public ConnectorNode { private: + /// Pointer to the implementation (Pimpl idiom). std::shared_ptr ptr; public: @@ -712,44 +971,69 @@ class ConsumeExit : public ConnectorNode { ConsumeExit() : ConnectorNode(nullptr) {} + /// Sets the consume entry this consume exit belongs to. void setEntry(ConsumeEntry entry); + + /// Emits the consume exit to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the consume entry node class. class ConsumeEntryImpl : public ScopeNodeImpl { private: + /// The matching consume exit. ConsumeExit exit; + /// The number of processing elements. std::string num_pes; + /// The name of the processing element index. std::string pe_index; + /// The condition to continue stream consumption. Code condition; public: ConsumeEntryImpl(Location location) : ScopeNodeImpl(location) {} + /// Sets the consume exit this consume entry belongs to. void setExit(ConsumeExit exit); + /// Returns the matching consume exit. ConsumeExit getExit(); + /// Adds a connector node to the scope. void addNode(ConnectorNode node) override; + /// Adds a multiedge from the source to the destination connector. void routeWrite(Connector from, Connector to) override; + /// Adds an edge to the scope. void addEdge(MultiEdge edge) override; + /// Maps the MLIR value to the specified connector. void mapConnector(Value value, Connector connector) override; + /// Returns the connector associated with a MLIR value, inserting consume + /// connectors when needed. Connector lookup(Value value, ConsumeEntry entry); + /// Sets the number of processing elements. void setNumPes(StringRef pes); + /// Sets the name of the processing element index. void setPeIndex(StringRef pe); + /// Sets the condition to continue stream consumption. void setCondition(Code condition); + /// Emits the consume entry to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; +/// Implementation of the consume exit node class. class ConsumeExitImpl : public ConnectorNodeImpl { private: + /// The matching consume entry. ConsumeEntry entry; public: ConsumeExitImpl(Location location) : ConnectorNodeImpl(location) {} + /// Sets the consume entry this consume exit belongs to. void setEntry(ConsumeEntry entry); + + /// Emits the consume exit to the output stream. void emit(emitter::JsonEmitter &jemit) override; }; diff --git a/include/SDFG/Translate/Translation.h b/include/SDFG/Translate/Translation.h index 8bb7351df..a62bfc63d 100644 --- a/include/SDFG/Translate/Translation.h +++ b/include/SDFG/Translate/Translation.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for the SDFG Dialect to SDFG IR translator. + #ifndef SDFG_Translation_H #define SDFG_Translation_H @@ -8,32 +12,51 @@ using namespace mlir::sdfg::emitter; namespace mlir::sdfg::translation { +/// Registers SDFG to SDFG IR translation. void registerToSDFGTranslation(); +/// Translates a module containing SDFG dialect to SDFG IR, outputs the result +/// to the provided output stream. LogicalResult translateToSDFG(ModuleOp &op, JsonEmitter &jemit); +/// Collects state node information in a top-level SDFG. LogicalResult collect(StateNode &op, SDFG &sdfg); +/// Collects edge information in a top-level SDFG. LogicalResult collect(EdgeOp &op, SDFG &sdfg); +/// Collects array/stream allocation information in a top-level SDFG. LogicalResult collect(AllocOp &op, SDFG &sdfg); +/// Collects symbol allocation information in a top-level SDFG. LogicalResult collect(AllocSymbolOp &op, SDFG &sdfg); +/// Collects array/stream allocation information in a scope. LogicalResult collect(AllocOp &op, ScopeNode &scope); +/// Collects tasklet information in a scope. LogicalResult collect(TaskletNode &op, ScopeNode &scope); -LogicalResult collect(NestedSDFGNode &op, ScopeNode &scope); +/// Collects library call information in a scope. LogicalResult collect(LibCallOp &op, ScopeNode &scope); +/// Collects nested SDFG node information in a scope. +LogicalResult collect(NestedSDFGNode &op, ScopeNode &scope); +/// Collects map node information in a scope. LogicalResult collect(MapNode &op, ScopeNode &scope); +/// Collects consume node information in a scope. LogicalResult collect(ConsumeNode &op, ScopeNode &scope); +/// Collects copy operation information in a scope. LogicalResult collect(CopyOp &op, ScopeNode &scope); +/// Collects store operation information in a scope. LogicalResult collect(StoreOp &op, ScopeNode &scope); +/// Collects load operation information in a scope. LogicalResult collect(LoadOp &op, ScopeNode &scope); +/// Collects symbol allocation information in a scope. LogicalResult collect(AllocSymbolOp &op, ScopeNode &scope); +/// Collects symbolic expression information in a scope. LogicalResult collect(SymOp &op, ScopeNode &scope); +/// Collects stream push operation information in a scope. LogicalResult collect(StreamPushOp &op, ScopeNode &scope); +/// Collects stream pop operation information in a scope. LogicalResult collect(StreamPopOp &op, ScopeNode &scope); - } // namespace mlir::sdfg::translation #endif // SDFG_Translation_H diff --git a/include/SDFG/Translate/liftToPython.h b/include/SDFG/Translate/liftToPython.h index 51d1f7499..992f5ed4e 100644 --- a/include/SDFG/Translate/liftToPython.h +++ b/include/SDFG/Translate/liftToPython.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for lifting operations to Python. + #ifndef SDFG_Translation_LiftToPython_H #define SDFG_Translation_LiftToPython_H @@ -5,7 +9,10 @@ namespace mlir::sdfg::translation { +/// Converts the operations in the first region of op to Python code. If +/// successful, returns Python code as a string. Optional liftToPython(Operation &op); +/// Provides a name for the tasklet. std::string getTaskletName(Operation &op); } // namespace mlir::sdfg::translation diff --git a/include/SDFG/Utils/AttributeToString.h b/include/SDFG/Utils/AttributeToString.h index f66404b66..25435144b 100644 --- a/include/SDFG/Utils/AttributeToString.h +++ b/include/SDFG/Utils/AttributeToString.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for the attribute to string utility functions. + #ifndef SDFG_Utils_AttributeToString_H #define SDFG_Utils_AttributeToString_H @@ -7,6 +11,7 @@ namespace mlir::sdfg::utils { +/// Prints an attribute to a string. std::string attributeToString(Attribute attribute, Operation &op); } // namespace mlir::sdfg::utils diff --git a/include/SDFG/Utils/CMakeLists.txt b/include/SDFG/Utils/CMakeLists.txt index aa8eb7810..d2a6b3f77 100644 --- a/include/SDFG/Utils/CMakeLists.txt +++ b/include/SDFG/Utils/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + target_sources( SOURCE_FILES_H PRIVATE AttributeToString.h diff --git a/include/SDFG/Utils/GetParents.h b/include/SDFG/Utils/GetParents.h index 666151ddc..2451d7955 100644 --- a/include/SDFG/Utils/GetParents.h +++ b/include/SDFG/Utils/GetParents.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for the parent utility functions. + #ifndef SDFG_Utils_GetParents_H #define SDFG_Utils_GetParents_H @@ -5,8 +9,13 @@ namespace mlir::sdfg::utils { +/// Returns the parent SDFG node, NestedSDFG node or nullptr if a parent does +/// not exist. Operation *getParentSDFG(Operation &op); +/// Returns the parent State node or nullptr if a parent does not exist. StateNode getParentState(Operation &op, bool ignoreSDFGs = false); +/// Returns top-level module operation or nullptr if a parent does not exist. +ModuleOp getTopModuleOp(Operation *op); } // namespace mlir::sdfg::utils diff --git a/include/SDFG/Utils/GetSizedType.h b/include/SDFG/Utils/GetSizedType.h index c2f2122a0..855ee6fd6 100644 --- a/include/SDFG/Utils/GetSizedType.h +++ b/include/SDFG/Utils/GetSizedType.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for the sized type utility functions. + #ifndef SDFG_Utils_GetSizedType_H #define SDFG_Utils_GetSizedType_H @@ -5,7 +9,9 @@ namespace mlir::sdfg::utils { +/// Extracts the sized type from an array or stream type. SizedType getSizedType(Type t); +/// Returns true if the provided type is a sized type. bool isSizedType(Type t); } // namespace mlir::sdfg::utils diff --git a/include/SDFG/Utils/IDGenerator.h b/include/SDFG/Utils/IDGenerator.h index 08440d97d..e5b5da5c7 100644 --- a/include/SDFG/Utils/IDGenerator.h +++ b/include/SDFG/Utils/IDGenerator.h @@ -1,9 +1,15 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for the ID generator utility functions. + #ifndef SDFG_Utils_IDGenerator_H #define SDFG_Utils_IDGenerator_H namespace mlir::sdfg::utils { +/// Returns a globally unique ID. unsigned generateID(); +/// Resets the ID generator. void resetIDGenerator(); } // namespace mlir::sdfg::utils diff --git a/include/SDFG/Utils/NameGenerator.h b/include/SDFG/Utils/NameGenerator.h index 06f704850..1403e800c 100644 --- a/include/SDFG/Utils/NameGenerator.h +++ b/include/SDFG/Utils/NameGenerator.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for the name generator utility functions. + #ifndef SDFG_Utils_NameGenerator_H #define SDFG_Utils_NameGenerator_H @@ -5,8 +9,9 @@ namespace mlir::sdfg::utils { +/// Converts the provided string to a globally unique one. std::string generateName(std::string base); -} +} // namespace mlir::sdfg::utils #endif // SDFG_Utils_NameGenerator_H diff --git a/include/SDFG/Utils/OperationToString.h b/include/SDFG/Utils/OperationToString.h index 1244a670e..955fecc0b 100644 --- a/include/SDFG/Utils/OperationToString.h +++ b/include/SDFG/Utils/OperationToString.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for the operation to string utility functions. + #ifndef SDFG_Utils_OperationToString_H #define SDFG_Utils_OperationToString_H @@ -6,6 +10,7 @@ namespace mlir::sdfg::utils { +/// Prints an operation to a string. std::string operationToString(Operation &op); } // namespace mlir::sdfg::utils diff --git a/include/SDFG/Utils/Sanitizer.h b/include/SDFG/Utils/Sanitizer.h index d6f02463e..dbba790d3 100644 --- a/include/SDFG/Utils/Sanitizer.h +++ b/include/SDFG/Utils/Sanitizer.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for the sanitizer utility functions. + #ifndef SDFG_Utils_Sanitizer_H #define SDFG_Utils_Sanitizer_H @@ -5,8 +9,10 @@ namespace mlir::sdfg::utils { +/// Sanitizes the provided string to only include alphanumericals and +/// underscores. void sanitizeName(std::string &name); -} +} // namespace mlir::sdfg::utils #endif // SDFG_Utils_Sanitizer_H diff --git a/include/SDFG/Utils/Utils.h b/include/SDFG/Utils/Utils.h index c4aa0c58e..d53196d52 100644 --- a/include/SDFG/Utils/Utils.h +++ b/include/SDFG/Utils/Utils.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for all utility functions. + #ifndef SDFG_Utils_H #define SDFG_Utils_H diff --git a/include/SDFG/Utils/ValueToString.h b/include/SDFG/Utils/ValueToString.h index 3becc567a..19c7b9be7 100644 --- a/include/SDFG/Utils/ValueToString.h +++ b/include/SDFG/Utils/ValueToString.h @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// Header for the value to string utility functions. + #ifndef SDFG_Utils_ValueToString_H #define SDFG_Utils_ValueToString_H @@ -6,7 +10,9 @@ namespace mlir::sdfg::utils { +/// Prints a value to a string. Optionally takes a context operation. std::string valueToString(Value value); +/// Prints a value to a string. Optionally takes a context operation. std::string valueToString(Value value, Operation &op); } // namespace mlir::sdfg::utils diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index c4feb2a26..a26945b68 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1 +1,3 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_subdirectory(SDFG) diff --git a/lib/SDFG/CMakeLists.txt b/lib/SDFG/CMakeLists.txt index 06b67873b..c7e53ecad 100644 --- a/lib/SDFG/CMakeLists.txt +++ b/lib/SDFG/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_subdirectory(Utils) add_subdirectory(Dialect) add_subdirectory(Translate) diff --git a/lib/SDFG/Conversion/CMakeLists.txt b/lib/SDFG/Conversion/CMakeLists.txt index 3ad79fafd..ef6fd1fef 100644 --- a/lib/SDFG/Conversion/CMakeLists.txt +++ b/lib/SDFG/Conversion/CMakeLists.txt @@ -1,2 +1,6 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_subdirectory(GenericToSDFG) add_subdirectory(LinalgToSDFG) + +add_subdirectory(SDFGToGeneric) diff --git a/lib/SDFG/Conversion/GenericToSDFG/CMakeLists.txt b/lib/SDFG/Conversion/GenericToSDFG/CMakeLists.txt index 09b749a99..51bef9fad 100644 --- a/lib/SDFG/Conversion/GenericToSDFG/CMakeLists.txt +++ b/lib/SDFG/Conversion/GenericToSDFG/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_mlir_dialect_library( GenericToSDFG ConvertGenericToSDFG.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/SDFG/Conversion/GenericToSDFG DEPENDS diff --git a/lib/SDFG/Conversion/GenericToSDFG/ConvertGenericToSDFG.cpp b/lib/SDFG/Conversion/GenericToSDFG/ConvertGenericToSDFG.cpp index c7ca1aab3..9cfc2379f 100644 --- a/lib/SDFG/Conversion/GenericToSDFG/ConvertGenericToSDFG.cpp +++ b/lib/SDFG/Conversion/GenericToSDFG/ConvertGenericToSDFG.cpp @@ -1,14 +1,17 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file defines a converter from builtin dialects to the SDFG dialect. + #include "SDFG/Conversion/GenericToSDFG/PassDetail.h" #include "SDFG/Conversion/GenericToSDFG/Passes.h" #include "SDFG/Dialect/Dialect.h" #include "SDFG/Utils/Utils.h" -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace sdfg; @@ -18,6 +21,7 @@ using namespace conversion; // Target & Type Converter //===----------------------------------------------------------------------===// +/// Defines the target to convert to. struct SDFGTarget : public ConversionTarget { SDFGTarget(MLIRContext &ctx) : ConversionTarget(ctx) { // Every Op in the SDFG Dialect is legal @@ -30,26 +34,29 @@ struct SDFGTarget : public ConversionTarget { } }; -class MemrefToMemletConverter : public TypeConverter { +/// Defines a type converter, converting input types to array types. +class ToArrayConverter : public TypeConverter { public: - MemrefToMemletConverter() { + ToArrayConverter() { addConversion([](Type type) { return type; }); addConversion(convertMemrefTypes); addConversion(convertLLVMPtrTypes); } + /// Attempts to convert MemRef types to array types. static Optional convertMemrefTypes(Type type) { if (MemRefType mem = type.dyn_cast()) { SmallVector ints; SmallVector symbols; SmallVector shape; for (int64_t dim : mem.getShape()) { - if (dim <= 0) { + if (dim < 0) { StringAttr sym = StringAttr::get(mem.getContext(), sdfg::utils::generateName("s")); symbols.push_back(sym); shape.push_back(false); } else { + dim = dim == 0 ? 1 : dim; ints.push_back(dim); shape.push_back(true); } @@ -62,6 +69,7 @@ class MemrefToMemletConverter : public TypeConverter { return std::nullopt; } + /// Attempts to convert LLVM Ptr types to array types. static Optional convertLLVMPtrTypes(Type type) { if (mlir::LLVM::LLVMPointerType ptrType = type.dyn_cast()) { @@ -93,8 +101,10 @@ class MemrefToMemletConverter : public TypeConverter { // Helpers //===----------------------------------------------------------------------===// -SmallVector createLoads(PatternRewriter &rewriter, Location loc, - ArrayRef vals) { +/// Checks for each value if it originates from an allocation operation. If so, +/// inserts load operations to access the stored value. +static SmallVector createLoads(PatternRewriter &rewriter, Location loc, + ArrayRef vals) { SmallVector loadedOps; for (Value operand : vals) { if (operand.getDefiningOp() != nullptr && @@ -116,12 +126,15 @@ SmallVector createLoads(PatternRewriter &rewriter, Location loc, return loadedOps; } -Value createLoad(PatternRewriter &rewriter, Location loc, Value val) { +/// Wrapper function for the above helper function. Instead of a list of values +/// this function instead just takes a single value. +static Value createLoad(PatternRewriter &rewriter, Location loc, Value val) { SmallVector loadedOps = {val}; return createLoads(rewriter, loc, loadedOps)[0]; } -Operation *getParentSDFG(Operation *op) { +/// Returns the closest SDFG node or nested SDFG node. +static Operation *getParentSDFG(Operation *op) { Operation *parent = op->getParentOp(); if (isa(parent)) @@ -133,7 +146,8 @@ Operation *getParentSDFG(Operation *op) { return getParentSDFG(parent); } -uint32_t getSDFGNumArgs(Operation *op) { +/// Returns the number of arguments for a SDFG node or nested SDFG node. +static uint32_t getSDFGNumArgs(Operation *op) { if (SDFGNode sdfg = dyn_cast(op)) { return sdfg.getNumArgs(); } @@ -145,7 +159,8 @@ uint32_t getSDFGNumArgs(Operation *op) { return -1; } -StateNode getFirstState(Operation *op) { +/// Returns the first state of a SDFG node or nested SDFG node. +static StateNode getFirstState(Operation *op) { if (SDFGNode sdfg = dyn_cast(op)) { return sdfg.getFirstState(); } @@ -157,17 +172,11 @@ StateNode getFirstState(Operation *op) { return nullptr; } -ModuleOp getTopModuleOp(Operation *op) { - Operation *parent = op->getParentOp(); - - if (isa(parent)) - return cast(parent); - - return getTopModuleOp(parent); -} - -void linkToLastState(PatternRewriter &rewriter, Location loc, - StateNode &state) { +/// Searches for the state appearing before the provided state and inserts an +/// edge without assignemnts or conditions from the previous state to the +/// current state. +static void linkToLastState(PatternRewriter &rewriter, Location loc, + StateNode &state) { Operation *sdfg = getParentSDFG(state); OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint(); rewriter.setInsertionPointToEnd(&sdfg->getRegion(0).getBlocks().front()); @@ -184,8 +193,11 @@ void linkToLastState(PatternRewriter &rewriter, Location loc, rewriter.restoreInsertionPoint(ip); } -void linkToNextState(PatternRewriter &rewriter, Location loc, - StateNode &state) { +/// Searches for the state appearing after the provided state and inserts an +/// edge without assignments or condition from the current state to the next +/// one. +static void linkToNextState(PatternRewriter &rewriter, Location loc, + StateNode &state) { Operation *sdfg = getParentSDFG(state); OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint(); rewriter.setInsertionPointToEnd(&sdfg->getRegion(0).getBlocks().front()); @@ -203,17 +215,23 @@ void linkToNextState(PatternRewriter &rewriter, Location loc, rewriter.restoreInsertionPoint(ip); } -void markToLink(Operation &op) { +/// Marks (using a boolean attribute) an operation (i.e. state) to be connected +/// to the next state with an edge. +static void markToLink(Operation &op) { BoolAttr boolAttr = BoolAttr::get(op.getContext(), true); op.setAttr("linkToNext", boolAttr); } -bool markedToLink(Operation &op) { +/// Checks if a operation (i.e. state) is marked (using a boolean attribute) to +/// be connected to the next state with an edge. +static bool markedToLink(Operation &op) { return op.hasAttr("linkToNext") && op.getAttr("linkToNext").cast().getValue(); } -Value getTransientValue(Value val) { +/// Returns the allocation operation from a value originating from a load +/// operation. +static Value getTransientValue(Value val) { if (val.getDefiningOp() != nullptr && isa(val.getDefiningOp())) { LoadOp load = cast(val.getDefiningOp()); AllocOp alloc = cast(load.getArr().getDefiningOp()); @@ -227,6 +245,7 @@ Value getTransientValue(Value val) { // Func Patterns //===----------------------------------------------------------------------===// +/// Converts a func::FuncOp to a SDFG node. class FuncToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -280,14 +299,14 @@ class FuncToSDFG : public OpConversionPattern { SmallVector args = {}; for (unsigned i = 0; i < op.getNumArguments(); ++i) { - MemrefToMemletConverter memo; - Type nt = memo.convertType(op.getArgumentTypes()[i]); + ToArrayConverter tac; + Type nt = tac.convertType(op.getArgumentTypes()[i]); args.push_back(nt); } for (unsigned i = 0; i < op.getNumResults(); ++i) { - MemrefToMemletConverter memo; - Type nt = memo.convertType(op.getResultTypes()[i]); + ToArrayConverter tac; + Type nt = tac.convertType(op.getResultTypes()[i]); if (!nt.isa()) { SizedType sized = SizedType::get(nt.getContext(), nt, {}, {}, {}); @@ -326,6 +345,8 @@ class FuncToSDFG : public OpConversionPattern { } }; +/// Converts a func::CallOp to a nested SDFG node or to a tasklet in special +/// cases. class CallToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -337,7 +358,7 @@ class CallToSDFG : public OpConversionPattern { // TODO: Support external calls // TODO: Support return values std::string callee = op.getCallee().str(); - ModuleOp mod = getTopModuleOp(op); + ModuleOp mod = sdfg::utils::getTopModuleOp(op); func::FuncOp funcOp = dyn_cast(mod.lookupSymbol(callee)); // HACK: The function call got replaced at `FuncToSDFG` (PolybenchC) @@ -349,7 +370,7 @@ class CallToSDFG : public OpConversionPattern { // HACK: Removes special function calls (cbrt, exit) and creates tasklet // with annotation (LULESH) - if (callee == "cbrt" || callee == "exit") { + if (callee == "cbrt" || callee == "exit" || funcOp.isExternal()) { StateNode state = StateNode::create(rewriter, op->getLoc(), callee); SmallVector operands = adaptor.getOperands(); @@ -369,8 +390,8 @@ class CallToSDFG : public OpConversionPattern { rewriter.setInsertionPointAfter(task); if (task.getNumResults() == 1) { - MemrefToMemletConverter memo; - Type nt = memo.convertType(op->getResultTypes()[0]); + ToArrayConverter tac; + Type nt = tac.convertType(op->getResultTypes()[0]); SizedType sized = SizedType::get(op->getLoc().getContext(), nt, {}, {}, {}); nt = ArrayType::get(op->getLoc().getContext(), sized); @@ -461,6 +482,7 @@ class CallToSDFG : public OpConversionPattern { } }; +/// Converts a func::ReturnOp to copy and store operations. class ReturnToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -507,6 +529,7 @@ class ReturnToSDFG : public OpConversionPattern { // Arith & Math Patterns //===----------------------------------------------------------------------===// +/// Wraps any arith and math operation into a tasklet. class OpToTasklet : public ConversionPattern { public: OpToTasklet(TypeConverter &converter, MLIRContext *context) @@ -535,8 +558,8 @@ class OpToTasklet : public ConversionPattern { SmallVector allocs; for (Type opType : op->getResultTypes()) { - MemrefToMemletConverter memo; - Type newType = memo.convertType(opType); + ToArrayConverter tac; + Type newType = tac.convertType(opType); SizedType sizedType = SizedType::get(op->getLoc().getContext(), newType, {}, {}, {}); newType = ArrayType::get(op->getLoc().getContext(), sizedType); @@ -595,6 +618,7 @@ class OpToTasklet : public ConversionPattern { // Memref Patterns //===----------------------------------------------------------------------===// +/// Converts a memref::LoadOp to a sdfg::LoadOp. class MemrefLoadToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -636,6 +660,7 @@ class MemrefLoadToSDFG : public OpConversionPattern { } }; +/// Converts a memref::StoreOp to a sdfg::StoreOp. class MemrefStoreToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -662,6 +687,30 @@ class MemrefStoreToSDFG : public OpConversionPattern { } }; +/// Converts a memref::CopyOp to a sdfg::CopyOp. +class MemrefCopyToSDFG : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + StateNode state = StateNode::create(rewriter, op->getLoc(), "copy"); + + Value source = createLoad(rewriter, op.getLoc(), adaptor.getSource()); + Value target = createLoad(rewriter, op.getLoc(), adaptor.getTarget()); + CopyOp::create(rewriter, op.getLoc(), source, target); + + linkToLastState(rewriter, op->getLoc(), state); + if (markedToLink(*op)) + linkToNextState(rewriter, op->getLoc(), state); + + rewriter.eraseOp(op); + return success(); + } +}; + +/// Erases memref::GlobalOp. class MemrefGlobalToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -674,6 +723,7 @@ class MemrefGlobalToSDFG : public OpConversionPattern { } }; +/// Converts a memref::GetGlobalOp to an allocation operation. class MemrefGetGlobalToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -699,6 +749,7 @@ class MemrefGetGlobalToSDFG : public OpConversionPattern { } }; +/// Converts a memref::AllocOp to a sdfg::AllocOp. class MemrefAllocToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -746,6 +797,7 @@ class MemrefAllocToSDFG : public OpConversionPattern { } }; +/// Converts a memref::AllocaOp to a sdfg::AllocOp. class MemrefAllocaToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -793,6 +845,7 @@ class MemrefAllocaToSDFG : public OpConversionPattern { } }; +/// Erases memref::DeallocOp. class MemrefDeallocToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -805,6 +858,7 @@ class MemrefDeallocToSDFG : public OpConversionPattern { } }; +/// Converts a memref::CastOp to an allocation and copy operation. class MemrefCastToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -851,6 +905,8 @@ class MemrefCastToSDFG : public OpConversionPattern { // SCF Patterns //===----------------------------------------------------------------------===// +/// Converts a scf::ForOp to multiple states, modeling a for loop with +/// assignment and conditional edges. class SCFForToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -872,8 +928,8 @@ class SCFForToSDFG : public OpConversionPattern { for (unsigned i = 0; i < op.getNumIterOperands(); ++i) { Value iterOp = op.getIterOperands()[i]; - MemrefToMemletConverter memo; - Type newType = memo.convertType(iterOp.getType()); + ToArrayConverter tac; + Type newType = tac.convertType(iterOp.getType()); SizedType sizedType = SizedType::get(op->getLoc().getContext(), newType, {}, {}, {}); newType = ArrayType::get(op->getLoc().getContext(), sizedType); @@ -1018,6 +1074,8 @@ class SCFForToSDFG : public OpConversionPattern { } }; +/// Converts a scf::WhileOp to multiple states, modeling a while loop with +/// assignment and conditional edges. class SCFWhileToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1041,7 +1099,7 @@ class SCFWhileToSDFG : public OpConversionPattern { // Itervars for (BlockArgument arg : op.getBeforeArguments()) { - MemrefToMemletConverter converter; + ToArrayConverter converter; Type newType = converter.convertType(arg.getType()); SizedType sizedType = SizedType::get(context, newType, {}, {}, {}); ArrayType arrayType = ArrayType::get(context, sizedType); @@ -1054,7 +1112,7 @@ class SCFWhileToSDFG : public OpConversionPattern { } // Condition - MemrefToMemletConverter converter; + ToArrayConverter converter; Type newType = converter.convertType(conditionOp.getCondition().getType()); SizedType sizedType = SizedType::get(context, newType, {}, {}, {}); ArrayType arrayType = ArrayType::get(context, sizedType); @@ -1064,7 +1122,7 @@ class SCFWhileToSDFG : public OpConversionPattern { // Condition Arguments for (Value arg : conditionOp.getArgs()) { - MemrefToMemletConverter converter; + ToArrayConverter converter; Type newType = converter.convertType(arg.getType()); SizedType sizedType = SizedType::get(context, newType, {}, {}, {}); ArrayType arrayType = ArrayType::get(context, sizedType); @@ -1178,6 +1236,8 @@ class SCFWhileToSDFG : public OpConversionPattern { } }; +/// Converts a scf::ConditionOp to store operations, storing the condition and +/// arguments. class SCFConditionToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1213,6 +1273,8 @@ class SCFConditionToSDFG : public OpConversionPattern { } }; +/// Converts a scf::IfOp to multiple states, modeling an if-clause with +/// conditional edges. class SCFIfToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1228,9 +1290,9 @@ class SCFIfToSDFG : public OpConversionPattern { AllocSymbolOp::create(rewriter, op.getLoc(), condName); for (unsigned i = 0; i < op.getNumResults(); ++i) { - MemrefToMemletConverter memo; + ToArrayConverter tac; - Type nt = memo.convertType(op.getResultTypes()[0]); + Type nt = tac.convertType(op.getResultTypes()[0]); SizedType sized = SizedType::get(op->getLoc().getContext(), nt, {}, {}, {}); nt = ArrayType::get(op->getLoc().getContext(), sized); @@ -1324,6 +1386,8 @@ class SCFIfToSDFG : public OpConversionPattern { } }; +/// Converts a scf::YieldOp to store operations, storing the values being +/// yielded in for/while/if clauses. class SCFYieldToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1366,6 +1430,7 @@ class SCFYieldToSDFG : public OpConversionPattern { // LLVM Patterns //===----------------------------------------------------------------------===// +/// Converts a LLVM::AllocaOp to sdfg::AllocOp. class LLVMAllocaToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1413,6 +1478,7 @@ class LLVMAllocaToSDFG : public OpConversionPattern { } }; +/// Converts a LLVM::BitcastOp to sdfg::ViewCastOp. class LLVMBitcastToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1432,6 +1498,7 @@ class LLVMBitcastToSDFG : public OpConversionPattern { } }; +/// Converts a LLVM::GEPOp to an index computation. class LLVMGEPToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1463,6 +1530,7 @@ class LLVMGEPToSDFG : public OpConversionPattern { } }; +/// Converts a LLVM::LoadOp to sdfg::LoadOp. class LLVMLoadToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1510,6 +1578,7 @@ class LLVMLoadToSDFG : public OpConversionPattern { } }; +/// Converts a LLVM::StoreOp to sdfg::StoreOp. class LLVMStoreToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1539,6 +1608,7 @@ class LLVMStoreToSDFG : public OpConversionPattern { } }; +/// Erases LLVM::GlobalOp. class LLVMGlobalToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1552,6 +1622,7 @@ class LLVMGlobalToSDFG : public OpConversionPattern { } }; +/// Erases LLVM::LLVMFuncOp. class LLVMFuncToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1565,6 +1636,7 @@ class LLVMFuncToSDFG : public OpConversionPattern { } }; +/// Wraps LLVM::UndefOp into tasklets. class LLVMUndefToSDFG : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1575,8 +1647,8 @@ class LLVMUndefToSDFG : public OpConversionPattern { std::string name = sdfg::utils::operationToString(*op); StateNode state = StateNode::create(rewriter, op->getLoc(), name); - MemrefToMemletConverter memo; - Type nt = memo.convertType(op->getResultTypes()[0]); + ToArrayConverter tac; + Type nt = tac.convertType(op->getResultTypes()[0]); SizedType sized = SizedType::get(op->getLoc().getContext(), nt, {}, {}, {}); nt = ArrayType::get(op->getLoc().getContext(), sized); @@ -1587,8 +1659,8 @@ class LLVMUndefToSDFG : public OpConversionPattern { SmallVector allocs; for (Type opType : op->getResultTypes()) { - MemrefToMemletConverter memo; - Type newType = memo.convertType(opType); + ToArrayConverter tac; + Type newType = tac.convertType(opType); SizedType sizedType = SizedType::get(op->getLoc().getContext(), newType, {}, {}, {}); newType = ArrayType::get(op->getLoc().getContext(), sizedType); @@ -1640,6 +1712,7 @@ class LLVMUndefToSDFG : public OpConversionPattern { // Pass //===----------------------------------------------------------------------===// +/// Registers all the patterns above in a RewritePatternSet. void populateGenericToSDFGConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter) { MLIRContext *ctxt = patterns.getContext(); @@ -1651,6 +1724,7 @@ void populateGenericToSDFGConversionPatterns(RewritePatternSet &patterns, patterns.add(converter, ctxt); patterns.add(converter, ctxt); + patterns.add(converter, ctxt); patterns.add(converter, ctxt); patterns.add(converter, ctxt); patterns.add(converter, ctxt); @@ -1688,7 +1762,7 @@ struct GenericToSDFGPass }; } // namespace -// Gets the name of the first function that isn't called by any other function +/// Gets the name of the first function that isn't called by any other function. llvm::Optional getMainFunctionName(ModuleOp moduleOp) { for (func::FuncOp mainFuncOp : moduleOp.getOps()) { // No need to check function declarations @@ -1725,10 +1799,11 @@ llvm::Optional getMainFunctionName(ModuleOp moduleOp) { return std::nullopt; } +/// Runs the pass on the top-level module operation. void GenericToSDFGPass::runOnOperation() { ModuleOp module = getOperation(); - // TODO: Find a way to get func name via CLI instead of inferring + // FIXME: Find a way to get func name via CLI instead of inferring llvm::Optional mainFuncNameOpt = getMainFunctionName(module); if (mainFuncNameOpt) mainFuncName = *mainFuncNameOpt; @@ -1738,7 +1813,7 @@ void GenericToSDFGPass::runOnOperation() { module->removeAttr(a.getName()); SDFGTarget target(getContext()); - MemrefToMemletConverter converter; + ToArrayConverter converter; RewritePatternSet patterns(&getContext()); populateGenericToSDFGConversionPatterns(patterns, converter); @@ -1747,6 +1822,7 @@ void GenericToSDFGPass::runOnOperation() { signalPassFailure(); } +/// Returns a unique pointer to this pass. std::unique_ptr conversion::createGenericToSDFGPass(StringRef getMainFuncName) { return std::make_unique(getMainFuncName); diff --git a/lib/SDFG/Conversion/LinalgToSDFG/CMakeLists.txt b/lib/SDFG/Conversion/LinalgToSDFG/CMakeLists.txt index 396055f9d..dfe0004b1 100644 --- a/lib/SDFG/Conversion/LinalgToSDFG/CMakeLists.txt +++ b/lib/SDFG/Conversion/LinalgToSDFG/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_mlir_dialect_library( LinalgToSDFG ConvertLinalgToSDFG.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/SDFG/Conversion/LinalgToSDFG DEPENDS diff --git a/lib/SDFG/Conversion/LinalgToSDFG/ConvertLinalgToSDFG.cpp b/lib/SDFG/Conversion/LinalgToSDFG/ConvertLinalgToSDFG.cpp index bab211f7c..89ee6da9f 100644 --- a/lib/SDFG/Conversion/LinalgToSDFG/ConvertLinalgToSDFG.cpp +++ b/lib/SDFG/Conversion/LinalgToSDFG/ConvertLinalgToSDFG.cpp @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file defines a converter from the linalg dialect to the SDFG dialect. + #include "SDFG/Conversion/LinalgToSDFG/PassDetail.h" #include "SDFG/Conversion/LinalgToSDFG/Passes.h" #include "SDFG/Dialect/Dialect.h" @@ -17,6 +21,7 @@ using namespace conversion; // Target & Type Converter //===----------------------------------------------------------------------===// +/// Defines the target to convert to. struct SDFGTarget : public ConversionTarget { SDFGTarget(MLIRContext &ctx) : ConversionTarget(ctx) { // Every operation is legal (best effort) @@ -28,6 +33,7 @@ struct SDFGTarget : public ConversionTarget { // Pass //===----------------------------------------------------------------------===// +/// Registers all the patterns above in a RewritePatternSet. void populateLinalgToSDFGConversionPatterns(RewritePatternSet &patterns) {} namespace { @@ -39,6 +45,7 @@ struct LinalgToSDFGPass }; } // namespace +/// Runs the pass on the top-level module operation. void LinalgToSDFGPass::runOnOperation() { ModuleOp module = getOperation(); @@ -51,6 +58,7 @@ void LinalgToSDFGPass::runOnOperation() { signalPassFailure(); } +/// Returns a unique pointer to this pass. std::unique_ptr conversion::createLinalgToSDFGPass() { return std::make_unique(); } diff --git a/lib/SDFG/Conversion/SDFGToGeneric/CMakeLists.txt b/lib/SDFG/Conversion/SDFGToGeneric/CMakeLists.txt new file mode 100644 index 000000000..2f19c31f9 --- /dev/null +++ b/lib/SDFG/Conversion/SDFGToGeneric/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +add_mlir_dialect_library( + SDFGToGeneric + ConvertSDFGToGeneric.cpp + SymbolicParser.cpp + OpCreators.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/SDFG/Conversion/SDFGToGeneric + DEPENDS + MLIRSDFGToGenericPassIncGen) + +target_link_libraries(SDFGToGeneric PUBLIC MLIRIR) + +target_sources(SOURCE_FILES_CPP PRIVATE ConvertSDFGToGeneric.cpp + SymbolicParser.cpp OpCreators.cpp) diff --git a/lib/SDFG/Conversion/SDFGToGeneric/ConvertSDFGToGeneric.cpp b/lib/SDFG/Conversion/SDFGToGeneric/ConvertSDFGToGeneric.cpp new file mode 100644 index 000000000..808a1b9c8 --- /dev/null +++ b/lib/SDFG/Conversion/SDFGToGeneric/ConvertSDFGToGeneric.cpp @@ -0,0 +1,713 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file defines a converter from the SDFG dialect to the builtin dialects. + +// +// SDFG -> func.func +// States -> block +// Edges -> +// Default: cf.br +// Assignment: Insert block, add assignments at the end, cf.br +// Condition: Insert blocks (true/false), compute condition, cf.cond_br +// +// Alloc -> memref.alloc +// Load -> memref.load +// Store -> memref.store +// Copy -> memref.copy +// +// Alloc Symbol -> memref.alloc (int64) +// Sym -> +// If single symbol: memref.load +// If expression: parse, build AST, create ops +// +// Return -> func.return +// Tasklet -> func.func + func.call +// +// Map -> scf.parallel (or: affine.parallel, affine.for, scf.forall, scf.for) +// +// Consume -> TBD +// + +#include "SDFG/Conversion/SDFGToGeneric/OpCreators.h" +#include "SDFG/Conversion/SDFGToGeneric/PassDetail.h" +#include "SDFG/Conversion/SDFGToGeneric/Passes.h" +#include "SDFG/Conversion/SDFGToGeneric/SymbolicParser.h" +#include "SDFG/Dialect/Dialect.h" +#include "SDFG/Utils/Utils.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace sdfg; +using namespace conversion; + +/// Maps state name to their generated block +llvm::StringMap blockMap; + +/// For each function scope, map symbols to values +/// Function scope is determined by the function name +llvm::StringMap> symbolMap; + +// HACK: Keeps track of processed EdgeOps +llvm::DenseSet processedEdges; + +//===----------------------------------------------------------------------===// +// Target & Type Converter +//===----------------------------------------------------------------------===// + +/// Defines the target to convert to. +struct GenericTarget : public ConversionTarget { + GenericTarget(MLIRContext &ctx) : ConversionTarget(ctx) { + // Every Op in the SDFG Dialect is illegal + addIllegalDialect(); + // Implicit top level module operation is legal + addLegalOp(); + // Dialects generated by this pass are legal + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + // All other operations are illegal + markUnknownOpDynamicallyLegal([](Operation *op) { return false; }); + } +}; + +/// Defines a type converter, converting input types to MemRef types. +class ToMemrefConverter : public TypeConverter { +public: + ToMemrefConverter() { + addConversion([](Type type) { return type; }); + addConversion(convertArrayTypes); + } + + /// Attempts to convert array types to MemRef types. + static Optional convertArrayTypes(Type type) { + if (ArrayType array = type.dyn_cast()) { + SmallVector shape; + unsigned intIdx = 0; + + for (unsigned i = 0; i < array.getShape().size(); ++i) { + int64_t val = array.getShape()[i] ? array.getIntegers()[intIdx++] + : ShapedType::kDynamic; + + if (val < 0) + val = ShapedType::kDynamic; + shape.push_back(val); + } + + return MemRefType::get(shape, array.getElementType()); + } + + return std::nullopt; + } +}; + +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +/// Gets the current function scope. +llvm::StringRef getFunctionScope(Operation *op) { + Operation *parent = op->getParentOfType(); + if (parent == nullptr) + return ""; + + return cast(parent).getName(); +} + +/// Creates operations that perform the symbolic expression. +static Value symbolicExpressionToMLIR(PatternRewriter &rewriter, Operation *op, + StringRef symExpr, + llvm::StringMap refMap = {}) { + std::unique_ptr ast = SymbolicParser().parse(symExpr); + if (!ast) + emitError(op->getLoc(), "failed to parse symbolic expression"); + + return ast->codegen(rewriter, op->getLoc(), symbolMap[getFunctionScope(op)], + refMap); +} + +/// Converts a numberlist (symbol, integer, operand) to Values. +static SmallVector numberListToMLIR(PatternRewriter &rewriter, + Operation *op, StringRef attrName) { + ArrayAttr attrList = op->getAttr(attrName).cast(); + ArrayAttr numList = + op->getAttr(attrName.str() + "_numList").cast(); + SmallVector values; + + for (unsigned i = 0; i < numList.size(); ++i) { + IntegerAttr num = numList[i].cast(); + + if (num.getValue().isNegative()) { + // Number is a symbol or integer + Attribute attr = attrList[-num.getInt() - 1]; + std::string expression; + + if (attr.isa()) + expression = attr.cast().str(); + else + expression = std::to_string(attr.cast().getInt()); + + Value val = symbolicExpressionToMLIR(rewriter, op, expression); + val = + createIndexCast(rewriter, op->getLoc(), rewriter.getIndexType(), val); + values.push_back(val); + } else { + // Number is a operand + values.push_back(op->getOperand(num.getInt())); + } + } + + return values; +} + +//===----------------------------------------------------------------------===// +// SDFG, State & Edge Patterns +//===----------------------------------------------------------------------===// + +/// Converts a SDFG node to func::FuncOp. +class SDFGToFunc : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SDFGNode op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Mark the entry state + op.getEntryState()->setAttr("entry", rewriter.getBoolAttr(true)); + + // Create a function and clone the sdfg body + SmallVector convertedTypes; + if (getTypeConverter() + ->convertTypes(op.getBody().getArgumentTypes(), convertedTypes) + .failed()) + return failure(); + + // Add symbols to signature + SmallVector symbols; + + for (Type t : op.getBody().getArgumentTypes()) + if (ArrayType arrayT = t.dyn_cast()) + for (StringAttr sym : arrayT.getSymbols()) + if (find(symbols, sym) == symbols.end()) { + convertedTypes.push_back(rewriter.getIndexType()); + symbols.push_back(sym); + } + + func::FuncOp funcOp = + createFunc(rewriter, op.getLoc(), "sdfg", convertedTypes, {}, "public"); + funcOp.getBody().takeBody(op.getBody()); + + // Add symbols to scope + for (StringAttr sym : symbols) + symbolMap[getFunctionScope(op)][sym] = + funcOp.getBody().addArgument(rewriter.getIndexType(), op.getLoc()); + + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), + *getTypeConverter()))) + return failure(); + + rewriter.eraseOp(op); + return success(); + } +}; + +/// Converts a nested SDFG node to func::FuncOp and func::CallOp. +class NestedSDFGToFunc : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(NestedSDFGNode op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Create call + std::string name = sdfg::utils::generateName("nested_sdfg"); + + // Propagate symbols + SmallVector operands = adaptor.getOperands(); + for (llvm::StringMapEntry &v : symbolMap[getFunctionScope(op)]) + operands.push_back(v.getValue()); + + createCall(rewriter, op.getLoc(), {}, name, operands); + + // Set insertion point before the current func + Operation *parent = op->getParentOfType(); + if (parent == nullptr) + return failure(); + + rewriter.setInsertionPoint(parent); + + // Mark the entry state + op.getEntryState()->setAttr("entry", rewriter.getBoolAttr(true)); + + // Create a function and clone the sdfg body + SmallVector operandTypes; + if (getTypeConverter() + ->convertTypes(op.getOperandTypes(), operandTypes) + .failed()) + return failure(); + + // Add symbols to signature + for (llvm::StringMapEntry &v : symbolMap[getFunctionScope(op)]) + operandTypes.push_back(v.getValue().getType()); + + func::FuncOp funcOp = + createFunc(rewriter, op.getLoc(), name, operandTypes, {}, "private"); + funcOp.getBody().takeBody(op.getBody()); + + // Add symbols to scope + for (llvm::StringMapEntry &v : symbolMap[getFunctionScope(op)]) + symbolMap[name][v.getKey()] = + funcOp.getBody().addArgument(v.getValue().getType(), op.getLoc()); + + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), + *getTypeConverter()))) + return failure(); + + rewriter.eraseOp(op); + return success(); + } +}; + +/// Converts a state to a basic block. +class StateToBlock : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(StateNode op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Split the current basic block at the current position + Block *newBlock = rewriter.createBlock(rewriter.getBlock()->getParent()); + + // Add the mapping from the sdfg.state's name to the new basic block + blockMap[op.getName()] = newBlock; + + // Connect to init block if it's an entry state + if (op->hasAttrOfType("entry") && + op->getAttrOfType("entry").getValue()) { + rewriter.setInsertionPointToEnd(&newBlock->getParent()->front()); + createBranch(rewriter, op.getLoc(), {}, newBlock); + } + + // Move the operations from the sdfg.state's body into the new basic block + rewriter.setInsertionPointToStart(newBlock); + + // Collect all operations + SmallVector ops; + for (Operation &operation : op.getBody().getOps()) { + ops.push_back(&operation); + } + + // Move them + for (Operation *operation : ops) { + operation->moveBefore(newBlock, newBlock->end()); + } + + // If there is an outward edge, do not add a return op + for (EdgeOp edge : op->getParentRegion()->getOps()) { + if (edge.getSrc().equals(op.getSymName())) { + rewriter.eraseOp(op); + return success(); + } + } + + createReturn(rewriter, op.getLoc(), {}); + rewriter.eraseOp(op); + return success(); + } +}; + +/// Converts an edge to basic blocks (for assignments and conditions) and +/// (conditional) branches. +class EdgeToBranch : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(EdgeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // If we don't have a condition or assignments, add a simple branch + if (adaptor.getCondition().equals("1") && adaptor.getAssign().empty()) { + rewriter.setInsertionPointToEnd(blockMap[adaptor.getSrc()]); + createBranch(rewriter, op.getLoc(), {}, blockMap[adaptor.getDest()]); + rewriter.eraseOp(op); + return success(); + } + + Block *takenBlock = rewriter.createBlock(rewriter.getBlock()->getParent()); + llvm::StringMap refMap; + // FIXME: Extend to variable amount of references + refMap["ref"] = adaptor.getRef(); + + if (!adaptor.getCondition().equals("1")) { + // If we have a condition, create a second block (not taken path) + Block *notTakenBlock = + rewriter.createBlock(rewriter.getBlock()->getParent()); + rewriter.setInsertionPointToEnd(blockMap[adaptor.getSrc()]); + // Compute condition + Value condition = symbolicExpressionToMLIR( + rewriter, op, adaptor.getCondition(), refMap); + // Add conditional branch + createCondBranch(rewriter, op.getLoc(), condition, takenBlock, + notTakenBlock); + + // Update blockMap + blockMap[adaptor.getSrc()] = notTakenBlock; + + // If there is no other edge op for the source state, add return statement + // to the new block + bool hasEdge = false; + + for (EdgeOp edge : rewriter.getBlock()->getParent()->getOps()) { + if (edge.getSrc().equals(adaptor.getSrc()) && edge != op && + !processedEdges.contains(edge)) { + hasEdge = true; + break; + } + } + + if (!hasEdge) { + rewriter.setInsertionPointToEnd(notTakenBlock); + createReturn(rewriter, op.getLoc(), {}); + } + } else { + rewriter.setInsertionPointToEnd(blockMap[adaptor.getSrc()]); + createBranch(rewriter, op.getLoc(), {}, takenBlock); + // No blockMap update because only one unconditial edge allowed per state + } + + // Add assignments + rewriter.setInsertionPointToStart(takenBlock); + + for (Attribute assignment : adaptor.getAssign()) + symbolicExpressionToMLIR(rewriter, op, cast(assignment), + refMap); + + // Create simple branch to destination + createBranch(rewriter, op.getLoc(), {}, blockMap[adaptor.getDest()]); + + rewriter.eraseOp(op); + processedEdges.insert(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Access Node Patterns +//===----------------------------------------------------------------------===// + +/// Converts an allocation operation to memref::AllocOp. +class AllocToAlloc : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type memrefType = getTypeConverter()->convertType(op.getType()); + if (!memrefType || !memrefType.isa()) + return failure(); + + SmallVector operands; + + if (ArrayType array = op.getType().dyn_cast()) { + unsigned intIdx = 0; + unsigned symIdx = 0; + unsigned opIdx = 0; + + for (unsigned i = 0; i < array.getShape().size(); ++i) { + Value val; + if (array.getShape()[i] && array.getIntegers()[intIdx] < 0) { + intIdx++; + val = op.getOperand(opIdx++); + } else if (array.getShape()[i]) { + intIdx++; + continue; + } else { + val = symbolicExpressionToMLIR(rewriter, op, + array.getSymbols()[symIdx++]); + val = createIndexCast(rewriter, op.getLoc(), rewriter.getIndexType(), + val); + } + operands.push_back(val); + } + + } else { + operands = adaptor.getOperands(); + } + + memref::AllocOp allocOp = createAlloc( + rewriter, op.getLoc(), memrefType.cast(), operands); + rewriter.replaceOp(op, {allocOp}); + return success(); + } +}; + +/// Converts a load operation to memref::LoadOp. +class LoadToLoad : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector indices = numberListToMLIR(rewriter, op, "indices"); + + memref::LoadOp loadOp = + createLoad(rewriter, op.getLoc(), adaptor.getArr(), indices); + rewriter.replaceOp(op, {loadOp}); + return success(); + } +}; + +/// Converts a store operation to memref::StoreOp. +class StoreToStore : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector indices = numberListToMLIR(rewriter, op, "indices"); + createStore(rewriter, op.getLoc(), adaptor.getVal(), adaptor.getArr(), + indices); + rewriter.eraseOp(op); + return success(); + } +}; + +/// Converts a copy operation to memref::CopyOp. +class CopyToCopy : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + createCopy(rewriter, op.getLoc(), adaptor.getSrc(), adaptor.getDest()); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Symbol Patterns +//===----------------------------------------------------------------------===// + +/// Converts a symbol allocation operation to memref::AllocOp. +class AllocSymbolToAlloc : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AllocSymbolOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + allocSymbol(rewriter, op.getLoc(), op.getSym(), + symbolMap[getFunctionScope(op)]); + rewriter.eraseOp(op); + return success(); + } +}; + +/// Converts a symbolic expression to multiple builtin operations. +class SymToOps : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SymOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value val = symbolicExpressionToMLIR(rewriter, op, op.getExpr()); + + if (op.getType().isIndex()) + val = createIndexCast(rewriter, op.getLoc(), op.getType(), val); + + if (op.getType().isIntOrIndex() && !op.getType().isIndex() && + !op.getType().isInteger(64)) + val = createTruncI(rewriter, op.getLoc(), op.getType(), val); + + rewriter.replaceOp(op, {val}); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Tasklet Patterns +//===----------------------------------------------------------------------===// + +/// Converts a tasklet to func::FuncOp and func::CallOp. +class TaskletToFunc : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TaskletNode op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Create call + std::string name = sdfg::utils::generateName("tasklet"); + + SmallVector resultTypes; + if (getTypeConverter() + ->convertTypes(op.getResultTypes(), resultTypes) + .failed()) + return failure(); + + // Propagate symbols + SmallVector operands = adaptor.getOperands(); + for (llvm::StringMapEntry &v : symbolMap[getFunctionScope(op)]) + operands.push_back(v.getValue()); + + func::CallOp callOp = + createCall(rewriter, op.getLoc(), resultTypes, name, operands); + rewriter.replaceOp(op, callOp.getResults()); + + // Set insertion point before the current func + Operation *parent = op->getParentOfType(); + if (parent == nullptr) + return failure(); + + rewriter.setInsertionPoint(parent); + + // Create function + SmallVector operandTypes; + if (getTypeConverter() + ->convertTypes(op.getOperandTypes(), operandTypes) + .failed()) + return failure(); + + // Add symbols to signature + for (llvm::StringMapEntry &v : symbolMap[getFunctionScope(op)]) + operandTypes.push_back(v.getValue().getType()); + + func::FuncOp funcOp = createFunc(rewriter, op.getLoc(), name, operandTypes, + resultTypes, "private"); + funcOp.getBody().takeBody(op.getBody()); + + // Add symbols to scope + for (llvm::StringMapEntry &v : symbolMap[getFunctionScope(op)]) + symbolMap[name][v.getKey()] = + funcOp.getBody().addArgument(v.getValue().getType(), op.getLoc()); + + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), + *getTypeConverter()))) + return failure(); + + return success(); + } +}; + +/// Converts a return operation to func::ReturnOp. +class ReturnToReturn : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + createReturn(rewriter, op.getLoc(), op.getOperands()); + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Map & Consume Patterns +//===----------------------------------------------------------------------===// + +/// Converts a map scope to scf::ParallelOp. +class MapToParallel : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(MapNode op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector lowerBounds = + numberListToMLIR(rewriter, op, "lowerBounds"); + SmallVector upperBounds = + numberListToMLIR(rewriter, op, "upperBounds"); + SmallVector steps = numberListToMLIR(rewriter, op, "steps"); + + scf::ParallelOp parallelOp = + createParallel(rewriter, op.getLoc(), lowerBounds, upperBounds, steps); + + if (!op.getBodyRegion().empty() && !op.getBodyRegion().front().empty()) { + parallelOp.getBodyRegion().takeBody(op.getBodyRegion()); + rewriter.setInsertionPointToEnd(parallelOp.getBody()); + createYield(rewriter, op.getLoc()); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +/// Converts a consume scope to TBD. +class ConsumeToTODO : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ConsumeNode op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: Write lowering for consume nodes + return failure(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +/// Registers all the patterns above in a RewritePatternSet. +void populateSDFGToGenericConversionPatterns(RewritePatternSet &patterns, + TypeConverter &converter) { + MLIRContext *ctxt = patterns.getContext(); + + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); + patterns.add(converter, ctxt); +} + +namespace { +struct SDFGToGenericPass + : public sdfg::conversion::SDFGToGenericPassBase { + void runOnOperation() override; +}; +} // namespace + +/// Runs the pass on the top-level module operation. +void SDFGToGenericPass::runOnOperation() { + ModuleOp module = getOperation(); + + GenericTarget target(getContext()); + ToMemrefConverter converter; + + RewritePatternSet patterns(&getContext()); + populateSDFGToGenericConversionPatterns(patterns, converter); + + if (applyFullConversion(module, target, std::move(patterns)).failed()) + signalPassFailure(); +} + +/// Returns a unique pointer to this pass. +std::unique_ptr conversion::createSDFGToGenericPass() { + return std::make_unique(); +} diff --git a/lib/SDFG/Conversion/SDFGToGeneric/OpCreators.cpp b/lib/SDFG/Conversion/SDFGToGeneric/OpCreators.cpp new file mode 100644 index 000000000..629c94cd2 --- /dev/null +++ b/lib/SDFG/Conversion/SDFGToGeneric/OpCreators.cpp @@ -0,0 +1,316 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains convenience functions that build, create and insert +/// various operations. + +#include "SDFG/Conversion/SDFGToGeneric/OpCreators.h" + +using namespace mlir; +using namespace sdfg; + +/// Builds, creates and inserts a func::FuncOp. +func::FuncOp conversion::createFunc(PatternRewriter &rewriter, Location loc, + StringRef name, TypeRange inputTypes, + TypeRange resultTypes, + StringRef visibility) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, func::FuncOp::getOperationName()); + + FunctionType func_type = builder.getFunctionType(inputTypes, resultTypes); + StringAttr visAttr = builder.getStringAttr(visibility); + + func::FuncOp::build(builder, state, name, func_type, visAttr, {}, {}); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts a func::CallOp. +func::CallOp conversion::createCall(PatternRewriter &rewriter, Location loc, + TypeRange resultTypes, StringRef callee, + ValueRange operands) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, func::CallOp::getOperationName()); + + func::CallOp::build(builder, state, resultTypes, callee, operands); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts a func::ReturnOp. +func::ReturnOp conversion::createReturn(PatternRewriter &rewriter, Location loc, + ValueRange operands) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, func::ReturnOp::getOperationName()); + + func::ReturnOp::build(builder, state, operands); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts a cf::BranchOp. +cf::BranchOp conversion::createBranch(PatternRewriter &rewriter, Location loc, + ValueRange operands, Block *dest) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, cf::BranchOp::getOperationName()); + + cf::BranchOp::build(builder, state, operands, dest); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts a cf::CondBranchOp. +cf::CondBranchOp conversion::createCondBranch(PatternRewriter &rewriter, + Location loc, Value condition, + Block *trueDest, + Block *falseDest) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, cf::CondBranchOp::getOperationName()); + + cf::CondBranchOp::build(builder, state, condition, trueDest, falseDest); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts a memref::AllocOp. +memref::AllocOp conversion::createAlloc(PatternRewriter &rewriter, Location loc, + MemRefType memrefType, + ValueRange dynamicSizes) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, memref::AllocOp::getOperationName()); + + memref::AllocaOp::build(builder, state, memrefType, dynamicSizes); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts a memref::LoadOp. +memref::LoadOp conversion::createLoad(PatternRewriter &rewriter, Location loc, + Value memref, ValueRange indices) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, memref::LoadOp::getOperationName()); + + memref::LoadOp::build(builder, state, memref, indices); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts a memref::StoreOp. +memref::StoreOp conversion::createStore(PatternRewriter &rewriter, Location loc, + Value value, Value memref, + ValueRange indices) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, memref::StoreOp::getOperationName()); + + memref::StoreOp::build(builder, state, value, memref, indices); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts a memref::CopyOp. +memref::CopyOp conversion::createCopy(PatternRewriter &rewriter, Location loc, + Value source, Value target) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, memref::CopyOp::getOperationName()); + + memref::CopyOp::build(builder, state, source, target); + return cast(rewriter.create(state)); +} + +/// Allocates a symbol as a memref if it's not already allocated and +/// populates the symbol map. +void conversion::allocSymbol(PatternRewriter &rewriter, Location loc, + StringRef symName, + llvm::StringMap &symbolMap) { + if (symbolMap.find(symName) != symbolMap.end()) + return; + + OpBuilder::InsertPoint insertionPoint = rewriter.saveInsertionPoint(); + + // Set insertion point to the beginning of the first block (top of func) + rewriter.setInsertionPointToStart(&rewriter.getBlock()->getParent()->front()); + + IntegerType intType = IntegerType::get(loc->getContext(), 64); + MemRefType memrefType = MemRefType::get({}, intType); + memref::AllocOp allocOp = createAlloc(rewriter, loc, memrefType, {}); + + // Update symbol map + symbolMap[symName] = allocOp; + rewriter.restoreInsertionPoint(insertionPoint); +} + +/// Builds, creates and inserts an arith::ConstantIntOp. +arith::ConstantIntOp conversion::createConstantInt(PatternRewriter &rewriter, + Location loc, int val, + int width) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::ConstantIntOp::getOperationName()); + + arith::ConstantIntOp::build(builder, state, val, width); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::AddIOp. +arith::AddIOp conversion::createAddI(PatternRewriter &rewriter, Location loc, + Value a, Value b) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::AddIOp::getOperationName()); + + arith::AddIOp::build(builder, state, a, b); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::SubIOp. +arith::SubIOp conversion::createSubI(PatternRewriter &rewriter, Location loc, + Value a, Value b) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::SubIOp::getOperationName()); + + arith::SubIOp::build(builder, state, a, b); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::MulIOp. +arith::MulIOp conversion::createMulI(PatternRewriter &rewriter, Location loc, + Value a, Value b) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::MulIOp::getOperationName()); + + arith::MulIOp::build(builder, state, a, b); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::DivSIOp. +arith::DivSIOp conversion::createDivSI(PatternRewriter &rewriter, Location loc, + Value a, Value b) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::DivSIOp::getOperationName()); + + arith::DivSIOp::build(builder, state, a, b); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::FloorDivSIOp. +arith::FloorDivSIOp conversion::createFloorDivSI(PatternRewriter &rewriter, + Location loc, Value a, + Value b) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::FloorDivSIOp::getOperationName()); + + arith::FloorDivSIOp::build(builder, state, a, b); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::RemSIOp. +arith::RemSIOp conversion::createRemSI(PatternRewriter &rewriter, Location loc, + Value a, Value b) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::RemSIOp::getOperationName()); + + arith::RemSIOp::build(builder, state, a, b); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::OrIOp. +arith::OrIOp conversion::createOrI(PatternRewriter &rewriter, Location loc, + Value a, Value b) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::OrIOp::getOperationName()); + + arith::OrIOp::build(builder, state, a, b); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::AndIOp. +arith::AndIOp conversion::createAndI(PatternRewriter &rewriter, Location loc, + Value a, Value b) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::AndIOp::getOperationName()); + + arith::AndIOp::build(builder, state, a, b); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::XOrIOp. +arith::XOrIOp conversion::createXOrI(PatternRewriter &rewriter, Location loc, + Value a, Value b) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::XOrIOp::getOperationName()); + + arith::XOrIOp::build(builder, state, a, b); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::ShLIOp. +arith::ShLIOp conversion::createShLI(PatternRewriter &rewriter, Location loc, + Value a, Value b) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::ShLIOp::getOperationName()); + + arith::ShLIOp::build(builder, state, a, b); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::ShRSIOp. +arith::ShRSIOp conversion::createShRSI(PatternRewriter &rewriter, Location loc, + Value a, Value b) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::ShRSIOp::getOperationName()); + + arith::ShRSIOp::build(builder, state, a, b); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::CmpIOp. +arith::CmpIOp conversion::createCmpI(PatternRewriter &rewriter, Location loc, + arith::CmpIPredicate predicate, Value lhs, + Value rhs) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::CmpIOp::getOperationName()); + + arith::CmpIOp::build(builder, state, predicate, lhs, rhs); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::ExtSIOp. +arith::ExtSIOp conversion::createExtSI(PatternRewriter &rewriter, Location loc, + Type out, Value in) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::ExtSIOp::getOperationName()); + + arith::ExtSIOp::build(builder, state, out, in); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::TruncIOp. +arith::TruncIOp conversion::createTruncI(PatternRewriter &rewriter, + Location loc, Type out, Value in) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::TruncIOp::getOperationName()); + + arith::TruncIOp::build(builder, state, out, in); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an arith::IndexCastOp. +arith::IndexCastOp conversion::createIndexCast(PatternRewriter &rewriter, + Location loc, Type out, + Value in) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, arith::IndexCastOp::getOperationName()); + + arith::IndexCastOp::build(builder, state, out, in); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts a scf::ParallelOp. +scf::ParallelOp conversion::createParallel(PatternRewriter &rewriter, + Location loc, ValueRange lowerBounds, + ValueRange upperBounds, + ValueRange steps) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, scf::ParallelOp::getOperationName()); + + scf::ParallelOp::build(builder, state, lowerBounds, upperBounds, steps); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts a scf::YieldOp. +scf::YieldOp conversion::createYield(PatternRewriter &rewriter, Location loc) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, scf::YieldOp::getOperationName()); + + scf::YieldOp::build(builder, state); + return cast(rewriter.create(state)); +} diff --git a/lib/SDFG/Conversion/SDFGToGeneric/SymbolicParser.cpp b/lib/SDFG/Conversion/SDFGToGeneric/SymbolicParser.cpp new file mode 100644 index 000000000..a5a8a9887 --- /dev/null +++ b/lib/SDFG/Conversion/SDFGToGeneric/SymbolicParser.cpp @@ -0,0 +1,762 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file implements a simple LL(1) parser for symbolic expressions. + +/* +// Grammar +stmt ::= assignment | log_or_expr +assignment ::= IDENT ASSIGN log_or_expr + +log_or_expr ::= log_and_expr ( LOG_OR log_and_expr )* +log_and_expr ::= eq_expr ( LOG_AND eq_expr )* + +eq_expr ::= rel_expr ( ( EQ | NE ) rel_expr )* +rel_expr ::= shift_expr ( ( LT | LE | GT | GE ) shift_expr )* +shift_expr ::= bit_or_expr ( ( LSHIFT | RSHIFT ) bit_or_expr )* + +bit_or_expr ::= bit_xor_expr ( BIT_OR bit_xor_expr )* +bit_xor_expr ::= bit_and_expr ( BIT_XOR bit_and_expr )* +bit_and_expr ::= add_expr ( BIT_AND add_expr )* + +add_expr ::= mul_expr ( ( ADD | SUB ) mul_expr )* +mul_expr ::= exp_expr ( ( MUL | DIV | FLOORDIV | MOD ) exp_expr )* +exp_expr ::= unary_expr ( EXP unary_expr )* +unary_expr ::= ( ADD | SUB | LOG_NOT | BIT_NOT )? factor +factor ::= LPAREN log_or_expr RPAREN | const_expr | IDENT + +const_expr ::= bool_const | INT_CONST +bool_const ::= TRUE | FALSE + +// Tokens +EQ ::= '=='; +NE ::= '!='; +LT ::= '<'; +LE ::= '<='; +GT ::= '>'; +GE ::= '>='; +ASSIGN ::= ':'; +LOG_OR ::= 'or'; +LOG_AND ::= 'and'; +LOG_NOT ::= 'not'; +ADD ::= '+'; +SUB ::= '-'; +MUL ::= '*'; +DIV ::= '/'; +FLOORDIV ::= '//'; +MOD ::= '%'; +EXP ::= '**'; +TRUE ::= 'True'; +FALSE ::= 'False'; +BIT_OR ::= '|'; +BIT_XOR ::= '^'; +BIT_AND ::= '&'; +BIT_NOT ::= '~'; +LSHIFT ::= '<<'; +RSHIFT ::= '>>'; +LPAREN ::= '('; +RPAREN ::= ')'; +INT_CONST ::= DIGIT+ +IDENT ::= LETTER ( LETTER | DIGIT )* +WS ::= [ \t\r\n]+ -> skip; + +// Helpers +DIGIT ::= [0-9]; +LETTER ::= [_a-zA-Z]; + +*/ + +#include "SDFG/Conversion/SDFGToGeneric/SymbolicParser.h" +#include "SDFG/Conversion/SDFGToGeneric/OpCreators.h" +#include + +using namespace mlir; +using namespace sdfg::conversion; + +namespace mlir::sdfg::conversion { + +//===----------------------------------------------------------------------===// +// AST Nodes +//===----------------------------------------------------------------------===// + +/// Converts the integer node into MLIR code. SymbolMap is used for permanent +/// mapping of symbols to values. RefMap is a temporary mapping overriding +/// SymbolMap. +Value IntNode::codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) { + return createConstantInt(rewriter, loc, value, 64); +} + +/// Converts the boolean node into MLIR code. SymbolMap is used for permanent +/// mapping of symbols to values. RefMap is a temporary mapping overriding +/// SymbolMap. +Value BoolNode::codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) { + return createConstantInt(rewriter, loc, value, 1); +} + +/// Converts the variable node into MLIR code. SymbolMap is used for permanent +/// mapping of symbols to values. RefMap is a temporary mapping overriding +/// SymbolMap. +Value VarNode::codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) { + if (refMap.find(name) != refMap.end()) { + Value val = refMap[name]; + + if (val.getType().isIndex()) + return createIndexCast(rewriter, loc, rewriter.getI64Type(), val); + + if (val.getType().isIntOrIndex() && + val.getType().getIntOrFloatBitWidth() != 64) + return createExtSI(rewriter, loc, rewriter.getI64Type(), val); + + return val; + } + + allocSymbol(rewriter, loc, name, symbolMap); + return createLoad(rewriter, loc, symbolMap[name], {}); +} + +/// Converts the assignment node into MLIR code. SymbolMap is used for +/// permanent mapping of symbols to values. RefMap is a temporary mapping +/// overriding SymbolMap. +Value AssignNode::codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) { + allocSymbol(rewriter, loc, variable->name, symbolMap); + Value eVal = expr->codegen(rewriter, loc, symbolMap, refMap); + createStore(rewriter, loc, eVal, symbolMap[variable->name], {}); + return nullptr; +} + +/// Converts the unary operation node into MLIR code. SymbolMap is used for +/// permanent mapping of symbols to values. RefMap is a temporary mapping +/// overriding SymbolMap. +Value UnOpNode::codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) { + Value eVal = expr->codegen(rewriter, loc, symbolMap, refMap); + + switch (op) { + case ADD: + return eVal; + case SUB: { + Value zero = createConstantInt(rewriter, loc, 0, 64); + return createSubI(rewriter, loc, zero, eVal); + } + case LOG_NOT: { + Value zero = createConstantInt(rewriter, loc, 0, 1); + return createCmpI(rewriter, loc, arith::CmpIPredicate::eq, zero, eVal); + } + case BIT_NOT: { + Value negOne = createConstantInt(rewriter, loc, -1, 64); + return createXOrI(rewriter, loc, negOne, eVal); + } + } + + return eVal; +} + +/// Converts the binary operation node into MLIR code. SymbolMap is used for +/// permanent mapping of symbols to values. RefMap is a temporary mapping +/// overriding SymbolMap. +Value BinOpNode::codegen(PatternRewriter &rewriter, Location loc, + llvm::StringMap &symbolMap, + llvm::StringMap &refMap) { + Value lVal = left->codegen(rewriter, loc, symbolMap, refMap); + Value rVal = right->codegen(rewriter, loc, symbolMap, refMap); + + switch (op) { + case ADD: + return createAddI(rewriter, loc, lVal, rVal); + case SUB: + return createSubI(rewriter, loc, lVal, rVal); + case MUL: + return createMulI(rewriter, loc, lVal, rVal); + case DIV: + return createDivSI(rewriter, loc, lVal, rVal); + case FLOORDIV: + return createFloorDivSI(rewriter, loc, lVal, rVal); + case MOD: + return createRemSI(rewriter, loc, lVal, rVal); + case EXP: + break; + // TODO: Implement EXP case + case BIT_OR: + return createOrI(rewriter, loc, lVal, rVal); + case BIT_XOR: + return createXOrI(rewriter, loc, lVal, rVal); + case BIT_AND: + return createAndI(rewriter, loc, lVal, rVal); + case LSHIFT: + return createShLI(rewriter, loc, lVal, rVal); + case RSHIFT: + return createShRSI(rewriter, loc, lVal, rVal); + case LOG_OR: + break; + // TODO: Implement LOG_OR case + case LOG_AND: + break; + // TODO: Implement LOG_AND case + case EQ: + return createCmpI(rewriter, loc, arith::CmpIPredicate::eq, lVal, rVal); + case NE: + return createCmpI(rewriter, loc, arith::CmpIPredicate::ne, lVal, rVal); + case LT: + return createCmpI(rewriter, loc, arith::CmpIPredicate::slt, lVal, rVal); + case LE: + return createCmpI(rewriter, loc, arith::CmpIPredicate::sle, lVal, rVal); + case GT: + return createCmpI(rewriter, loc, arith::CmpIPredicate::sgt, lVal, rVal); + case GE: + return createCmpI(rewriter, loc, arith::CmpIPredicate::sge, lVal, rVal); + } + + return lVal; +} + +//===----------------------------------------------------------------------===// +// Tokenizer +//===----------------------------------------------------------------------===// + +/// Converts the symbolic expression to individual tokens. +Optional> SymbolicParser::tokenize(StringRef input) { + std::vector> tokenDefinitions = { + {"==", EQ}, + {"!=", NE}, + {"<<", LSHIFT}, + {">>", RSHIFT}, + {"<=", LE}, + {">=", GE}, + {"<", LT}, + {">", GT}, + {":", ASSIGN}, + {"or", LOG_OR}, + {"and", LOG_AND}, + {"not", LOG_NOT}, + {"\\+", ADD}, + {"-", SUB}, + {"\\*\\*", EXP}, + {"\\*", MUL}, + {"//", FLOORDIV}, + {"/", DIV}, + {"%", MOD}, + {"True", TRUE}, + {"False", FALSE}, + {"\\|", BIT_OR}, + {"\\^", BIT_XOR}, + {"&", BIT_AND}, + {"~", BIT_NOT}, + {"\\(", LPAREN}, + {"\\)", RPAREN}, + {"\\d+", INT_CONST}, + {"[_a-zA-Z][_a-zA-Z0-9]*", IDENT}, + {"[ \\t\\r\\n]+", WS}}; + + SmallVector tokens; + std::string remaining = input.str(); + + while (!remaining.empty()) { + bool matched = false; + + for (std::pair &definition : tokenDefinitions) { + std::regex pattern("^" + definition.first); + std::smatch match; + if (std::regex_search(remaining, match, pattern)) { + if (definition.second != WS) // Skip whitespace tokens + tokens.push_back(Token{definition.second, match.str()}); + + remaining = match.suffix().str(); + matched = true; + break; + } + } + + if (!matched) + return std::nullopt; + } + + return tokens; +} + +//===----------------------------------------------------------------------===// +// Parser +//===----------------------------------------------------------------------===// + +/// Attempts to parse a statement: +/// stmt ::= assignment | log_or_expr +std::unique_ptr SymbolicParser::stmt() { + std::unique_ptr assignNode = assignment(); + + if (assignNode != nullptr) + return assignNode; + + return log_or_expr(); +} + +/// Attempts to parse an assignment: +/// assignment ::= IDENT ASSIGN log_or_expr +std::unique_ptr SymbolicParser::assignment() { + if (pos + 2 >= tokens.size() || tokens[pos].type != TokenType::IDENT || + tokens[pos + 1].type != TokenType::ASSIGN) + return nullptr; + + std::string varName = tokens[pos].value; + pos += 2; + std::unique_ptr orNode = log_or_expr(); + + if (orNode == nullptr) + return nullptr; + + std::unique_ptr varNode = std::make_unique(varName); + return std::make_unique(std::move(varNode), std::move(orNode)); +} + +/// Attempts to parse a logical OR expression: +/// log_or_expr ::= log_and_expr ( LOG_OR log_and_expr )* +std::unique_ptr SymbolicParser::log_or_expr() { + std::unique_ptr leftNode = log_and_expr(); + + if (leftNode == nullptr) + return nullptr; + + while (pos + 1 < tokens.size() && tokens[pos].type == TokenType::LOG_OR) { + pos++; + std::unique_ptr rightNode = log_and_expr(); + + if (rightNode == nullptr) + return nullptr; + + leftNode = std::make_unique( + std::move(leftNode), BinOpNode::BinOp::LOG_OR, std::move(rightNode)); + } + + return leftNode; +} + +/// Attempts to parse a logical AND expression: +/// log_and_expr ::= eq_expr ( LOG_AND eq_expr )* +std::unique_ptr SymbolicParser::log_and_expr() { + std::unique_ptr leftNode = eq_expr(); + + if (leftNode == nullptr) + return nullptr; + + while (pos + 1 < tokens.size() && tokens[pos].type == TokenType::LOG_AND) { + pos++; + std::unique_ptr rightNode = eq_expr(); + + if (rightNode == nullptr) + return nullptr; + + leftNode = std::make_unique( + std::move(leftNode), BinOpNode::BinOp::LOG_AND, std::move(rightNode)); + } + + return leftNode; +} + +/// Attempts to parse an equality expression: +/// eq_expr ::= rel_expr ( ( EQ | NE ) rel_expr )* +std::unique_ptr SymbolicParser::eq_expr() { + std::unique_ptr leftNode = rel_expr(); + + if (leftNode == nullptr) + return nullptr; + + while (pos + 1 < tokens.size() && (tokens[pos].type == TokenType::EQ || + tokens[pos].type == TokenType::NE)) { + BinOpNode::BinOp binOp; + + switch (tokens[pos].type) { + case TokenType::EQ: + binOp = BinOpNode::BinOp::EQ; + break; + case TokenType::NE: + binOp = BinOpNode::BinOp::NE; + break; + default: + break; + } + + pos++; + std::unique_ptr rightNode = rel_expr(); + + if (rightNode == nullptr) + return nullptr; + + leftNode = std::make_unique(std::move(leftNode), binOp, + std::move(rightNode)); + } + + return leftNode; +} + +/// Attempts to parse an inequality expression: +/// rel_expr ::= shift_expr ( ( LT | LE | GT | GE ) shift_expr )* +std::unique_ptr SymbolicParser::rel_expr() { + std::unique_ptr leftNode = shift_expr(); + + if (leftNode == nullptr) + return nullptr; + + while (pos + 1 < tokens.size() && (tokens[pos].type == TokenType::LT || + tokens[pos].type == TokenType::LE || + tokens[pos].type == TokenType::GT || + tokens[pos].type == TokenType::GE)) { + BinOpNode::BinOp binOp; + + switch (tokens[pos].type) { + case TokenType::LT: + binOp = BinOpNode::BinOp::LT; + break; + case TokenType::LE: + binOp = BinOpNode::BinOp::LE; + break; + case TokenType::GT: + binOp = BinOpNode::BinOp::GT; + break; + case TokenType::GE: + binOp = BinOpNode::BinOp::GE; + break; + default: + break; + } + + pos++; + std::unique_ptr rightNode = shift_expr(); + + if (rightNode == nullptr) + return nullptr; + + leftNode = std::make_unique(std::move(leftNode), binOp, + std::move(rightNode)); + } + + return leftNode; +} + +/// Attempts to parse a shift expression: +/// shift_expr ::= bit_or_expr ( (LSHIFT | RSHIFT ) bit_or_expr )* +std::unique_ptr SymbolicParser::shift_expr() { + std::unique_ptr leftNode = bit_or_expr(); + + if (leftNode == nullptr) + return nullptr; + + while (pos + 1 < tokens.size() && (tokens[pos].type == TokenType::LSHIFT || + tokens[pos].type == TokenType::RSHIFT)) { + BinOpNode::BinOp binOp; + + switch (tokens[pos].type) { + case TokenType::LSHIFT: + binOp = BinOpNode::BinOp::LSHIFT; + break; + case TokenType::RSHIFT: + binOp = BinOpNode::BinOp::RSHIFT; + break; + default: + break; + } + + pos++; + std::unique_ptr rightNode = bit_or_expr(); + + if (rightNode == nullptr) + return nullptr; + + leftNode = std::make_unique(std::move(leftNode), binOp, + std::move(rightNode)); + } + + return leftNode; +} + +/// Attempts to parse a bitwise OR expression: +/// bit_or_expr ::= bit_xor_expr ( BIT_OR bit_xor_expr )* +std::unique_ptr SymbolicParser::bit_or_expr() { + std::unique_ptr leftNode = bit_xor_expr(); + + if (leftNode == nullptr) + return nullptr; + + while (pos + 1 < tokens.size() && tokens[pos].type == TokenType::BIT_OR) { + pos++; + std::unique_ptr rightNode = bit_xor_expr(); + + if (rightNode == nullptr) + return nullptr; + + leftNode = std::make_unique( + std::move(leftNode), BinOpNode::BinOp::BIT_OR, std::move(rightNode)); + } + + return leftNode; +} + +/// Attempts to parse a bitwise XOR expression: +/// bit_xor_expr ::= bit_and_expr ( BIT_XOR bit_and_expr )* +std::unique_ptr SymbolicParser::bit_xor_expr() { + std::unique_ptr leftNode = bit_and_expr(); + + if (leftNode == nullptr) + return nullptr; + + while (pos + 1 < tokens.size() && tokens[pos].type == TokenType::BIT_XOR) { + pos++; + std::unique_ptr rightNode = bit_and_expr(); + + if (rightNode == nullptr) + return nullptr; + + leftNode = std::make_unique( + std::move(leftNode), BinOpNode::BinOp::BIT_XOR, std::move(rightNode)); + } + + return leftNode; +} + +/// Attempts to parse a bitwise AND expression: +/// bit_and_expr ::= add_expr ( BIT_AND add_expr )* +std::unique_ptr SymbolicParser::bit_and_expr() { + std::unique_ptr leftNode = add_expr(); + + if (leftNode == nullptr) + return nullptr; + + while (pos + 1 < tokens.size() && tokens[pos].type == TokenType::BIT_AND) { + pos++; + std::unique_ptr rightNode = add_expr(); + + if (rightNode == nullptr) + return nullptr; + + leftNode = std::make_unique( + std::move(leftNode), BinOpNode::BinOp::BIT_AND, std::move(rightNode)); + } + + return leftNode; +} + +/// Attempts to parse an arithmetic addition / subtraction expression: +/// add_expr ::= mul_expr ( ( ADD | SUB ) mul_expr )* +std::unique_ptr SymbolicParser::add_expr() { + std::unique_ptr leftNode = mul_expr(); + + if (leftNode == nullptr) + return nullptr; + + while (pos + 1 < tokens.size() && (tokens[pos].type == TokenType::ADD || + tokens[pos].type == TokenType::SUB)) { + BinOpNode::BinOp binOp; + + switch (tokens[pos].type) { + case TokenType::ADD: + binOp = BinOpNode::BinOp::ADD; + break; + case TokenType::SUB: + binOp = BinOpNode::BinOp::SUB; + break; + default: + break; + } + + pos++; + std::unique_ptr rightNode = mul_expr(); + + if (rightNode == nullptr) + return nullptr; + + leftNode = std::make_unique(std::move(leftNode), binOp, + std::move(rightNode)); + } + + return leftNode; +} + +/// Attempts to parse an arithmetic multiplication / division / floor / modulo +/// expression: +/// mul_expr ::= exp_expr ( ( MUL | DIV | FLOORDIV | MOD ) exp_expr )* +std::unique_ptr SymbolicParser::mul_expr() { + std::unique_ptr leftNode = exp_expr(); + + if (leftNode == nullptr) + return nullptr; + + while (pos + 1 < tokens.size() && (tokens[pos].type == TokenType::MUL || + tokens[pos].type == TokenType::DIV || + tokens[pos].type == TokenType::FLOORDIV || + tokens[pos].type == TokenType::MOD)) { + BinOpNode::BinOp binOp; + + switch (tokens[pos].type) { + case TokenType::MUL: + binOp = BinOpNode::BinOp::MUL; + break; + case TokenType::DIV: + binOp = BinOpNode::BinOp::DIV; + break; + case TokenType::FLOORDIV: + binOp = BinOpNode::BinOp::FLOORDIV; + break; + case TokenType::MOD: + binOp = BinOpNode::BinOp::MOD; + break; + default: + break; + } + + pos++; + std::unique_ptr rightNode = exp_expr(); + + if (rightNode == nullptr) + return nullptr; + + leftNode = std::make_unique(std::move(leftNode), binOp, + std::move(rightNode)); + } + + return leftNode; +} + +/// Attempts to parse an arithmetic exponential expression: +/// exp_expr ::= unary_expr ( EXP unary_expr )* +std::unique_ptr SymbolicParser::exp_expr() { + std::unique_ptr leftNode = unary_expr(); + + if (leftNode == nullptr) + return nullptr; + + while (pos + 1 < tokens.size() && tokens[pos].type == TokenType::EXP) { + pos++; + std::unique_ptr rightNode = unary_expr(); + + if (rightNode == nullptr) + return nullptr; + + leftNode = std::make_unique( + std::move(leftNode), BinOpNode::BinOp::EXP, std::move(rightNode)); + } + + return leftNode; +} + +/// Attempts to parse an unary positive / negative / logical and bitwise NOT +/// expression: +/// unary_expr ::= ( ADD | SUB | LOG_NOT | BIT_NOT )? factor +std::unique_ptr SymbolicParser::unary_expr() { + if (pos >= tokens.size()) + return nullptr; + + UnOpNode::UnOp unop; + switch (tokens[pos].type) { + case TokenType::ADD: + unop = UnOpNode::UnOp::ADD; + break; + case TokenType::SUB: + unop = UnOpNode::UnOp::SUB; + break; + case TokenType::LOG_NOT: + unop = UnOpNode::UnOp::LOG_NOT; + break; + case TokenType::BIT_NOT: + unop = UnOpNode::UnOp::BIT_NOT; + break; + default: + return factor(); + } + + pos++; + std::unique_ptr node = factor(); + if (node == nullptr) + return nullptr; + + return std::make_unique(unop, std::move(node)); +} + +/// Attempts to parse a single factor: +/// factor ::= LPAREN log_or_expr RPAREN | const_expr | IDENT +std::unique_ptr SymbolicParser::factor() { + if (pos >= tokens.size()) + return nullptr; + + if (tokens[pos].type == TokenType::LPAREN) { + pos++; + std::unique_ptr expr = log_or_expr(); + if (expr == nullptr) + return nullptr; + + if (pos >= tokens.size() || tokens[pos].type != TokenType::RPAREN) + return nullptr; + + pos++; + return expr; + } + + std::unique_ptr constExpr = const_expr(); + if (constExpr != nullptr) + return constExpr; + + if (tokens[pos].type == TokenType::IDENT) { + std::unique_ptr varNode = + std::make_unique(tokens[pos].value); + pos++; + return varNode; + } + + return nullptr; +} + +/// Attempts to parse a constant expression: +/// const_expr ::= bool_const | INT_CONST +std::unique_ptr SymbolicParser::const_expr() { + if (pos >= tokens.size()) + return nullptr; + + std::unique_ptr boolExpr = bool_const(); + if (boolExpr != nullptr) + return boolExpr; + + if (tokens[pos].type == TokenType::INT_CONST) { + std::unique_ptr intNode = + std::make_unique(std::stoi(tokens[pos].value)); + pos++; + return intNode; + } + + return nullptr; +} + +/// Attempts to parse a constant boolean expression: +/// bool_const ::= TRUE | FALSE +std::unique_ptr SymbolicParser::bool_const() { + if (pos >= tokens.size()) + return nullptr; + + if (tokens[pos].type == TokenType::TRUE) { + pos++; + return std::make_unique(true); + } + + if (tokens[pos].type == TokenType::FALSE) { + pos++; + return std::make_unique(false); + } + + return nullptr; +} + +/// Parses a symbolic expression provided as a string to an AST. +std::unique_ptr SymbolicParser::parse(StringRef input) { + Optional> tokens = tokenize(input); + if (!tokens.has_value()) + return nullptr; + + this->tokens = tokens.value(); + this->pos = 0; + return stmt(); +} + +} // namespace mlir::sdfg::conversion diff --git a/lib/SDFG/Dialect/CMakeLists.txt b/lib/SDFG/Dialect/CMakeLists.txt index 7a645ddc9..31e3a042f 100644 --- a/lib/SDFG/Dialect/CMakeLists.txt +++ b/lib/SDFG/Dialect/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_mlir_dialect_library( MLIR_SDFG Dialect.cpp diff --git a/lib/SDFG/Dialect/Dialect.cpp b/lib/SDFG/Dialect/Dialect.cpp index 5ef602e0a..e669f0a54 100644 --- a/lib/SDFG/Dialect/Dialect.cpp +++ b/lib/SDFG/Dialect/Dialect.cpp @@ -1,3 +1,8 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the SDFG dialect initializer and the type definitions, +/// such as parsing, printing and utility functions. + #include "SDFG/Dialect/Dialect.h" #include "mlir/IR/Builders.h" #include "llvm/ADT/TypeSwitch.h" @@ -11,6 +16,7 @@ using namespace sdfg; // SDFG Dialect //===----------------------------------------------------------------------===// +/// Initializes the SDFG dialect by adding all operation and type declarations. void SDFGDialect::initialize() { addOperations< #define GET_OP_LIST @@ -28,6 +34,8 @@ void SDFGDialect::initialize() { //===----------------------------------------------------------------------===// // FIXME: Rewrite to only use an ArrayAttr containing strings & ints +/// Parses a list of dimensions consisting of symbols, constants and question +/// marks. static ParseResult parseDimensionList(AsmParser &parser, Type &elemType, SmallVector &symbols, SmallVector &integers, @@ -54,7 +62,7 @@ static ParseResult parseDimensionList(AsmParser &parser, Type &elemType, continue; } - int32_t num = -1; + int64_t num = -1; OptionalParseResult intOPR = parser.parseOptionalInteger(num); if (intOPR.has_value() && intOPR.value().succeeded()) { integers.push_back(num); @@ -74,6 +82,7 @@ static ParseResult parseDimensionList(AsmParser &parser, Type &elemType, return failure(); } +/// Prints a list of dimensions in human-readable form. static void printDimensionList(AsmPrinter &printer, Type &elemType, ArrayRef &symbols, ArrayRef &integers, @@ -95,6 +104,7 @@ static void printDimensionList(AsmPrinter &printer, Type &elemType, printer << elemType << ">"; } +/// Attempts to parse an array type. ::mlir::Type ArrayType::parse(::mlir::AsmParser &odsParser) { Type elementType; SmallVector symbols; @@ -108,6 +118,7 @@ ::mlir::Type ArrayType::parse(::mlir::AsmParser &odsParser) { return get(odsParser.getContext(), sized); } +/// Prints an array type in human-readable form. void ArrayType::print(::mlir::AsmPrinter &odsPrinter) const { Type elemType = getDimensions().getElementType(); ArrayRef symbols = getDimensions().getSymbols(); @@ -117,6 +128,24 @@ void ArrayType::print(::mlir::AsmPrinter &odsPrinter) const { printDimensionList(odsPrinter, elemType, symbols, integers, shape); } +/// Returns the type of the elements in an array. +Type ArrayType::getElementType() { return getDimensions().getElementType(); } + +/// Returns a list of symbols in the array type. +ArrayRef ArrayType::getSymbols() { + return getDimensions().getSymbols(); +} + +/// Returns a list of integer constants in the array type. +ArrayRef ArrayType::getIntegers() { + return getDimensions().getIntegers(); +} + +/// Returns a list of booleans representing the shape of the array type. +/// (false = symbolic size, true = integer constant) +ArrayRef ArrayType::getShape() { return getDimensions().getShape(); } + +/// Attempts to parse a stream type. ::mlir::Type StreamType::parse(::mlir::AsmParser &odsParser) { Type elementType; SmallVector symbols; @@ -130,6 +159,7 @@ ::mlir::Type StreamType::parse(::mlir::AsmParser &odsParser) { return get(odsParser.getContext(), sized); } +/// Prints a stream type in human-readable form. void StreamType::print(::mlir::AsmPrinter &odsPrinter) const { Type elemType = getDimensions().getElementType(); ArrayRef symbols = getDimensions().getSymbols(); @@ -139,5 +169,6 @@ void StreamType::print(::mlir::AsmPrinter &odsPrinter) const { printDimensionList(odsPrinter, elemType, symbols, integers, shape); } +/// Generate the code for type definitions. #define GET_TYPEDEF_CLASSES #include "SDFG/Dialect/OpsTypes.cpp.inc" diff --git a/lib/SDFG/Dialect/Ops.cpp b/lib/SDFG/Dialect/Ops.cpp index 347c97dde..e4c2eeb8c 100644 --- a/lib/SDFG/Dialect/Ops.cpp +++ b/lib/SDFG/Dialect/Ops.cpp @@ -1,3 +1,8 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains SDFG operation definitions, such as parsing, printing and +/// utility functions. + #include "SDFG/Dialect/Dialect.h" #include "SDFG/Utils/Utils.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -13,6 +18,7 @@ using namespace sdfg; // Helpers //===----------------------------------------------------------------------===// +/// Parses a non-empty region. static ParseResult parseRegion(OpAsmParser &parser, OperationState &result, SmallVector &args, bool enableShadowing) { @@ -26,6 +32,7 @@ static ParseResult parseRegion(OpAsmParser &parser, OperationState &result, return success(); } +/// Parses a list of arguments. static ParseResult parseArgsList(OpAsmParser &parser, SmallVector &args) { if (parser.parseLParen()) @@ -46,6 +53,7 @@ static ParseResult parseArgsList(OpAsmParser &parser, return success(); } +/// Prints a list of arguments in human-readable form. static void printArgsList(OpAsmPrinter &p, Region::BlockArgListType args, unsigned lb, unsigned ub) { p << " ("; @@ -59,6 +67,8 @@ static void printArgsList(OpAsmPrinter &p, Region::BlockArgListType args, p << ")"; } +/// Parses arguments with an optional "as" keyword to compactly represent +/// arguments and parameters. static ParseResult parseAsArgs(OpAsmParser &parser, OperationState &result, SmallVector &args) { if (parser.parseLParen()) @@ -97,6 +107,8 @@ static ParseResult parseAsArgs(OpAsmParser &parser, OperationState &result, return success(); } +/// Prints a list of arguments with an optional "as" keyword in human-readable +/// form. static void printAsArgs(OpAsmPrinter &p, OperandRange opRange, Region::BlockArgListType args, unsigned lb, unsigned ub) { @@ -115,13 +127,13 @@ static void printAsArgs(OpAsmPrinter &p, OperandRange opRange, // InlineSymbol //===----------------------------------------------------------------------===// -// There are 3 possible values that can be used as a number: symbols, integers -// and operands. Operands are stored as regular operands. Symbols as stringAttr -// and integers as int32 Attr. In order to encode the correct order of values -// we use an auxiliary attr called [attrName]_numList. -// The numList contains int32 Attrs with the following encoding: -// Positive int n: nth operand -// Negative int n: -nth - 1 Attribute (symbol or integer) in [attrName] +/// There are 3 possible values that can be used as a number: symbols, integers +/// and operands. Operands are stored as regular operands. Symbols as stringAttr +/// and integers as int32 Attr. In order to encode the correct order of values +/// we use an auxiliary attr called [attrName]_numList. +/// The numList contains int32 Attrs with the following encoding: +/// Positive int n: nth operand +/// Negative int n: -nth - 1 Attribute (symbol or integer) in [attrName] static ParseResult parseNumberList(OpAsmParser &parser, OperationState &result, StringRef attrName) { SmallVector opList; @@ -181,6 +193,7 @@ static ParseResult parseNumberList(OpAsmParser &parser, OperationState &result, return success(); } +/// Prints a list of number arguments in human-readable form. static void printNumberList(OpAsmPrinter &p, Operation *op, StringRef attrName) { ArrayAttr attrList = op->getAttr(attrName).cast(); @@ -209,6 +222,8 @@ static void printNumberList(OpAsmPrinter &p, Operation *op, } } +/// Prints a list of optional attributes excluding the number list in +/// human-readable form. static void printOptionalAttrDictNoNumList(OpAsmPrinter &p, ArrayRef attrs, ArrayRef elidedAttrs = {}) { @@ -221,6 +236,8 @@ printOptionalAttrDictNoNumList(OpAsmPrinter &p, ArrayRef attrs, p.printOptionalAttrDict(attrs, /*elidedAttrs=*/numListAttrs); } +/// Returns the length of the number list, which is equivalent to the number of +/// numeric arguments. static size_t getNumListSize(Operation *op, StringRef attrName) { ArrayAttr numList = op->getAttr(attrName.str() + "_numList").cast(); @@ -231,10 +248,7 @@ static size_t getNumListSize(Operation *op, StringRef attrName) { // SDFGNode //===----------------------------------------------------------------------===// -SDFGNode SDFGNode::create(PatternRewriter &rewriter, Location loc) { - return create(rewriter, loc, 0, {}); -} - +/// Builds, creates and inserts a SDFG node using the provided PatternRewriter. SDFGNode SDFGNode::create(PatternRewriter &rewriter, Location loc, unsigned num_args, TypeRange args) { OpBuilder builder(loc->getContext()); @@ -250,6 +264,12 @@ SDFGNode SDFGNode::create(PatternRewriter &rewriter, Location loc, return sdfg; } +/// Builds, creates and inserts a SDFG node using the provided PatternRewriter. +SDFGNode SDFGNode::create(PatternRewriter &rewriter, Location loc) { + return create(rewriter, loc, 0, {}); +} + +/// Attempts to parse a SDFG node. ParseResult SDFGNode::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -275,6 +295,7 @@ ParseResult SDFGNode::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a SDFG node in human-readable form. void SDFGNode::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"ID", "num_args"}); @@ -288,6 +309,7 @@ void SDFGNode::print(OpAsmPrinter &p) { /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); } +/// Verifies the correct structure of a SDFG node. LogicalResult SDFGNode::verify() { // Verify that no other dialect is used in the body for (Operation &oper : getBody().getOps()) @@ -301,6 +323,7 @@ LogicalResult SDFGNode::verify() { return success(); } +/// Verifies the correct structure of symbols in a SDFG node. LogicalResult SDFGNode::verifySymbolUses(SymbolTableCollection &symbolTable) { // Check that the entry attribute references valid state FlatSymbolRefAttr entryAttr = @@ -317,23 +340,59 @@ LogicalResult SDFGNode::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +/// Returns the first state in the SDFG node. StateNode SDFGNode::getFirstState() { return *getBody().getOps().begin(); } +/// Returns the state with the provided name (symbol) in the SDFG node. StateNode SDFGNode::getStateBySymRef(StringRef symRef) { Operation *op = lookupSymbol(symRef); return dyn_cast(op); } +/// Returns the entry state of the SDFG node. +StateNode SDFGNode::getEntryState() { + if (this->getEntry().has_value()) + return getStateBySymRef(this->getEntry().value()); + + return this->getFirstState(); +} + +/// Returns the list of arguments in the SDFG node. +Block::BlockArgListType SDFGNode::getArgs() { + return this->getBody().getArguments().take_front(getNumArgs()); +} + +/// Returns a list of argument types in the SDFG node. +TypeRange SDFGNode::getArgTypes() { + SmallVector types = {}; + for (BlockArgument BArg : getArgs()) { + types.push_back(BArg.getType()); + } + return TypeRange(types); +} + +/// Returns the list of results in the SDFG node. +Block::BlockArgListType SDFGNode::getResults() { + return this->getBody().getArguments().drop_front(getNumArgs()); +} + +/// Returns a list of result types in the SDFG node. +TypeRange SDFGNode::getResultTypes() { + SmallVector types = {}; + for (BlockArgument BArg : getResults()) { + types.push_back(BArg.getType()); + } + return TypeRange(types); +} + //===----------------------------------------------------------------------===// // NestedSDFGNode //===----------------------------------------------------------------------===// -NestedSDFGNode NestedSDFGNode::create(PatternRewriter &rewriter, Location loc) { - return create(rewriter, loc, 0, {}); -} - +/// Builds, creates and inserts a nested SDFG node using the provided +/// PatternRewriter. NestedSDFGNode NestedSDFGNode::create(PatternRewriter &rewriter, Location loc, unsigned num_args, ValueRange args) { OpBuilder builder(loc->getContext()); @@ -350,6 +409,13 @@ NestedSDFGNode NestedSDFGNode::create(PatternRewriter &rewriter, Location loc, return sdfg; } +/// Builds, creates and inserts a nested SDFG node using the provided +/// PatternRewriter. +NestedSDFGNode NestedSDFGNode::create(PatternRewriter &rewriter, Location loc) { + return create(rewriter, loc, 0, {}); +} + +/// Attempts to parse a nested SDFG node. ParseResult NestedSDFGNode::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -376,6 +442,7 @@ ParseResult NestedSDFGNode::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a nested SDFG node in human-readable form. void NestedSDFGNode::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"ID", "num_args"}); @@ -389,6 +456,7 @@ void NestedSDFGNode::print(OpAsmPrinter &p) { /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); } +/// Verifies the correct structure of a nested SDFG node. LogicalResult NestedSDFGNode::verify() { // Verify that no other dialect is used in the body for (Operation &oper : getBody().getOps()) @@ -406,6 +474,7 @@ LogicalResult NestedSDFGNode::verify() { return success(); } +/// Verifies the correct structure of symbols in a nested SDFG node. LogicalResult NestedSDFGNode::verifySymbolUses(SymbolTableCollection &symbolTable) { // Check that the entry attribute references valid state @@ -423,23 +492,40 @@ NestedSDFGNode::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +/// Returns the first state in the nested SDFG node. StateNode NestedSDFGNode::getFirstState() { return *getBody().getOps().begin(); } +/// Returns the state with the provided name (symbol) in the nested SDFG node. StateNode NestedSDFGNode::getStateBySymRef(StringRef symRef) { Operation *op = lookupSymbol(symRef); return dyn_cast(op); } +/// Returns the entry state of the nested SDFG node. +StateNode NestedSDFGNode::getEntryState() { + if (this->getEntry().has_value()) + return getStateBySymRef(this->getEntry().value()); + + return this->getFirstState(); +} + +/// Returns the list of arguments in the nested SDFG node. +ValueRange NestedSDFGNode::getArgs() { + return this->getOperands().take_front(getNumArgs()); +} + +/// Returns the list of results in the nested SDFG node. +ValueRange NestedSDFGNode::getResults() { + return this->getOperands().drop_front(getNumArgs()); +} + //===----------------------------------------------------------------------===// // StateNode //===----------------------------------------------------------------------===// -StateNode StateNode::create(PatternRewriter &rewriter, Location loc) { - return create(rewriter, loc, "state"); -} - +/// Builds, creates and inserts a state node using the provided PatternRewriter. StateNode StateNode::create(PatternRewriter &rewriter, Location loc, StringRef name) { OpBuilder builder(loc->getContext()); @@ -450,6 +536,12 @@ StateNode StateNode::create(PatternRewriter &rewriter, Location loc, return stateNode; } +/// Builds, creates and inserts a state node using the provided PatternRewriter. +StateNode StateNode::create(PatternRewriter &rewriter, Location loc) { + return create(rewriter, loc, "state"); +} + +/// Builds, creates and inserts a state node using Operation::create. StateNode StateNode::create(Location loc, StringRef name) { OpBuilder builder(loc->getContext()); OperationState state(loc, getOperationName()); @@ -457,6 +549,7 @@ StateNode StateNode::create(Location loc, StringRef name) { return cast(Operation::create(state)); } +/// Attempts to parse a state node. ParseResult StateNode::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -479,6 +572,7 @@ ParseResult StateNode::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a state node in human-readable form. void StateNode::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"ID", "sym_name"}); @@ -487,6 +581,7 @@ void StateNode::print(OpAsmPrinter &p) { p.printRegion(getBody()); } +/// Verifies the correct structure of a state node. LogicalResult StateNode::verify() { // Verify that no other dialect is used in the body // Except func operations @@ -501,6 +596,42 @@ LogicalResult StateNode::verify() { // TaskletNode //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a tasklet node using the provided +/// PatternRewriter. +TaskletNode TaskletNode::create(PatternRewriter &rewriter, Location location, + ValueRange operands, TypeRange results) { + OpBuilder builder(location->getContext()); + OperationState state(location, getOperationName()); + build(builder, state, results, utils::generateID(), operands); + + TaskletNode task = cast(rewriter.create(state)); + + std::vector locs = {}; + for (unsigned i = 0; i < operands.size(); ++i) + locs.push_back(location); + + rewriter.createBlock(&task.getRegion(), {}, operands.getTypes(), locs); + return task; +} + +/// Builds, creates and inserts a tasklet node using Operation::create. +TaskletNode TaskletNode::create(Location location, ValueRange operands, + TypeRange results) { + OpBuilder builder(location->getContext()); + OperationState state(location, getOperationName()); + build(builder, state, results, utils::generateID(), operands); + + TaskletNode task = cast(Operation::create(state)); + + std::vector locs = {}; + for (unsigned i = 0; i < operands.size(); ++i) + locs.push_back(location); + + builder.createBlock(&task.getBody(), {}, operands.getTypes(), locs); + return task; +} + +/// Attempts to parse a tasklet node. ParseResult TaskletNode::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -523,6 +654,7 @@ ParseResult TaskletNode::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a tasklet node in human-readable form. void TaskletNode::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"ID"}); printAsArgs(p, getOperands(), getBody().getArguments(), 0, getNumOperands()); @@ -531,6 +663,7 @@ void TaskletNode::print(OpAsmPrinter &p) { /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); } +/// Verifies the correct structure of a tasklet node. LogicalResult TaskletNode::verify() { // Verify that operands and arguments line up if (getNumOperands() != getBody().getNumArguments()) @@ -539,42 +672,12 @@ LogicalResult TaskletNode::verify() { return success(); } -TaskletNode TaskletNode::create(PatternRewriter &rewriter, Location location, - ValueRange operands, TypeRange results) { - OpBuilder builder(location->getContext()); - OperationState state(location, getOperationName()); - build(builder, state, results, utils::generateID(), operands); - - TaskletNode task = cast(rewriter.create(state)); - - std::vector locs = {}; - for (unsigned i = 0; i < operands.size(); ++i) - locs.push_back(location); - - rewriter.createBlock(&task.getRegion(), {}, operands.getTypes(), locs); - return task; -} - -TaskletNode TaskletNode::create(Location location, ValueRange operands, - TypeRange results) { - OpBuilder builder(location->getContext()); - OperationState state(location, getOperationName()); - build(builder, state, results, utils::generateID(), operands); - - TaskletNode task = cast(Operation::create(state)); - - std::vector locs = {}; - for (unsigned i = 0; i < operands.size(); ++i) - locs.push_back(location); - - builder.createBlock(&task.getBody(), {}, operands.getTypes(), locs); - return task; -} - +/// Returns the input name of the provided index. std::string TaskletNode::getInputName(unsigned idx) { return utils::valueToString(getBody().getArgument(idx), *getOperation()); } +/// Returns the output name of the provided index. std::string TaskletNode::getOutputName(unsigned idx) { return "__out" + std::to_string(idx); } @@ -583,6 +686,7 @@ std::string TaskletNode::getOutputName(unsigned idx) { // MapNode //===----------------------------------------------------------------------===// +/// Attempts to parse a map node. ParseResult MapNode::parse(OpAsmParser &parser, OperationState &result) { IntegerAttr intAttr = parser.getBuilder().getI32IntegerAttr(utils::generateID()); @@ -628,6 +732,7 @@ ParseResult MapNode::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a map node in human-readable form. void MapNode::print(OpAsmPrinter &p) { printOptionalAttrDictNoNumList( p, (*this)->getAttrs(), @@ -651,6 +756,7 @@ void MapNode::print(OpAsmPrinter &p) { /*printBlockTerminators=*/false); } +/// Verifies the correct structure of a map node. LogicalResult MapNode::verify() { size_t var_count = getBody().getArguments().size(); @@ -674,12 +780,14 @@ LogicalResult MapNode::verify() { return success(); } +/// Returns the body of the map node. Region &MapNode::getLoopBody() { return getBody(); } //===----------------------------------------------------------------------===// // ConsumeNode //===----------------------------------------------------------------------===// +/// Attempts to parse a consume node. ParseResult ConsumeNode::parse(OpAsmParser &parser, OperationState &result) { IntegerAttr intAttr = parser.getBuilder().getI32IntegerAttr(utils::generateID()); @@ -732,6 +840,7 @@ ParseResult ConsumeNode::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a consume node in human-readable form. void ConsumeNode::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"entryID", "exitID"}); @@ -742,6 +851,7 @@ void ConsumeNode::print(OpAsmPrinter &p) { /*printBlockTerminators=*/false); } +/// Verifies the correct structure of a consume node. LogicalResult ConsumeNode::verify() { if (getNumPes().has_value() && getNumPes().value().isNonPositive()) return emitOpError("failed to verify that number of " @@ -755,6 +865,7 @@ LogicalResult ConsumeNode::verify() { return success(); } +/// Verifies the correct structure of symbols in a consume node. LogicalResult ConsumeNode::verifySymbolUses(SymbolTableCollection &symbolTable) { // Check that the condition attribute is specified. @@ -780,14 +891,18 @@ ConsumeNode::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +/// Returns the body of the consume node. Region &ConsumeNode::getLoopBody() { return getBody(); } +/// Returns the argument corresponding to the processing element. BlockArgument ConsumeNode::pe() { return getBody().getArgument(0); } +/// Returns the argument corresponding to the popped element. BlockArgument ConsumeNode::elem() { return getBody().getArgument(1); } //===----------------------------------------------------------------------===// // EdgeOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts an edge using the provided PatternRewriter. EdgeOp EdgeOp::create(PatternRewriter &rewriter, Location loc, StateNode &from, StateNode &to, ArrayAttr &assign, StringAttr &condition, Value ref) { @@ -798,6 +913,7 @@ EdgeOp EdgeOp::create(PatternRewriter &rewriter, Location loc, StateNode &from, return cast(rewriter.create(state)); } +/// Builds, creates and inserts an edge using the provided PatternRewriter. EdgeOp EdgeOp::create(PatternRewriter &rewriter, Location loc, StateNode &from, StateNode &to) { OpBuilder builder(loc->getContext()); @@ -807,6 +923,7 @@ EdgeOp EdgeOp::create(PatternRewriter &rewriter, Location loc, StateNode &from, return cast(rewriter.create(state)); } +/// Builds, creates and inserts an edge using Operation::create. EdgeOp EdgeOp::create(Location loc, StateNode &from, StateNode &to, ArrayAttr &assign, StringAttr &condition, Value ref) { OpBuilder builder(loc->getContext()); @@ -816,6 +933,7 @@ EdgeOp EdgeOp::create(Location loc, StateNode &from, StateNode &to, return cast(Operation::create(state)); } +/// Attempts to parse a edge operation. ParseResult EdgeOp::parse(OpAsmParser &parser, OperationState &result) { FlatSymbolRefAttr srcAttr; FlatSymbolRefAttr destAttr; @@ -848,6 +966,7 @@ ParseResult EdgeOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a edge operation in human-readable form. void EdgeOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"src", "dest"}); p << ' '; @@ -858,6 +977,7 @@ void EdgeOp::print(OpAsmPrinter &p) { p.printAttributeWithoutType(getDestAttr()); } +/// Verifies the correct structure of an edge operation. LogicalResult EdgeOp::verify() { // Check that condition is non-empty if (getCondition().empty()) @@ -866,6 +986,7 @@ LogicalResult EdgeOp::verify() { return success(); } +/// Verifies the correct structure of symbols in an edge operation. LogicalResult EdgeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Check that the src/dest attributes are specified. FlatSymbolRefAttr srcAttr = (*this)->getAttrOfType("src"); @@ -896,6 +1017,41 @@ LogicalResult EdgeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // AllocOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts an allocation operation using the provided +/// PatternRewriter. +AllocOp AllocOp::create(PatternRewriter &rewriter, Location loc, Type res, + StringRef name, bool transient) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, getOperationName()); + StringAttr nameAttr = rewriter.getStringAttr(utils::generateName(name.str())); + build(builder, state, res, {}, nameAttr, transient); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts an allocation operation using the provided +/// PatternRewriter. +AllocOp AllocOp::create(PatternRewriter &rewriter, Location loc, Type res, + bool transient) { + return create(rewriter, loc, res, "arr", transient); +} + +/// Builds, creates and inserts an allocation operation using Operation::create. +AllocOp AllocOp::create(Location loc, Type res, StringRef name, + bool transient) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, getOperationName()); + StringAttr nameAttr = builder.getStringAttr(name); + + if (!res.isa()) { + SizedType sized = SizedType::get(res.getContext(), res, {}, {}, {}); + res = ArrayType::get(res.getContext(), sized); + } + + build(builder, state, res, {}, nameAttr, transient); + return cast(Operation::create(state)); +} + +/// Attempts to parse an allocation operation. ParseResult AllocOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -916,6 +1072,7 @@ ParseResult AllocOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints an allocation operation in human-readable form. void AllocOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << " ("; @@ -924,6 +1081,7 @@ void AllocOp::print(OpAsmPrinter &p) { p << getOperation()->getResultTypes(); } +/// Verifies the correct structure of an allocation operation. LogicalResult AllocOp::verify() { SizedType result = utils::getSizedType(getRes().getType()); @@ -938,49 +1096,25 @@ LogicalResult AllocOp::verify() { return success(); } -AllocOp AllocOp::create(PatternRewriter &rewriter, Location loc, Type res, - StringRef name, bool transient) { - OpBuilder builder(loc->getContext()); - OperationState state(loc, getOperationName()); - StringAttr nameAttr = rewriter.getStringAttr(utils::generateName(name.str())); - build(builder, state, res, {}, nameAttr, transient); - return cast(rewriter.create(state)); -} - -AllocOp AllocOp::create(PatternRewriter &rewriter, Location loc, Type res, - bool transient) { - return create(rewriter, loc, res, "arr", transient); -} - -AllocOp AllocOp::create(Location loc, Type res, StringRef name, - bool transient) { - OpBuilder builder(loc->getContext()); - OperationState state(loc, getOperationName()); - StringAttr nameAttr = builder.getStringAttr(name); - - if (!res.isa()) { - SizedType sized = SizedType::get(res.getContext(), res, {}, {}, {}); - res = ArrayType::get(res.getContext(), sized); - } - - build(builder, state, res, {}, nameAttr, transient); - return cast(Operation::create(state)); -} - +/// Returns the type of the elements in the allocated data container. Type AllocOp::getElementType() { return utils::getSizedType(getType()).getElementType(); } +/// Returns true if the allocated data container is a scalar. bool AllocOp::isScalar() { return utils::getSizedType(getType()).getShape().empty(); } +/// Returns true if the allocated data container is a stream. bool AllocOp::isStream() { return getType().isa(); } +/// Returns true if the allocation operation is inside a state. bool AllocOp::isInState() { return utils::getParentState(*this->getOperation()) != nullptr; } +/// Returns the name of the allocated data container. std::string AllocOp::getContainerName() { if ((*this)->hasAttr("name")) { Attribute nameAttr = (*this)->getAttr("name"); @@ -998,11 +1132,8 @@ std::string AllocOp::getContainerName() { // LoadOp //===----------------------------------------------------------------------===// -LoadOp LoadOp::create(PatternRewriter &rewriter, Location loc, AllocOp alloc, - ValueRange indices) { - return create(rewriter, loc, alloc.getType(), alloc, indices); -} - +/// Builds, creates and inserts a load operation using the provided +/// PatternRewriter. LoadOp LoadOp::create(PatternRewriter &rewriter, Location loc, Type t, Value mem, ValueRange indices) { OpBuilder builder(loc->getContext()); @@ -1026,10 +1157,14 @@ LoadOp LoadOp::create(PatternRewriter &rewriter, Location loc, Type t, return cast(rewriter.create(state)); } -LoadOp LoadOp::create(Location loc, AllocOp alloc, ValueRange indices) { - return create(loc, alloc.getType(), alloc, indices); +/// Builds, creates and inserts a load operation using the provided +/// PatternRewriter. +LoadOp LoadOp::create(PatternRewriter &rewriter, Location loc, AllocOp alloc, + ValueRange indices) { + return create(rewriter, loc, alloc.getType(), alloc, indices); } +/// Builds, creates and inserts a load operation using Operation::create. LoadOp LoadOp::create(Location loc, Type t, Value mem, ValueRange indices) { OpBuilder builder(loc->getContext()); OperationState state(loc, getOperationName()); @@ -1050,6 +1185,12 @@ LoadOp LoadOp::create(Location loc, Type t, Value mem, ValueRange indices) { return cast(Operation::create(state)); } +/// Builds, creates and inserts a load operation using Operation::create. +LoadOp LoadOp::create(Location loc, AllocOp alloc, ValueRange indices) { + return create(loc, alloc.getType(), alloc, indices); +} + +/// Attempts to parse a load operation. ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1085,6 +1226,7 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a load operation in human-readable form. void LoadOp::print(OpAsmPrinter &p) { printOptionalAttrDictNoNumList(p, (*this)->getAttrs(), /*elidedAttrs*/ {"indices"}); @@ -1098,6 +1240,7 @@ void LoadOp::print(OpAsmPrinter &p) { p << ArrayRef(getRes().getType()); } +/// Verifies the correct structure of a load operation. LogicalResult LoadOp::verify() { size_t idx_size = getNumListSize(getOperation(), "indices"); size_t mem_size = utils::getSizedType(getArr().getType()).getRank(); @@ -1108,12 +1251,15 @@ LogicalResult LoadOp::verify() { return success(); } +/// Returns true if the load operation has non-constant indices. bool LoadOp::isIndirect() { return !getIndices().empty(); } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a store operation using the provided +/// PatternRewriter. StoreOp StoreOp::create(PatternRewriter &rewriter, Location loc, Value val, Value mem, ValueRange indices) { OpBuilder builder(loc->getContext()); @@ -1134,6 +1280,7 @@ StoreOp StoreOp::create(PatternRewriter &rewriter, Location loc, Value val, return cast(rewriter.create(state)); } +/// Builds, creates and inserts a store operation using Operation::create. StoreOp StoreOp::create(Location loc, Value val, Value mem, ValueRange indices) { OpBuilder builder(loc->getContext()); @@ -1154,6 +1301,7 @@ StoreOp StoreOp::create(Location loc, Value val, Value mem, return cast(Operation::create(state)); } +/// Builds, creates and inserts a store operation using Operation::create. StoreOp StoreOp::create(Location loc, Value val, Value mem, ArrayRef indices) { OpBuilder builder(loc->getContext()); @@ -1177,6 +1325,7 @@ StoreOp StoreOp::create(Location loc, Value val, Value mem, return cast(Operation::create(state)); } +/// Builds, creates and inserts a store operation using Operation::create. StoreOp StoreOp::create(Location loc, Value val, Value mem) { OpBuilder builder(loc->getContext()); OperationState state(loc, getOperationName()); @@ -1196,6 +1345,7 @@ StoreOp StoreOp::create(Location loc, Value val, Value mem) { return cast(Operation::create(state)); } +/// Attempts to parse a store operation. ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1239,6 +1389,7 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a store operation in human-readable form. void StoreOp::print(OpAsmPrinter &p) { printOptionalAttrDictNoNumList(p, (*this)->getAttrs(), /*elidedAttrs=*/{"indices"}); @@ -1252,6 +1403,7 @@ void StoreOp::print(OpAsmPrinter &p) { p << ArrayRef(getArr().getType()); } +/// Verifies the correct structure of a store operation. LogicalResult StoreOp::verify() { size_t idx_size = getNumListSize(getOperation(), "indices"); size_t mem_size = utils::getSizedType(getArr().getType()).getRank(); @@ -1262,12 +1414,15 @@ LogicalResult StoreOp::verify() { return success(); } +/// Returns true if the store operation has non-constant indices. bool StoreOp::isIndirect() { return !getIndices().empty(); } //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a copy operation using the provided +/// PatternRewriter. CopyOp CopyOp::create(PatternRewriter &rewriter, Location loc, Value src, Value dst) { OpBuilder builder(loc->getContext()); @@ -1280,6 +1435,7 @@ CopyOp CopyOp::create(PatternRewriter &rewriter, Location loc, Value src, return cast(rewriter.create(state)); } +/// Attempts to parse a copy operation. ParseResult CopyOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1308,6 +1464,7 @@ ParseResult CopyOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a copy operation in human-readable form. void CopyOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << ' ' << getSrc() << " -> " << getDest(); @@ -1315,12 +1472,15 @@ void CopyOp::print(OpAsmPrinter &p) { p << ArrayRef(getSrc().getType()); } +/// Verifies the correct structure of a copy operation. LogicalResult CopyOp::verify() { return success(); } //===----------------------------------------------------------------------===// // ViewCastOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a viewcast operation using the provided +/// PatternRewriter. ViewCastOp ViewCastOp::create(PatternRewriter &rewriter, Location loc, Value array, Type type) { OpBuilder builder(loc->getContext()); @@ -1329,6 +1489,7 @@ ViewCastOp ViewCastOp::create(PatternRewriter &rewriter, Location loc, return cast(rewriter.create(state)); } +/// Attempts to parse a viewcast operation. ParseResult ViewCastOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1355,6 +1516,7 @@ ParseResult ViewCastOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a viewcast operation in human-readable form. void ViewCastOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << ' ' << getSrc(); @@ -1364,6 +1526,7 @@ void ViewCastOp::print(OpAsmPrinter &p) { p << getOperation()->getResultTypes(); } +/// Verifies the correct structure of a viewcast operation. LogicalResult ViewCastOp::verify() { size_t src_size = utils::getSizedType(getSrc().getType()).getRank(); size_t res_size = utils::getSizedType(getRes().getType()).getRank(); @@ -1378,6 +1541,8 @@ LogicalResult ViewCastOp::verify() { // SubviewOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a subview operation using the provided +/// PatternRewriter. SubviewOp SubviewOp::create(PatternRewriter &rewriter, Location loc, Type res, Value src, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides) { @@ -1394,6 +1559,7 @@ SubviewOp SubviewOp::create(PatternRewriter &rewriter, Location loc, Type res, return cast(rewriter.create(state)); } +/// Attempts to parse a subview operation. ParseResult SubviewOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1432,6 +1598,7 @@ ParseResult SubviewOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a subview operation in human-readable form. void SubviewOp::print(OpAsmPrinter &p) { printOptionalAttrDictNoNumList(p, (*this)->getAttrs(), {"offsets", "sizes", "strides"}); @@ -1448,12 +1615,14 @@ void SubviewOp::print(OpAsmPrinter &p) { p << getOperation()->getResultTypes(); } +/// Verifies the correct structure of a subview operation. LogicalResult SubviewOp::verify() { return success(); } //===----------------------------------------------------------------------===// // StreamPopOp //===----------------------------------------------------------------------===// +/// Attempts to parse a stream pop operation. ParseResult StreamPopOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1480,6 +1649,7 @@ ParseResult StreamPopOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a stream pop operation in human-readable form. void StreamPopOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << ' ' << getStr(); @@ -1489,12 +1659,14 @@ void StreamPopOp::print(OpAsmPrinter &p) { p << ArrayRef(getRes().getType()); } +/// Verifies the correct structure of a stream pop operation. LogicalResult StreamPopOp::verify() { return success(); } //===----------------------------------------------------------------------===// // StreamPushOp //===----------------------------------------------------------------------===// +/// Attempts to parse a stream push operation. ParseResult StreamPushOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1529,6 +1701,7 @@ ParseResult StreamPushOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a stream push operation in human-readable form. void StreamPushOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << ' ' << getVal() << ", " << getStr(); @@ -1538,12 +1711,14 @@ void StreamPushOp::print(OpAsmPrinter &p) { p << ArrayRef(getStr().getType()); } +/// Verifies the correct structure of a stream push operation. LogicalResult StreamPushOp::verify() { return success(); } //===----------------------------------------------------------------------===// // StreamLengthOp //===----------------------------------------------------------------------===// +/// Attempts to parse a stream length operation. ParseResult StreamLengthOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1570,6 +1745,7 @@ ParseResult StreamLengthOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a stream length operation in human-readable form. void StreamLengthOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << ' ' << getStr(); @@ -1579,6 +1755,7 @@ void StreamLengthOp::print(OpAsmPrinter &p) { p << getOperation()->getResultTypes(); } +/// Verifies the correct structure of a stream length operation. LogicalResult StreamLengthOp::verify() { Operation *parent = (*this)->getParentOp(); if (parent == nullptr) @@ -1596,6 +1773,25 @@ LogicalResult StreamLengthOp::verify() { // ReturnOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a return operation using the provided +/// PatternRewriter. +sdfg::ReturnOp sdfg::ReturnOp::create(PatternRewriter &rewriter, Location loc, + mlir::ValueRange input) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, getOperationName()); + build(builder, state, input); + return cast(rewriter.create(state)); +} + +/// Builds, creates and inserts a return operation using Operation::create. +sdfg::ReturnOp sdfg::ReturnOp::create(Location loc, mlir::ValueRange input) { + OpBuilder builder(loc->getContext()); + OperationState state(loc, getOperationName()); + build(builder, state, input); + return cast(Operation::create(state)); +} + +/// Attempts to parse a return operation. ParseResult sdfg::ReturnOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1615,12 +1811,14 @@ ParseResult sdfg::ReturnOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a return operation in human-readable form. void sdfg::ReturnOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); if (getNumOperands() > 0) p << ' ' << getInput() << " : " << getInput().getTypes(); } +/// Verifies the correct structure of a return operation. LogicalResult sdfg::ReturnOp::verify() { TaskletNode task = dyn_cast((*this)->getParentOp()); @@ -1630,25 +1828,12 @@ LogicalResult sdfg::ReturnOp::verify() { return success(); } -sdfg::ReturnOp sdfg::ReturnOp::create(PatternRewriter &rewriter, Location loc, - mlir::ValueRange input) { - OpBuilder builder(loc->getContext()); - OperationState state(loc, getOperationName()); - build(builder, state, input); - return cast(rewriter.create(state)); -} - -sdfg::ReturnOp sdfg::ReturnOp::create(Location loc, mlir::ValueRange input) { - OpBuilder builder(loc->getContext()); - OperationState state(loc, getOperationName()); - build(builder, state, input); - return cast(Operation::create(state)); -} - //===----------------------------------------------------------------------===// // LibCallOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a library call operation using the provided +/// PatternRewriter. LibCallOp LibCallOp::create(PatternRewriter &rewriter, Location loc, TypeRange result, StringRef callee, ValueRange operands) { @@ -1658,6 +1843,7 @@ LibCallOp LibCallOp::create(PatternRewriter &rewriter, Location loc, return cast(rewriter.create(state)); } +/// Attempts to parse a library call operation. ParseResult LibCallOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1686,6 +1872,7 @@ ParseResult LibCallOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a library call operation in human-readable form. void LibCallOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{"callee"}); @@ -1697,8 +1884,10 @@ void LibCallOp::print(OpAsmPrinter &p) { getOperation()->getResultTypes()); } +/// Verifies the correct structure of a library call operation. LogicalResult LibCallOp::verify() { return success(); } +/// Returns the input name of the provided index. std::string LibCallOp::getInputName(unsigned idx) { if (getOperation()->hasAttr("inputs")) { if (ArrayAttr inputs = @@ -1712,6 +1901,7 @@ std::string LibCallOp::getInputName(unsigned idx) { return utils::valueToString(getOperand(idx), *getOperation()); } +/// Returns the output name of the provided index. std::string LibCallOp::getOutputName(unsigned idx) { if (getOperation()->hasAttr("outputs")) { if (ArrayAttr outputs = @@ -1729,6 +1919,8 @@ std::string LibCallOp::getOutputName(unsigned idx) { // AllocSymbolOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a symbol allocation operation using the +/// provided PatternRewriter. AllocSymbolOp AllocSymbolOp::create(PatternRewriter &rewriter, Location loc, StringRef sym) { OpBuilder builder(loc->getContext()); @@ -1737,6 +1929,8 @@ AllocSymbolOp AllocSymbolOp::create(PatternRewriter &rewriter, Location loc, return cast(rewriter.create(state)); } +/// Builds, creates and inserts a symbol allocation operation using +/// Operation::create. AllocSymbolOp AllocSymbolOp::create(Location loc, StringRef sym) { OpBuilder builder(loc->getContext()); OperationState state(loc, getOperationName()); @@ -1744,6 +1938,7 @@ AllocSymbolOp AllocSymbolOp::create(Location loc, StringRef sym) { return cast(Operation::create(state)); } +/// Attempts to parse a symbol allocation operation. ParseResult AllocSymbolOp::parse(OpAsmParser &parser, OperationState &result) { StringAttr symAttr; if (parser.parseOptionalAttrDict(result.attributes)) @@ -1760,6 +1955,7 @@ ParseResult AllocSymbolOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a symbol allocation operation in human-readable form. void AllocSymbolOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{"sym"}); p << " ("; @@ -1767,6 +1963,7 @@ void AllocSymbolOp::print(OpAsmPrinter &p) { p << ")"; } +/// Verifies the correct structure of a symbol allocation operation. LogicalResult AllocSymbolOp::verify() { if (getSym().empty()) return emitOpError("failed to verify that input string is not empty"); @@ -1787,6 +1984,8 @@ LogicalResult AllocSymbolOp::verify() { // SymOp //===----------------------------------------------------------------------===// +/// Builds, creates and inserts a symbolic expression operation using the +/// provided PatternRewriter. SymOp SymOp::create(PatternRewriter &rewriter, Location loc, Type type, StringRef expr) { OpBuilder builder(loc->getContext()); @@ -1795,6 +1994,8 @@ SymOp SymOp::create(PatternRewriter &rewriter, Location loc, Type type, return cast(rewriter.create(state)); } +/// Builds, creates and inserts a symbolic expression operation using +/// Operation::create. SymOp SymOp::create(Location loc, Type type, StringRef expr) { OpBuilder builder(loc->getContext()); OperationState state(loc, getOperationName()); @@ -1802,6 +2003,7 @@ SymOp SymOp::create(Location loc, Type type, StringRef expr) { return cast(Operation::create(state)); } +/// Attempts to parse a symbolic expression operation. ParseResult SymOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); @@ -1820,6 +2022,7 @@ ParseResult SymOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +/// Prints a symbolic expression operation in human-readable form. void SymOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{"expr"}); p << " ("; @@ -1827,11 +2030,13 @@ void SymOp::print(OpAsmPrinter &p) { p << ") : " << getOperation()->getResultTypes(); } +/// Verifies the correct structure of a symbolic expression operation. LogicalResult SymOp::verify() { return success(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// +/// Generate the code for operation definitions. #define GET_OP_CLASSES #include "SDFG/Dialect/Ops.cpp.inc" diff --git a/lib/SDFG/Translate/CMakeLists.txt b/lib/SDFG/Translate/CMakeLists.txt index 6104a6e12..f84db447b 100644 --- a/lib/SDFG/Translate/CMakeLists.txt +++ b/lib/SDFG/Translate/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_mlir_translation_library( MLIRTargetSDFG registration.cpp diff --git a/lib/SDFG/Translate/JsonEmitter.cpp b/lib/SDFG/Translate/JsonEmitter.cpp index cdaeea417..df4066a98 100644 --- a/lib/SDFG/Translate/JsonEmitter.cpp +++ b/lib/SDFG/Translate/JsonEmitter.cpp @@ -1,3 +1,8 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains a JSON emitter, which constraints the generated output +/// stream to valid JSON. + #include "SDFG/Translate/JsonEmitter.h" #include "mlir/IR/Types.h" @@ -5,6 +10,7 @@ using namespace mlir; using namespace sdfg; using namespace emitter; +/// Creates a new JSON emitter. JsonEmitter::JsonEmitter(raw_ostream &os) : os(os) { indentation = 0; error = false; @@ -13,6 +19,8 @@ JsonEmitter::JsonEmitter(raw_ostream &os) : os(os) { symStack.clear(); } +/// Checks for errors (open objects/lists) and adds trailing newline. Returns +/// a LogicalResult indicating success or failure. LogicalResult JsonEmitter::finish() { while (!symStack.empty()) { SYM sym = symStack.pop_back_val(); @@ -31,11 +39,14 @@ LogicalResult JsonEmitter::finish() { return failure(error); } +/// Increases the indentation level. void JsonEmitter::indent() { indentation += 2; } +/// Decreases the indentation level. void JsonEmitter::unindent() { indentation = indentation >= 2 ? indentation - 2 : 0; } +/// Starts a new line in the output stream. void JsonEmitter::newLine() { if (emptyLine) return; @@ -43,6 +54,7 @@ void JsonEmitter::newLine() { emptyLine = true; } +/// Prints a literal string to the output stream. void JsonEmitter::printLiteral(StringRef str) { if (emptyLine) os.indent(indentation); @@ -50,18 +62,22 @@ void JsonEmitter::printLiteral(StringRef str) { emptyLine = false; } +/// Prints a string to the output stream, surrounding it with quotation marks. void JsonEmitter::printString(StringRef str) { printLiteral("\""); printLiteral(str); printLiteral("\""); } +/// Prints an integer to the output stream, surrounding it with quotation +/// marks. void JsonEmitter::printInt(int i) { printLiteral("\""); os << i; printLiteral("\""); } +/// Starts a new JSON object. void JsonEmitter::startObject() { startEntry(); printLiteral("{"); @@ -78,6 +94,7 @@ void JsonEmitter::startObject() { firstEntry = true; } +/// Starts a new named (keyed) JSON object. void JsonEmitter::startNamedObject(StringRef name) { startEntry(); printString(name); @@ -96,6 +113,7 @@ void JsonEmitter::startNamedObject(StringRef name) { firstEntry = true; } +/// Ends the current JSON object. void JsonEmitter::endObject() { newLine(); unindent(); @@ -104,6 +122,35 @@ void JsonEmitter::endObject() { firstEntry = false; } +/// Starts a new named JSON list. +void JsonEmitter::startNamedList(StringRef name) { + startEntry(); + printString(name); + printLiteral(": "); + printLiteral("["); + if (!symStack.empty() && symStack.back() == SYM::SQUARE) { + // Can't have keyed values in a list + os.changeColor(os.RED, /*Bold=*/true); + printLiteral(" <<<<<<<<<<<< Started keyed list in a list"); + os.resetColor(); + error = true; + } + symStack.push_back(SYM::SQUARE); + indent(); + newLine(); + firstEntry = true; +} + +/// Ends the current JSON list. +void JsonEmitter::endList() { + newLine(); + unindent(); + tryPop(SYM::SQUARE); + printLiteral("]"); + firstEntry = false; +} + +/// Starts a new entry in the current JSON object or list. void JsonEmitter::startEntry() { if (!firstEntry) printLiteral(","); @@ -111,6 +158,8 @@ void JsonEmitter::startEntry() { newLine(); } +/// Prints a key-value pair to the output stream. If desired, turns the value +/// into string. void JsonEmitter::printKVPair(StringRef key, StringRef val, bool stringify) { startEntry(); printString(key); @@ -121,6 +170,8 @@ void JsonEmitter::printKVPair(StringRef key, StringRef val, bool stringify) { printLiteral(val); } +/// Prints a key-value pair to the output stream. If desired, turns the value +/// into string. void JsonEmitter::printKVPair(StringRef key, int val, bool stringify) { startEntry(); printString(key); @@ -131,6 +182,8 @@ void JsonEmitter::printKVPair(StringRef key, int val, bool stringify) { os << val; } +/// Prints a key-value pair to the output stream. If desired, turns the value +/// into string. void JsonEmitter::printKVPair(StringRef key, Attribute val, bool stringify) { startEntry(); printString(key); @@ -146,32 +199,21 @@ void JsonEmitter::printKVPair(StringRef key, Attribute val, bool stringify) { } } -void JsonEmitter::startNamedList(StringRef name) { - startEntry(); - printString(name); - printLiteral(": "); - printLiteral("["); - if (!symStack.empty() && symStack.back() == SYM::SQUARE) { - // Can't have keyed values in a list - os.changeColor(os.RED, /*Bold=*/true); - printLiteral(" <<<<<<<<<<<< Started keyed list in a list"); - os.resetColor(); - error = true; - } - symStack.push_back(SYM::SQUARE); - indent(); - newLine(); - firstEntry = true; -} +/// Prints a list of NamedAttributes as key-value pairs. +void JsonEmitter::printAttributes(ArrayRef arr, + ArrayRef elidedAttrs) { -void JsonEmitter::endList() { - newLine(); - unindent(); - tryPop(SYM::SQUARE); - printLiteral("]"); - firstEntry = false; + llvm::SmallDenseSet elidedAttrsSet(elidedAttrs.begin(), + elidedAttrs.end()); + + for (NamedAttribute attr : arr) { + if (elidedAttrsSet.contains(attr.getName().strref())) + continue; + printKVPair(attr.getName().strref(), attr.getValue()); + } } +/// Tries to pop a symbol from the symStack, checking for matching symbols. void JsonEmitter::tryPop(SYM sym) { if (symStack.empty()) { if (sym == SYM::BRACE) @@ -200,16 +242,3 @@ void JsonEmitter::tryPop(SYM sym) { symStack.pop_back(); } } - -void JsonEmitter::printAttributes(ArrayRef arr, - ArrayRef elidedAttrs) { - - llvm::SmallDenseSet elidedAttrsSet(elidedAttrs.begin(), - elidedAttrs.end()); - - for (NamedAttribute attr : arr) { - if (elidedAttrsSet.contains(attr.getName().strref())) - continue; - printKVPair(attr.getName().strref(), attr.getValue()); - } -} diff --git a/lib/SDFG/Translate/Node.cpp b/lib/SDFG/Translate/Node.cpp index 22b6cfb6d..f906e2586 100644 --- a/lib/SDFG/Translate/Node.cpp +++ b/lib/SDFG/Translate/Node.cpp @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the nodes of the internal IR used by the translator. + #include "SDFG/Translate/Node.h" using namespace mlir; @@ -8,6 +12,7 @@ using namespace translation; // Helpers //===----------------------------------------------------------------------===// +/// Converts a MLIR type to a DaCe DType. DType typeToDtype(Type t) { if (t.isInteger(1)) return DType::boolean; @@ -15,12 +20,18 @@ DType typeToDtype(Type t) { if (t.isInteger(8)) return DType::int8; + if (t.isInteger(16)) + return DType::int16; + if (t.isInteger(32)) return DType::int32; if (t.isInteger(64)) return DType::int64; + if (t.isF16()) + return DType::float16; + if (t.isF32()) return DType::float32; @@ -40,16 +51,21 @@ DType typeToDtype(Type t) { return DType::null; } +/// Converts a DType to a string. std::string dtypeToString(DType t) { switch (t) { case DType::boolean: return "bool"; case DType::int8: return "int8"; + case DType::int16: + return "int16"; case DType::int32: return "int32"; case DType::int64: return "int64"; + case DType::float16: + return "float16"; case DType::float32: return "float32"; case DType::float64: @@ -61,10 +77,13 @@ std::string dtypeToString(DType t) { return "Unsupported DType"; } +/// Converts a CodeLanguage to a string. std::string codeLanguageToString(CodeLanguage lang) { switch (lang) { case CodeLanguage::Python: return "Python"; + case CodeLanguage::CPP: + return "CPP"; case CodeLanguage::MLIR: return "MLIR"; } @@ -72,6 +91,7 @@ std::string codeLanguageToString(CodeLanguage lang) { return "Unsupported CodeLanguage"; } +/// Prints an array of ranges to the output stream. void printRangeVector(std::vector ranges, std::string name, emitter::JsonEmitter &jemit) { if (ranges.empty()) { @@ -88,6 +108,7 @@ void printRangeVector(std::vector ranges, std::string name, jemit.endObject(); // name } +/// Prints source location information as debug information. void printLocation(Location loc, emitter::JsonEmitter &jemit) { jemit.startNamedObject("debuginfo"); jemit.printKVPair("type", "DebugInfo"); @@ -121,6 +142,7 @@ void printLocation(Location loc, emitter::JsonEmitter &jemit) { // Array //===----------------------------------------------------------------------===// +/// Emits this array to the output stream. void Array::emit(emitter::JsonEmitter &jemit) { jemit.startNamedObject(name); @@ -206,6 +228,7 @@ void Array::emit(emitter::JsonEmitter &jemit) { // Range //===----------------------------------------------------------------------===// +/// Emits this range to the output stream. void translation::Range::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); jemit.printKVPair("start", start); @@ -219,24 +242,30 @@ void translation::Range::emit(emitter::JsonEmitter &jemit) { // InterstateEdge //===----------------------------------------------------------------------===// +/// Sets the condition of the interstate edge. void InterstateEdge::setCondition(Condition condition) { ptr->setCondition(condition); } +/// Adds an assignment to the interstate edge. void InterstateEdge::addAssignment(Assignment assignment) { ptr->addAssignment(assignment); } +/// Emits the interstate edge to the output stream. void InterstateEdge::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Sets the condition of the interstate edge. void InterstateEdgeImpl::setCondition(Condition condition) { this->condition = condition; } +/// Adds an assignment to the interstate edge. void InterstateEdgeImpl::addAssignment(Assignment assignment) { assignments.push_back(assignment); } +/// Emits the interstate edge to the output stream. void InterstateEdgeImpl::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); jemit.printKVPair("type", "Edge"); @@ -271,6 +300,7 @@ void InterstateEdgeImpl::emit(emitter::JsonEmitter &jemit) { // MultiEdge //===----------------------------------------------------------------------===// +/// Emits this edge to the output stream. void MultiEdge::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); jemit.printKVPair("type", "MultiConnectorEdge"); @@ -313,19 +343,34 @@ void MultiEdge::emit(emitter::JsonEmitter &jemit) { // Node //===----------------------------------------------------------------------===// +/// Sets the ID of the Node. void Node::setID(unsigned id) { ptr->setID(id); } + +/// Returns the ID of the Node. unsigned Node::getID() { return ptr->getID(); } +/// Returns the source code location. Location Node::getLocation() { return ptr->getLocation(); } + +/// Returns the type of the Node. NType Node::getType() { return type; } +/// Sets the name of the node. void Node::setName(StringRef name) { ptr->setName(name); } + +/// Returns the name of the node. StringRef Node::getName() { return ptr->getName(); } +/// Sets the parent of the node. void Node::setParent(Node parent) { ptr->setParent(parent); } + +/// Returns the parent of the node. Node Node::getParent() { return ptr->getParent(); } + +/// Return true if this node has a parent node. bool Node::hasParent() { return getParent().ptr != nullptr; } +/// Returns the top-level SDFG. SDFG Node::getSDFG() { if (type == NType::SDFG) { return SDFG(std::static_pointer_cast(ptr)); @@ -333,6 +378,7 @@ SDFG Node::getSDFG() { return ptr->getParent().getSDFG(); } +/// Returns the surrounding state. State Node::getState() { if (type == NType::State) { return State(std::static_pointer_cast(ptr)); @@ -340,48 +386,74 @@ State Node::getState() { return ptr->getParent().getState(); } +/// Adds an attribute to this node, replaces existing attributes with the same +/// name. void Node::addAttribute(Attribute attribute) { ptr->addAttribute(attribute); } + +/// Emits this node to the output stream. void Node::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Sets the ID of the node. void NodeImpl::setID(unsigned id) { this->id = id; } + +/// Returns the ID of the node. unsigned NodeImpl::getID() { return id; } +/// Returns the source code location. Location NodeImpl::getLocation() { return location; } +/// Sets the name of the node. void NodeImpl::setName(StringRef name) { this->name = name.str(); utils::sanitizeName(this->name); } +/// Returns the name of the node. StringRef NodeImpl::getName() { return name; } +/// Sets the parent of the node. void NodeImpl::setParent(Node parent) { this->parent = parent; } + +/// Returns the parent of the node. Node NodeImpl::getParent() { return parent; } +/// Adds an attribute to this node, replaces existing attributes with the same +/// name. void NodeImpl::addAttribute(Attribute attribute) { attributes.push_back(attribute); } +/// Emits this node to the output stream. void NodeImpl::emit(emitter::JsonEmitter &jemit) {} //===----------------------------------------------------------------------===// // ConnectorNode //===----------------------------------------------------------------------===// +/// Adds an incoming connector. void ConnectorNode::addInConnector(Connector connector) { ptr->addInConnector(connector); } + +/// Adds an outgoing connector. void ConnectorNode::addOutConnector(Connector connector) { ptr->addOutConnector(connector); } + +/// Returns to number of incoming connectors. unsigned ConnectorNode::getInConnectorCount() { return ptr->getInConnectorCount(); } + +/// Returns to number of outgoing connectors. unsigned ConnectorNode::getOutConnectorCount() { return ptr->getOutConnectorCount(); } + +/// Emits the connectors to the output stream. void ConnectorNode::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Adds an incoming connector. void ConnectorNodeImpl::addInConnector(Connector connector) { if (std::find(inConnectors.begin(), inConnectors.end(), connector) != inConnectors.end()) { @@ -393,6 +465,7 @@ void ConnectorNodeImpl::addInConnector(Connector connector) { inConnectors.push_back(connector); } +/// Adds an outgoing connector. void ConnectorNodeImpl::addOutConnector(Connector connector) { if (std::find(outConnectors.begin(), outConnectors.end(), connector) != outConnectors.end()) { @@ -404,14 +477,17 @@ void ConnectorNodeImpl::addOutConnector(Connector connector) { outConnectors.push_back(connector); } +/// Returns to number of incoming connectors. unsigned ConnectorNodeImpl::getInConnectorCount() { return inConnectors.size(); } +/// Returns to number of outgoing connectors. unsigned ConnectorNodeImpl::getOutConnectorCount() { return outConnectors.size(); } +/// Emits the connectors to the output stream. void ConnectorNodeImpl::emit(emitter::JsonEmitter &jemit) { jemit.startNamedObject("in_connectors"); for (Connector c : inConnectors) { @@ -434,6 +510,7 @@ void ConnectorNodeImpl::emit(emitter::JsonEmitter &jemit) { // ScopeNode //===----------------------------------------------------------------------===// +/// Adds a connector node to the scope. void ScopeNode::addNode(ConnectorNode node) { if (!node.hasParent()) { node.setParent(*this); @@ -441,16 +518,20 @@ void ScopeNode::addNode(ConnectorNode node) { ptr->addNode(node); } +/// Adds a multiedge from the source to the destination connector. void ScopeNode::routeWrite(Connector from, Connector to) { ptr->routeWrite(from, to); } +/// Adds an edge to the scope. void ScopeNode::addEdge(MultiEdge edge) { ptr->addEdge(edge); } +/// Maps the MLIR value to the specified connector. void ScopeNode::mapConnector(Value value, Connector connector) { ptr->mapConnector(value, connector); } +/// Returns the connector associated with a MLIR value. Connector ScopeNode::lookup(Value value) { if (type == NType::MapEntry) { return MapEntry(*this).lookup(value); @@ -462,20 +543,26 @@ Connector ScopeNode::lookup(Value value) { return ptr->lookup(value); } + +/// Emits all nodes and edges to the output stream. void ScopeNode::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Adds a connector node to the scope. void ScopeNodeImpl::addNode(ConnectorNode node) { node.setID(nodes.size()); nodes.push_back(node); } +/// Adds a multiedge from the source to the destination connector. void ScopeNodeImpl::routeWrite(Connector from, Connector to) { MultiEdge edge(location, from, to); addEdge(edge); } +/// Adds an edge to the scope. void ScopeNodeImpl::addEdge(MultiEdge edge) { edges.push_back(edge); } +/// Maps the MLIR value to the specified connector. void ScopeNodeImpl::mapConnector(Value value, Connector connector) { auto res = lut.insert({utils::valueToString(value), connector}); @@ -483,6 +570,7 @@ void ScopeNodeImpl::mapConnector(Value value, Connector connector) { res.first->second = connector; } +/// Returns the connector associated with a MLIR value. Connector ScopeNodeImpl::lookup(Value value) { if (lut.find(utils::valueToString(value)) == lut.end()) { emitError(location, @@ -491,6 +579,7 @@ Connector ScopeNodeImpl::lookup(Value value) { return lut.find(utils::valueToString(value))->second; } +/// Emits all nodes and edges to the output stream. void ScopeNodeImpl::emit(emitter::JsonEmitter &jemit) { jemit.startNamedList("nodes"); for (ConnectorNode cn : nodes) @@ -507,23 +596,46 @@ void ScopeNodeImpl::emit(emitter::JsonEmitter &jemit) { // SDFG //===----------------------------------------------------------------------===// +/// Returns the state associated with the provided name. +State SDFG::lookup(StringRef name) { return ptr->lookup(name); } + +/// Adds a state to the SDFG. void SDFG::addState(State state) { state.setParent(*this); ptr->addState(state); } -State SDFG::lookup(StringRef name) { return ptr->lookup(name); } +/// Adds a state to the SDFG and marks it as the entry state. void SDFG::setStartState(State state) { ptr->setStartState(state); } + +/// Adds an interstate edge to the SDFG, connecting two states. void SDFG::addEdge(InterstateEdge edge) { ptr->addEdge(edge); } + +/// Adds an array (data container) to the SDFG. void SDFG::addArray(Array array) { ptr->addArray(array); } + +/// Adds an array (data container) to the SDFG and marks it as an argument. void SDFG::addArg(Array arg) { ptr->addArg(arg); } + +/// Adds a symbol to the SDFG. void SDFG::addSymbol(Symbol symbol) { ptr->addSymbol(symbol); } + +/// Returns an array of all symbols in the SDFG. std::vector SDFG::getSymbols() { return ptr->getSymbols(); } + +/// Emits the SDFG to the output stream. void SDFG::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); }; + +/// Emits the SDFG as a nested SDFG to the output stream. void SDFG::emitNested(emitter::JsonEmitter &jemit) { ptr->emitNested(jemit); }; +/// Global counter for the ID of SDFGs. unsigned SDFGImpl::list_id = 0; +/// Returns the state associated with the provided name. +State SDFGImpl::lookup(StringRef name) { return lut.find(name.str())->second; } + +/// Adds a state to the SDFG. void SDFGImpl::addState(State state) { state.setID(states.size()); states.push_back(state); @@ -532,6 +644,7 @@ void SDFGImpl::addState(State state) { emitError(location, "Duplicate ID in SDFGImpl::addState"); } +/// Adds a state to the SDFG and marks it as the entry state. void SDFGImpl::setStartState(State state) { if (std::find(states.begin(), states.end(), state) == states.end()) emitError( @@ -541,28 +654,37 @@ void SDFGImpl::setStartState(State state) { this->startState = state; } +/// Adds an interstate edge to the SDFG, connecting two states. void SDFGImpl::addEdge(InterstateEdge edge) { edges.push_back(edge); } -State SDFGImpl::lookup(StringRef name) { return lut.find(name.str())->second; } +/// Adds an array (data container) to the SDFG. void SDFGImpl::addArray(Array array) { arrays.push_back(array); } + +/// Adds an array (data container) to the SDFG and marks it as an argument. void SDFGImpl::addArg(Array arg) { args.push_back(arg); addArray(arg); } +/// Adds a symbol to the SDFG. void SDFGImpl::addSymbol(Symbol symbol) { symbols.push_back(symbol); } + +/// Returns an array of all symbols in the SDFG. std::vector SDFGImpl::getSymbols() { return symbols; } +/// Emits the SDFG to the output stream. void SDFGImpl::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); emitBody(jemit); } +/// Emits the SDFG as a nested SDFG to the output stream. void SDFGImpl::emitNested(emitter::JsonEmitter &jemit) { jemit.startNamedObject("sdfg"); emitBody(jemit); } +/// Emits the body of the SDFG to the output stream. void SDFGImpl::emitBody(emitter::JsonEmitter &jemit) { jemit.printKVPair("type", "SDFG"); jemit.printKVPair("sdfg_list_id", id, /*stringify=*/false); @@ -612,8 +734,10 @@ void SDFGImpl::emitBody(emitter::JsonEmitter &jemit) { // NestedSDFG //===----------------------------------------------------------------------===// +/// Emits the nested SDFG to the output stream. void NestedSDFG::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Emits the nested SDFG to the output stream. void NestedSDFGImpl::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); jemit.printKVPair("type", "NestedSDFG"); @@ -641,9 +765,15 @@ void NestedSDFGImpl::emit(emitter::JsonEmitter &jemit) { // State //===----------------------------------------------------------------------===// +/// Modified lookup function creates access nodes if the value could not be +/// found. Connector State::lookup(Value value) { return ptr->lookup(value); } + +/// Emits the state node to the output stream. void State::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Modified lookup function creates access nodes if the value could not be +/// found. Connector StateImpl::lookup(Value value) { if (lut.find(utils::valueToString(value)) == lut.end()) { Access access(location); @@ -667,6 +797,7 @@ Connector StateImpl::lookup(Value value) { return ScopeNodeImpl::lookup(value); } +/// Emits the state node to the output stream. void StateImpl::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); jemit.printKVPair("type", "SDFGState"); @@ -685,11 +816,36 @@ void StateImpl::emit(emitter::JsonEmitter &jemit) { // Tasklet //===----------------------------------------------------------------------===// +/// Sets the code of the tasklet. void Tasklet::setCode(Code code) { ptr->setCode(code); } + +/// Sets the global code of the tasklet. +void Tasklet::setGlobalCode(Code code_global) { + ptr->setGlobalCode(code_global); +} + +/// Sets the side effect flag of the tasklet. +void Tasklet::setHasSideEffect(bool hasSideEffect) { + ptr->setHasSideEffect(hasSideEffect); +} + +/// Emits the tasklet to the output stream. void Tasklet::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Sets the code of the tasklet. void TaskletImpl::setCode(Code code) { this->code = code; } +/// Sets the global code of the tasklet. +void TaskletImpl::setGlobalCode(Code code_global) { + this->code_global = code_global; +} + +/// Sets the side effect flag of the tasklet. +void TaskletImpl::setHasSideEffect(bool hasSideEffect) { + this->hasSideEffect = hasSideEffect; +} + +/// Emits the tasklet to the output stream. void TaskletImpl::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); jemit.printKVPair("type", "Tasklet"); @@ -705,6 +861,13 @@ void TaskletImpl::emit(emitter::JsonEmitter &jemit) { jemit.printKVPair("language", codeLanguageToString(code.language)); jemit.endObject(); // code + jemit.startNamedObject("code_global"); + jemit.printKVPair("string_data", code_global.data); + jemit.printKVPair("language", codeLanguageToString(code_global.language)); + jemit.endObject(); // code_global + + jemit.printKVPair("side_effects", hasSideEffect ? "true" : "false", + /*stringify=*/false); ConnectorNodeImpl::emit(jemit); jemit.endObject(); // attributes @@ -715,16 +878,20 @@ void TaskletImpl::emit(emitter::JsonEmitter &jemit) { // Library //===----------------------------------------------------------------------===// +/// Sets the library code path. void Library::setClasspath(StringRef classpath) { ptr->setClasspath(classpath); } +/// Emits the library node to the output stream. void Library::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Sets the library code path. void LibraryImpl::setClasspath(StringRef classpath) { this->classpath = classpath.str(); } +/// Emits the library node to the output stream. void LibraryImpl::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); jemit.printKVPair("type", "LibraryNode"); @@ -745,8 +912,10 @@ void LibraryImpl::emit(emitter::JsonEmitter &jemit) { // Access //===----------------------------------------------------------------------===// +/// Emits the access node to the output stream void Access::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Emits the access node to the output stream void AccessImpl::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); jemit.printKVPair("type", "AccessNode"); @@ -766,37 +935,63 @@ void AccessImpl::emit(emitter::JsonEmitter &jemit) { // Map //===----------------------------------------------------------------------===// +/// Adds a parameter to the map entry. +void MapEntry::addParam(StringRef param) { ptr->addParam(param); } + +/// Adds a range for a parameter. +void MapEntry::addRange(Range range) { ptr->addRange(range); } + +/// Sets the map exit this map entry belongs to. +void MapEntry::setExit(MapExit exit) { ptr->setExit(exit); } + +/// Returns the matching map exit. +MapExit MapEntry::getExit() { return ptr->getExit(); } + +/// Adds a connector node to the scope. void MapEntry::addNode(ConnectorNode node) { node.setParent(*this); ptr->addNode(node); } +/// Adds a multiedge from the source to the destination connector. void MapEntry::routeWrite(Connector from, Connector to) { ptr->routeWrite(from, to); } +/// Adds an edge to the scope. void MapEntry::addEdge(MultiEdge edge) { ptr->addEdge(edge); } + +/// Maps the MLIR value to the specified connector. void MapEntry::mapConnector(Value value, Connector connector) { ptr->mapConnector(value, connector); } +/// Returns the connector associated with a MLIR value, inserting map +/// connectors when needed. Connector MapEntry::lookup(Value value) { return ptr->lookup(value, *this); } -void MapEntry::addParam(StringRef param) { ptr->addParam(param); } -void MapEntry::addRange(Range range) { ptr->addRange(range); } -void MapEntry::setExit(MapExit exit) { ptr->setExit(exit); } -MapExit MapEntry::getExit() { return ptr->getExit(); } + +/// Emits the map entry to the output stream. void MapEntry::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Adds a parameter to the map entry. void MapEntryImpl::addParam(StringRef param) { params.push_back(param.str()); } + +/// Adds a range for a parameter. void MapEntryImpl::addRange(Range range) { ranges.push_back(range); } + +/// Sets the map exit this map entry belongs to. void MapEntryImpl::setExit(MapExit exit) { this->exit = exit; } + +/// Returns the matching map exit. MapExit MapEntryImpl::getExit() { return exit; } +/// Adds a connector node to the scope. void MapEntryImpl::addNode(ConnectorNode node) { ScopeNode scope(parent); scope.addNode(node); } +/// Adds a multiedge from the source to the destination connector. void MapEntryImpl::routeWrite(Connector from, Connector to) { MapExit mapExit = getExit(); Connector in(mapExit, "IN_" + std::to_string(mapExit.getInConnectorCount())); @@ -817,11 +1012,13 @@ void MapEntryImpl::routeWrite(Connector from, Connector to) { scope.routeWrite(out, to); } +/// Adds an edge to the scope. void MapEntryImpl::addEdge(MultiEdge edge) { ScopeNode scope(parent); scope.addEdge(edge); } +/// Maps the MLIR value to the specified connector. void MapEntryImpl::mapConnector(Value value, Connector connector) { auto res = lut.insert({utils::valueToString(value), connector}); @@ -829,6 +1026,8 @@ void MapEntryImpl::mapConnector(Value value, Connector connector) { res.first->second = connector; } +/// Returns the connector associated with a MLIR value, inserting map +/// connectors when needed. Connector MapEntryImpl::lookup(Value value, MapEntry mapEntry) { if (lut.find(utils::valueToString(value)) == lut.end()) { ScopeNode scope(parent); @@ -852,6 +1051,7 @@ Connector MapEntryImpl::lookup(Value value, MapEntry mapEntry) { return ScopeNodeImpl::lookup(value); } +/// Emits the map entry to the output stream. void MapEntryImpl::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); jemit.printKVPair("type", "MapEntry"); @@ -878,11 +1078,16 @@ void MapEntryImpl::emit(emitter::JsonEmitter &jemit) { jemit.endObject(); } +/// Sets the map entry this map exit belongs to. void MapExit::setEntry(MapEntry entry) { ptr->setEntry(entry); } + +/// Emits the map exit to the output stream. void MapExit::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Sets the map entry this map exit belongs to. void MapExitImpl::setEntry(MapEntry entry) { this->entry = entry; } +/// Emits the map exit to the output stream. void MapExitImpl::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); jemit.printKVPair("type", "MapExit"); @@ -903,51 +1108,63 @@ void MapExitImpl::emit(emitter::JsonEmitter &jemit) { // Consume //===----------------------------------------------------------------------===// +/// Sets the consume exit this consume entry belongs to. +void ConsumeEntry::setExit(ConsumeExit exit) { ptr->setExit(exit); } + +/// Returns the matching consume exit. +ConsumeExit ConsumeEntry::getExit() { return ptr->getExit(); } + +/// Adds a connector node to the scope. void ConsumeEntry::addNode(ConnectorNode node) { node.setParent(*this); ptr->addNode(node); } +/// Adds a multiedge from the source to the destination connector. void ConsumeEntry::routeWrite(Connector from, Connector to) { ptr->routeWrite(from, to); } +/// Adds an edge to the scope. void ConsumeEntry::addEdge(MultiEdge edge) { ptr->addEdge(edge); } + +/// Maps the MLIR value to the specified connector. void ConsumeEntry::mapConnector(Value value, Connector connector) { ptr->mapConnector(value, connector); } +/// Returns the connector associated with a MLIR value, inserting consume +/// connectors when needed. Connector ConsumeEntry::lookup(Value value) { return ptr->lookup(value, *this); } +/// Sets the number of processing elements. void ConsumeEntry::setNumPes(StringRef pes) { ptr->setNumPes(pes); } + +/// Sets the name of the processing element index. void ConsumeEntry::setPeIndex(StringRef pe) { ptr->setPeIndex(pe); } + +/// Sets the condition to continue stream consumption. void ConsumeEntry::setCondition(Code condition) { ptr->setCondition(condition); } -void ConsumeEntry::setExit(ConsumeExit exit) { ptr->setExit(exit); } -ConsumeExit ConsumeEntry::getExit() { return ptr->getExit(); } +/// Emits the consume entry to the output stream. void ConsumeEntry::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Sets the consume exit this consume entry belongs to. void ConsumeEntryImpl::setExit(ConsumeExit exit) { this->exit = exit; } + +/// Returns the matching consume exit. ConsumeExit ConsumeEntryImpl::getExit() { return exit; } +/// Adds a connector node to the scope. void ConsumeEntryImpl::addNode(ConnectorNode node) { getParent().getState().addNode(node); } -void ConsumeEntryImpl::setNumPes(StringRef pes) { num_pes = pes.str(); } -void ConsumeEntryImpl::setPeIndex(StringRef pe) { - pe_index = pe.str(); - utils::sanitizeName(pe_index); -} - -void ConsumeEntryImpl::setCondition(Code condition) { - this->condition = condition; -} - +/// Adds a multiedge from the source to the destination connector. void ConsumeEntryImpl::routeWrite(Connector from, Connector to) { ConsumeExit consumeExit = getExit(); Connector in(consumeExit, @@ -969,10 +1186,12 @@ void ConsumeEntryImpl::routeWrite(Connector from, Connector to) { scope.routeWrite(out, to); } +/// Adds an edge to the scope. void ConsumeEntryImpl::addEdge(MultiEdge edge) { getParent().getState().addEdge(edge); } +/// Maps the MLIR value to the specified connector. void ConsumeEntryImpl::mapConnector(Value value, Connector connector) { auto res = lut.insert({utils::valueToString(value), connector}); @@ -980,6 +1199,8 @@ void ConsumeEntryImpl::mapConnector(Value value, Connector connector) { res.first->second = connector; } +/// Returns the connector associated with a MLIR value, inserting consume +/// connectors when needed. Connector ConsumeEntryImpl::lookup(Value value, ConsumeEntry entry) { if (lut.find(utils::valueToString(value)) == lut.end()) { ScopeNode scope(parent); @@ -1003,6 +1224,21 @@ Connector ConsumeEntryImpl::lookup(Value value, ConsumeEntry entry) { return ScopeNodeImpl::lookup(value); } +/// Sets the number of processing elements. +void ConsumeEntryImpl::setNumPes(StringRef pes) { num_pes = pes.str(); } + +/// Sets the name of the processing element index. +void ConsumeEntryImpl::setPeIndex(StringRef pe) { + pe_index = pe.str(); + utils::sanitizeName(pe_index); +} + +/// Sets the condition to continue stream consumption. +void ConsumeEntryImpl::setCondition(Code condition) { + this->condition = condition; +} + +/// Emits the consume entry to the output stream. void ConsumeEntryImpl::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); jemit.printKVPair("type", "ConsumeEntry"); @@ -1033,11 +1269,16 @@ void ConsumeEntryImpl::emit(emitter::JsonEmitter &jemit) { jemit.endObject(); } +/// Sets the consume entry this consume exit belongs to. void ConsumeExit::setEntry(ConsumeEntry entry) { ptr->setEntry(entry); } + +/// Emits the consume exit to the output stream. void ConsumeExit::emit(emitter::JsonEmitter &jemit) { ptr->emit(jemit); } +/// Sets the consume entry this consume exit belongs to. void ConsumeExitImpl::setEntry(ConsumeEntry entry) { this->entry = entry; } +/// Emits the consume exit to the output stream. void ConsumeExitImpl::emit(emitter::JsonEmitter &jemit) { jemit.startObject(); jemit.printKVPair("type", "ConsumeExit"); diff --git a/lib/SDFG/Translate/liftToPython.cpp b/lib/SDFG/Translate/liftToPython.cpp index c82ab3e28..88b1d0e7c 100644 --- a/lib/SDFG/Translate/liftToPython.cpp +++ b/lib/SDFG/Translate/liftToPython.cpp @@ -1,3 +1,8 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains a Python lifter, which lifts MLIR operations to Python +/// code. + #include "SDFG/Translate/liftToPython.h" #include "SDFG/Utils/Utils.h" @@ -5,6 +10,8 @@ using namespace mlir; using namespace sdfg; // TODO(later): Temporary auto-lifting. Will be included into DaCe +/// Converts a single operation to a single line of Python code. If successful, +/// returns Python code as s string. Optional liftOperationToPython(Operation &op, Operation &source) { // FIXME: Support multiple return values if (op.getNumResults() > 1) @@ -90,10 +97,11 @@ Optional liftOperationToPython(Operation &op, Operation &source) { sdfg::utils::valueToString(op.getOperand(0), op) + ")"; } - if (arith::MaxFOp maxFOp = dyn_cast(op)) { + if (isa(op) || isa(op) || + isa(op)) { return nameOut + " = max(" + - sdfg::utils::valueToString(maxFOp.getLhs(), op) + ", " + - sdfg::utils::valueToString(maxFOp.getRhs(), op) + ")"; + sdfg::utils::valueToString(op.getOperand(0), op) + ", " + + sdfg::utils::valueToString(op.getOperand(1), op) + ")"; } if (isa(op) || isa(op)) { @@ -342,7 +350,8 @@ Optional liftOperationToPython(Operation &op, Operation &source) { return std::nullopt; } -// If successful returns Python code as string +/// Converts the operations in the first region of op to Python code. If +/// successful, returns Python code as a string. Optional translation::liftToPython(Operation &op) { std::string code = ""; @@ -360,6 +369,7 @@ Optional translation::liftToPython(Operation &op) { return code; } +/// Provides a name for the tasklet. std::string translation::getTaskletName(Operation &op) { Operation &firstOp = *op.getRegion(0).getOps().begin(); return sdfg::utils::operationToString(firstOp); diff --git a/lib/SDFG/Translate/registration.cpp b/lib/SDFG/Translate/registration.cpp index 35eaec6b1..447148250 100644 --- a/lib/SDFG/Translate/registration.cpp +++ b/lib/SDFG/Translate/registration.cpp @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the translation pass registration. + #include "SDFG/Translate/Translation.h" #include "mlir/InitAllDialects.h" #include "mlir/Tools/mlir-translate/Translation.h" @@ -6,6 +10,7 @@ // SDFG registration //===----------------------------------------------------------------------===// +/// Registers SDFG to SDFG IR translation. void mlir::sdfg::translation::registerToSDFGTranslation() { mlir::TranslateFromMLIRRegistration registration( "mlir-to-sdfg", "Generates a SDFG JSON", diff --git a/lib/SDFG/Translate/translateToSDFG.cpp b/lib/SDFG/Translate/translateToSDFG.cpp index 02e0842be..b69aadfcd 100644 --- a/lib/SDFG/Translate/translateToSDFG.cpp +++ b/lib/SDFG/Translate/translateToSDFG.cpp @@ -1,3 +1,9 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains function to translate the SDFG dialect to the SDFG IR. It +/// performs the translation in two passes. First it collects all operations and +/// generates an internal IR, which in the second pass is used to generate JSON. + #include "SDFG/Translate/Node.h" #include "SDFG/Translate/Translation.h" #include "SDFG/Translate/liftToPython.h" @@ -13,8 +19,11 @@ using namespace sdfg; // Helpers //===----------------------------------------------------------------------===// -void insertTransientArray(Location location, translation::Connector connector, - Value value, translation::ScopeNode &scope) { +/// Inserts a transient array connecting to the provided connector and mapping +/// to the provided value. +static void insertTransientArray(Location location, + translation::Connector connector, Value value, + translation::ScopeNode &scope) { using namespace translation; Array array(sdfg::utils::generateName("tmp"), /*transient=*/true, @@ -46,6 +55,7 @@ void insertTransientArray(Location location, translation::Connector connector, scope.mapConnector(value, accOut); } +/// Collects a operation by performing a case distinction on the operation type. LogicalResult collectOperations(Operation &op, translation::ScopeNode &scope) { using namespace translation; @@ -139,6 +149,7 @@ LogicalResult collectOperations(Operation &op, translation::ScopeNode &scope) { return success(); } +/// Collects all operations in a SDFG. LogicalResult collectSDFG(Operation &op, translation::SDFG &sdfg) { using namespace translation; @@ -200,6 +211,8 @@ LogicalResult collectSDFG(Operation &op, translation::SDFG &sdfg) { // Module //===----------------------------------------------------------------------===// +/// Translates a module containing SDFG dialect to SDFG IR, outputs the result +/// to the provided output stream. LogicalResult translation::translateToSDFG(ModuleOp &op, JsonEmitter &jemit) { if (++op.getOps().begin() != op.getOps().end()) { emitError(op.getLoc(), "Must have exactly one top-level SDFGNode"); @@ -220,6 +233,7 @@ LogicalResult translation::translateToSDFG(ModuleOp &op, JsonEmitter &jemit) { // State //===----------------------------------------------------------------------===// +/// Collects state node information in a top-level SDFG. LogicalResult translation::collect(StateNode &op, SDFG &sdfg) { State state(op.getLoc()); state.setName(op.getName()); @@ -235,6 +249,7 @@ LogicalResult translation::collect(StateNode &op, SDFG &sdfg) { // EdgeOp //===----------------------------------------------------------------------===// +/// Collects edge information in a top-level SDFG. LogicalResult translation::collect(EdgeOp &op, SDFG &sdfg) { Operation *sdfgNode = sdfg::utils::getParentSDFG(*op); @@ -286,6 +301,7 @@ LogicalResult translation::collect(EdgeOp &op, SDFG &sdfg) { // AllocOp //===----------------------------------------------------------------------===// +/// Collects array/stream allocation information in a top-level SDFG. LogicalResult translation::collect(AllocOp &op, SDFG &sdfg) { Array array(op.getContainerName(), op.getTransient(), op.isStream(), sdfg::utils::getSizedType(op.getType())); @@ -294,6 +310,7 @@ LogicalResult translation::collect(AllocOp &op, SDFG &sdfg) { return success(); } +/// Collects array/stream allocation information in a scope. LogicalResult translation::collect(AllocOp &op, ScopeNode &scope) { Array array(op.getContainerName(), op.getTransient(), op.isStream(), sdfg::utils::getSizedType(op.getType())); @@ -306,6 +323,7 @@ LogicalResult translation::collect(AllocOp &op, ScopeNode &scope) { // AllocSymbolOp //===----------------------------------------------------------------------===// +/// Collects symbol allocation information in a top-level SDFG. LogicalResult translation::collect(AllocSymbolOp &op, SDFG &sdfg) { // IDEA: Support other types? Symbol sym(op.getSym(), DType::int64); @@ -317,6 +335,7 @@ LogicalResult translation::collect(AllocSymbolOp &op, SDFG &sdfg) { // TaskletNode //===----------------------------------------------------------------------===// +/// Collects tasklet information in a scope. LogicalResult translation::collect(TaskletNode &op, ScopeNode &scope) { Tasklet tasklet(op.getLoc()); tasklet.setName(getTaskletName(*op)); @@ -347,11 +366,30 @@ LogicalResult translation::collect(TaskletNode &op, ScopeNode &scope) { Code code(nameOut + " = " + nameIn + " ** (1. / 3.)", CodeLanguage::Python); tasklet.setCode(code); - } - - if (operation == "exit") { + } else if (operation == "exit") { Code code("sys.exit()", CodeLanguage::Python); tasklet.setCode(code); + } else { + // TODO: Support inputs & outputs + if (tasklet.getOutConnectorCount() > 0) { + emitError(op.getLoc(), "return types not supported"); + return failure(); + } + + if (tasklet.getInConnectorCount() > 0) { + emitError(op.getLoc(), "input types not supported"); + return failure(); + } + + std::string declString = "extern \\\"C\\\" void " + operation + "();\\n"; + std::string codeString = operation + "();"; + + Code code_global(declString, CodeLanguage::CPP); + Code code(codeString, CodeLanguage::CPP); + + tasklet.setGlobalCode(code_global); + tasklet.setCode(code); + tasklet.setHasSideEffect(true); } } else { @@ -371,6 +409,7 @@ LogicalResult translation::collect(TaskletNode &op, ScopeNode &scope) { // LibCallOp //===----------------------------------------------------------------------===// +/// Collects library call information in a scope. LogicalResult translation::collect(LibCallOp &op, ScopeNode &scope) { Library lib(op.getLoc()); lib.setName(sdfg::utils::generateName(op.getCallee().str())); @@ -399,6 +438,7 @@ LogicalResult translation::collect(LibCallOp &op, ScopeNode &scope) { // NestedSDFGNode //===----------------------------------------------------------------------===// +/// Collects nested SDFG node information in a scope. LogicalResult translation::collect(NestedSDFGNode &op, ScopeNode &scope) { SDFG sdfg(op.getLoc()); @@ -450,6 +490,7 @@ LogicalResult translation::collect(NestedSDFGNode &op, ScopeNode &scope) { // MapNode //===----------------------------------------------------------------------===// +/// Collects map node information in a scope. LogicalResult translation::collect(MapNode &op, ScopeNode &scope) { MapEntry mapEntry(op.getLoc()); mapEntry.setName(sdfg::utils::generateName("mapEntry")); @@ -540,6 +581,7 @@ LogicalResult translation::collect(MapNode &op, ScopeNode &scope) { // ConsumeNode //===----------------------------------------------------------------------===// +/// Collects consume node information in a scope. LogicalResult translation::collect(ConsumeNode &op, ScopeNode &scope) { ConsumeEntry consumeEntry(op.getLoc()); consumeEntry.setName(sdfg::utils::generateName("consumeEntry")); @@ -612,6 +654,7 @@ LogicalResult translation::collect(ConsumeNode &op, ScopeNode &scope) { // CopyOp //===----------------------------------------------------------------------===// +/// Collects copy operation information in a scope. LogicalResult translation::collect(CopyOp &op, ScopeNode &scope) { Access access(op.getLoc()); @@ -639,6 +682,7 @@ LogicalResult translation::collect(CopyOp &op, ScopeNode &scope) { // StoreOp //===----------------------------------------------------------------------===// +/// Collects store operation information in a scope. LogicalResult translation::collect(StoreOp &op, ScopeNode &scope) { std::string name = sdfg::utils::valueToString(op.getArr()); @@ -774,6 +818,7 @@ LogicalResult translation::collect(StoreOp &op, ScopeNode &scope) { // LoadOp //===----------------------------------------------------------------------===// +/// Collects load operation information in a scope. LogicalResult translation::collect(LoadOp &op, ScopeNode &scope) { // TODO: Implement a dce pass if (op.use_empty()) @@ -907,6 +952,7 @@ LogicalResult translation::collect(LoadOp &op, ScopeNode &scope) { // AllocSymbolOp //===----------------------------------------------------------------------===// +/// Collects symbol allocation information in a scope. LogicalResult translation::collect(AllocSymbolOp &op, ScopeNode &scope) { // IDEA: Support other types? Symbol sym(op.getSym(), DType::int64); @@ -919,6 +965,7 @@ LogicalResult translation::collect(AllocSymbolOp &op, ScopeNode &scope) { // SymOp //===----------------------------------------------------------------------===// +/// Collects symbolic expression information in a scope. LogicalResult translation::collect(SymOp &op, ScopeNode &scope) { Tasklet task(op.getLoc()); task.setName("SYM_" + op.getExpr().str()); @@ -939,6 +986,7 @@ LogicalResult translation::collect(SymOp &op, ScopeNode &scope) { // StreamPushOp //===----------------------------------------------------------------------===// +/// Collects stream push operation information in a scope. LogicalResult translation::collect(StreamPushOp &op, ScopeNode &scope) { Access access(op.getLoc()); std::string name = sdfg::utils::valueToString(op.getStr()); @@ -965,6 +1013,7 @@ LogicalResult translation::collect(StreamPushOp &op, ScopeNode &scope) { // StreamPopOp //===----------------------------------------------------------------------===// +/// Collects stream pop operation information in a scope. LogicalResult translation::collect(StreamPopOp &op, ScopeNode &scope) { Connector connector = scope.lookup(op.getStr()); scope.mapConnector(op.getRes(), connector); diff --git a/lib/SDFG/Utils/AttributeToString.cpp b/lib/SDFG/Utils/AttributeToString.cpp index dfcb49278..d13b14995 100644 --- a/lib/SDFG/Utils/AttributeToString.cpp +++ b/lib/SDFG/Utils/AttributeToString.cpp @@ -1,8 +1,13 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the attribute to string utility functions. + #include "SDFG/Utils/AttributeToString.h" #include "SDFG/Utils/Utils.h" namespace mlir::sdfg::utils { +/// Prints an attribute to a string. std::string attributeToString(Attribute attribute, Operation &op) { std::string name; llvm::raw_string_ostream nameStream(name); diff --git a/lib/SDFG/Utils/CMakeLists.txt b/lib/SDFG/Utils/CMakeLists.txt index 939ebc03c..204937c59 100644 --- a/lib/SDFG/Utils/CMakeLists.txt +++ b/lib/SDFG/Utils/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_library( SDFG_UTILS Sanitizer.cpp diff --git a/lib/SDFG/Utils/GetParents.cpp b/lib/SDFG/Utils/GetParents.cpp index fe0884f6f..638edba3b 100644 --- a/lib/SDFG/Utils/GetParents.cpp +++ b/lib/SDFG/Utils/GetParents.cpp @@ -1,9 +1,13 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the parent utility functions. + #include "SDFG/Utils/GetParents.h" namespace mlir::sdfg::utils { -// Returns the parent SDFG node, NestedSDFG node or nullptr if a parent does not -// exist +/// Returns the parent SDFG node, NestedSDFG node or nullptr if a parent does +/// not exist. Operation *getParentSDFG(Operation &op) { Operation *parent = op.getParentOp(); @@ -20,7 +24,7 @@ Operation *getParentSDFG(Operation &op) { return nullptr; } -// Returns the parent State node or nullptr if a parent does not exist +/// Returns the parent State node or nullptr if a parent does not exist. StateNode getParentState(Operation &op, bool ignoreSDFGs) { Operation *parent = op.getParentOp(); @@ -37,4 +41,17 @@ StateNode getParentState(Operation &op, bool ignoreSDFGs) { return nullptr; } +/// Returns top-level module operation or nullptr if a parent does not exist. +ModuleOp getTopModuleOp(Operation *op) { + Operation *parent = op->getParentOp(); + + if (parent == nullptr) + return nullptr; + + if (isa(parent)) + return cast(parent); + + return getTopModuleOp(parent); +} + } // namespace mlir::sdfg::utils diff --git a/lib/SDFG/Utils/GetSizedType.cpp b/lib/SDFG/Utils/GetSizedType.cpp index 07b4f805d..8868af0d8 100644 --- a/lib/SDFG/Utils/GetSizedType.cpp +++ b/lib/SDFG/Utils/GetSizedType.cpp @@ -1,7 +1,12 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the sized type utility functions. + #include "SDFG/Utils/GetSizedType.h" namespace mlir::sdfg::utils { +/// Extracts the sized type from an array or stream type. SizedType getSizedType(Type t) { if (ArrayType arr = t.dyn_cast()) return arr.getDimensions(); @@ -9,11 +14,7 @@ SizedType getSizedType(Type t) { return t.cast().getDimensions(); } -bool isSizedType(Type t) { - if (t.isa() || t.isa()) - return true; - - return false; -} +/// Returns true if the provided type is a sized type. +bool isSizedType(Type t) { return t.isa() || t.isa(); } } // namespace mlir::sdfg::utils diff --git a/lib/SDFG/Utils/IDGenerator.cpp b/lib/SDFG/Utils/IDGenerator.cpp index 327bd2936..f0752cf17 100644 --- a/lib/SDFG/Utils/IDGenerator.cpp +++ b/lib/SDFG/Utils/IDGenerator.cpp @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the ID generator utility functions. + #include "SDFG/Utils/IDGenerator.h" namespace mlir::sdfg::utils { @@ -5,8 +9,10 @@ namespace { unsigned idGeneratorID = 0; } +/// Returns a globally unique ID. unsigned generateID() { return idGeneratorID++; } +/// Resets the ID generator. void resetIDGenerator() { idGeneratorID = 0; } } // namespace mlir::sdfg::utils diff --git a/lib/SDFG/Utils/NameGenerator.cpp b/lib/SDFG/Utils/NameGenerator.cpp index e894a57dd..ff2e7e3af 100644 --- a/lib/SDFG/Utils/NameGenerator.cpp +++ b/lib/SDFG/Utils/NameGenerator.cpp @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the name generator utility functions. + #include "SDFG/Utils/NameGenerator.h" namespace mlir::sdfg::utils { @@ -5,6 +9,7 @@ namespace { int nameGeneratorID = 0; } +/// Converts the provided string to a globally unique one. std::string generateName(std::string base) { return base + "_" + std::to_string(nameGeneratorID++); } diff --git a/lib/SDFG/Utils/OperationToString.cpp b/lib/SDFG/Utils/OperationToString.cpp index 1dc353bee..c6c7c4220 100644 --- a/lib/SDFG/Utils/OperationToString.cpp +++ b/lib/SDFG/Utils/OperationToString.cpp @@ -1,9 +1,14 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the operation to string utility functions. + #include "SDFG/Utils/OperationToString.h" #include "SDFG/Utils/Utils.h" #include "mlir/IR/AsmState.h" namespace mlir::sdfg::utils { +/// Prints an operation to a string. std::string operationToString(Operation &op) { std::string name = op.getName().stripDialect().str(); utils::sanitizeName(name); diff --git a/lib/SDFG/Utils/Sanitizer.cpp b/lib/SDFG/Utils/Sanitizer.cpp index 8eb50036c..2e55ccbfc 100644 --- a/lib/SDFG/Utils/Sanitizer.cpp +++ b/lib/SDFG/Utils/Sanitizer.cpp @@ -1,8 +1,14 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the sanitizer utility functions. + #include "SDFG/Utils/Sanitizer.h" using namespace mlir; using namespace sdfg; +/// Sanitizes the provided string to only include alphanumericals and +/// underscores. void utils::sanitizeName(std::string &name) { for (unsigned i = 0; i < name.size(); ++i) { if (!(name[i] >= 'a' && name[i] <= 'z') && diff --git a/lib/SDFG/Utils/ValueToString.cpp b/lib/SDFG/Utils/ValueToString.cpp index 7337b07a2..354a4a2c7 100644 --- a/lib/SDFG/Utils/ValueToString.cpp +++ b/lib/SDFG/Utils/ValueToString.cpp @@ -1,9 +1,14 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the value to string utility functions. + #include "SDFG/Utils/ValueToString.h" #include "SDFG/Utils/Utils.h" #include "mlir/IR/AsmState.h" namespace mlir::sdfg::utils { +/// Prints a value to a string. Optionally takes a context operation. std::string valueToString(Value value) { if (value.getDefiningOp() != nullptr) return valueToString(value, *value.getDefiningOp()); @@ -11,6 +16,7 @@ std::string valueToString(Value value) { return valueToString(value, *value.getParentBlock()->getParentOp()); } +/// Prints a value to a string. Optionally takes a context operation. std::string valueToString(Value value, Operation &op) { Operation *sdfg; diff --git a/sdfg-opt/CMakeLists.txt b/sdfg-opt/CMakeLists.txt index 15fbf5378..e7920dbd1 100644 --- a/sdfg-opt/CMakeLists.txt +++ b/sdfg-opt/CMakeLists.txt @@ -1,8 +1,16 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -set(LIBS ${dialect_libs} ${conversion_libs} MLIROptLib MLIR_SDFG GenericToSDFG - LinalgToSDFG) +set(LIBS + ${dialect_libs} + ${conversion_libs} + MLIROptLib + MLIR_SDFG + GenericToSDFG + LinalgToSDFG + SDFGToGeneric) add_llvm_executable(sdfg-opt sdfg-opt.cpp) llvm_update_compile_flags(sdfg-opt) diff --git a/sdfg-opt/sdfg-opt.cpp b/sdfg-opt/sdfg-opt.cpp index bad65648f..968fdb1c8 100644 --- a/sdfg-opt/sdfg-opt.cpp +++ b/sdfg-opt/sdfg-opt.cpp @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the SDFG optimizer with the conversion passes. + #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/Pass.h" @@ -7,12 +11,14 @@ #include "SDFG/Conversion/GenericToSDFG/Passes.h" #include "SDFG/Conversion/LinalgToSDFG/Passes.h" +#include "SDFG/Conversion/SDFGToGeneric/Passes.h" #include "SDFG/Dialect/Dialect.h" int main(int argc, char **argv) { // Register SDFG passes mlir::sdfg::conversion::registerGenericToSDFGPasses(); mlir::sdfg::conversion::registerLinalgToSDFGPasses(); + mlir::sdfg::conversion::registerSDFGToGenericPasses(); mlir::DialectRegistry registry; registry.insert(); diff --git a/sdfg-translate/CMakeLists.txt b/sdfg-translate/CMakeLists.txt index 3485bd5d1..349d623dd 100644 --- a/sdfg-translate/CMakeLists.txt +++ b/sdfg-translate/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + set(LLVM_LINK_COMPONENTS Support) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) diff --git a/sdfg-translate/sdfg-translate.cpp b/sdfg-translate/sdfg-translate.cpp index 2ba8d38dc..8f2181dd1 100644 --- a/sdfg-translate/sdfg-translate.cpp +++ b/sdfg-translate/sdfg-translate.cpp @@ -1,3 +1,7 @@ +// Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + +/// This file contains the SDFG translator with the translation passes. + #include "mlir/InitAllTranslations.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" diff --git a/sdfg_gen/gen_all.sh b/sdfg_gen/gen_all.sh deleted file mode 100755 index 99833a02d..000000000 --- a/sdfg_gen/gen_all.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash -cd "$(dirname "$0")" || exit -rm -rf gen/json/* -rm -rf gen/sdfg/* -for f in python/*.py; do python3 "$f"; done diff --git a/sdfg_gen/python/design.py b/sdfg_gen/python/design.py deleted file mode 100644 index 780c83a01..000000000 --- a/sdfg_gen/python/design.py +++ /dev/null @@ -1,219 +0,0 @@ -# This file was used to create the images for the design of SDFG - -import dace -from dace import dtypes -from utils import export_sdfg -import numpy as np - -# State only -sdfg = dace.SDFG("design") -state = sdfg.add_state() -export_sdfg(sdfg, "state") - -# Simple movement -A = state.add_read('A') -C = state.add_write('C') -e1 = state.add_edge(A, None, C, None, dace.Memlet('A[0]')) - -export_sdfg(sdfg, "simple_movement") - -# Tasklet -state.remove_node(A) -state.remove_node(C) - -tasklet = state.add_tasklet(name='add', - inputs={'a', 'b'}, - outputs={'c', 'd'}, - code='c = a + b\nd = a - b', - language=dace.Language.Python) - -export_sdfg(sdfg, "tasklet") - -# Parallel Memlets -state.remove_node(tasklet) - -A = state.add_read('A') -B = state.add_read('B') - -C = state.add_write('C') -D = state.add_write('D') - -e1 = state.add_edge(A, None, C, None, dace.Memlet('A[0]')) -e2 = state.add_edge(B, None, D, None, dace.Memlet('B[0]')) - -export_sdfg(sdfg, "parallel") - -# Memlets -state.remove_edge(e1) -state.remove_edge(e2) -state.remove_node(D) -export_sdfg(sdfg, "memlets") - -# Full Graph -tasklet = state.add_tasklet(name='add', - inputs={'a', 'b'}, - outputs={'c'}, - code='c = a + b', - language=dace.Language.Python) - -state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0]')) -state.add_edge(B, None, tasklet, 'b', dace.Memlet('B[0]')) -state.add_edge(tasklet, 'c', C, None, dace.Memlet('C[0]')) - -export_sdfg(sdfg, "full") - -# Self-write -sdfg = dace.SDFG("design") -state = sdfg.add_state() - -A = state.add_read('A') -A2 = state.add_write('A') - -tasklet = state.add_tasklet(name='add', - inputs={'a'}, - outputs={'a_1'}, - code='a_1 = a + 1', - language=dace.Language.Python) - -state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0]')) -state.add_edge(tasklet, 'a_1', A2, None, dace.Memlet('A[0]')) -export_sdfg(sdfg, "self-write") - -# data race -sdfg = dace.SDFG("design") -state = sdfg.add_state() - -A = state.add_read('A') -A2 = state.add_write('A') -B = state.add_read('B') -C = state.add_write('C') - -tasklet = state.add_tasklet(name='add', - inputs={'a'}, - outputs={'c'}, - code='c = a + 1', - language=dace.Language.Python) - -tasklet2 = state.add_tasklet(name='add', - inputs={'b'}, - outputs={'a'}, - code='a = b + 1', - language=dace.Language.Python) - -state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0]')) -state.add_edge(tasklet, 'c', C, None, dace.Memlet('C[0]')) -state.add_edge(B, None, tasklet2, 'b', dace.Memlet('B[0]')) -state.add_edge(tasklet2, 'a', A2, None, dace.Memlet('A[0]')) -export_sdfg(sdfg, "data_race") - -# Map -sdfg = dace.SDFG("design") -state = sdfg.add_state() -A = state.add_read('A') -B = state.add_read('B') -C = state.add_write('C') - -tasklet, map_entry, map_exit = state.add_mapped_tasklet( - name='add', - map_ranges=dict(i='0:2', j='0:2'), - inputs=dict(a=dace.Memlet('A[i, j]'), b=dace.Memlet('B[i, j]')), - code='c = a + b', - outputs=dict(c=dace.Memlet('C[i, j]')) -) - -state.add_edge(A, None, map_entry, None, memlet=dace.Memlet('A[0:2,0:2]')) -state.add_edge(B, None, map_entry, None, memlet=dace.Memlet('B[0:2,0:2]')) -state.add_edge(map_exit, None, C, None, memlet=dace.Memlet('C[0:2,0:2]')) - -sdfg.fill_scope_connectors() -export_sdfg(sdfg, "map") - -# nested & lib -@dace.program -def design_mmm(A, B): - return A @ B - -@dace.program -def design_nested(A, B): - C = design_mmm(A, B) - -a = np.random.rand(2, 2) -sdfg = design_nested.to_sdfg(a, a) -export_sdfg(sdfg) - -sdfg = design_mmm.to_sdfg(a, a) -export_sdfg(sdfg) - -# Symbols -sdfg = dace.SDFG("design") -state = sdfg.add_state() -A = state.add_read('A') -C = state.add_write('C') - -tasklet, map_entry, map_exit = state.add_mapped_tasklet( - name='add', - map_ranges=dict(i='0:N'), - inputs=dict(a=dace.Memlet('A[i]')), - code='c = a + 1', - outputs=dict(c=dace.Memlet('C[i]')) -) - -state.add_edge(A, None, map_entry, None, memlet=dace.Memlet('A[0:N]')) -state.add_edge(map_exit, None, C, None, memlet=dace.Memlet('C[0:N]')) - -sdfg.fill_scope_connectors() -export_sdfg(sdfg, "sym") - -# Multistate -sdfg = dace.SDFG("design") -state = sdfg.add_state(is_start_state=True) -state2 = sdfg.add_state_after(state) -state3 = sdfg.add_state_after(state2) -state4 = sdfg.add_state() - - -sdfg.add_edge(state2, state4, dace.InterstateEdge()) - -export_sdfg(sdfg, "multistate") - -# Streams -sdfg = dace.SDFG("design") -state = sdfg.add_state() -A = state.add_stream('A', dtypes.int32) -C = state.add_write('C') - -consEntry, consExit = state.add_consume("add_one", ("p","P"), "len(A) == 0") - -tasklet = state.add_tasklet(name='add', - inputs={'a'}, - outputs={'c'}, - code='c = a + 1', - language=dace.Language.Python) - -state.add_edge(A, None, consEntry, None, memlet=dace.Memlet('A[0:2]')) -state.add_edge(consEntry, None, tasklet, None, memlet=dace.Memlet('A[p]')) -state.add_edge(tasklet, None, consExit, None, memlet=dace.Memlet('C[p]')) -state.add_edge(consExit, None, C, None, memlet=dace.Memlet('C[0:2]')) - -export_sdfg(sdfg, "stream") - -@dace.program -def design_ex(in1: dtypes.vector(dtypes.int32,5)): - sum = 0 - - for i in range(5): - sum = sum + in1[i] - - return sum - -sdfg = design_ex.to_sdfg() -export_sdfg(sdfg) - -''' -@dace.program -def design_fail(in1: dtypes.vector(dtypes.int32,5)): - return in1[0] + in1[1] - -sdfg = design_fail.to_sdfg() -export_sdfg(sdfg) -''' diff --git a/sdfg_gen/python/nested.py b/sdfg_gen/python/nested.py deleted file mode 100644 index 2ec1e1ab0..000000000 --- a/sdfg_gen/python/nested.py +++ /dev/null @@ -1,26 +0,0 @@ -# Taken and modified from: https://github.com/spcl/dace/blob/master/tests/nest_subgraph_test.py - -import dace -from dace.sdfg.nodes import MapEntry, Tasklet -from dace.sdfg.graph import NodeNotFoundError, SubgraphView -from dace.transformation.helpers import nest_state_subgraph -from dace.transformation.dataflow import tiling - -def create_sdfg(): - sdfg = dace.SDFG('badscope_test') - sdfg.add_array('A', [2], dace.float32) - sdfg.add_array('B', [2], dace.float32) - state = sdfg.add_state() - t, me, mx = state.add_mapped_tasklet('map', - dict(i='0:2'), - dict(a=dace.Memlet.simple('A', 'i')), - 'b = a * 2', - dict(b=dace.Memlet.simple('B', 'i')), - external_edges=True) - return sdfg, state, t, me, mx - -sdfg, state, t, me, mx = create_sdfg() -nest_state_subgraph(sdfg, state, SubgraphView(state, [t])) - -from utils import export_sdfg -export_sdfg(sdfg) diff --git a/sdfg_gen/python/single_empty_state.py b/sdfg_gen/python/single_empty_state.py deleted file mode 100644 index 3b5060fb1..000000000 --- a/sdfg_gen/python/single_empty_state.py +++ /dev/null @@ -1,7 +0,0 @@ -from dace import SDFG - -sdfg = SDFG("single_empty_state") -state = sdfg.add_state() - -from utils import export_sdfg -export_sdfg(sdfg) \ No newline at end of file diff --git a/sdfg_gen/python/utils.py b/sdfg_gen/python/utils.py deleted file mode 100644 index 253494405..000000000 --- a/sdfg_gen/python/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -def export_sdfg(sdfg, name=None): - if name is None: name = sdfg.name - else: name = sdfg.name + "_" + name - - out_file_json = "gen/json/" + name + ".json" - out_file_sdfg = "gen/sdfg/" + name + ".sdfg" - - sdfg.save(filename=out_file_json) - sdfg.save(filename=out_file_sdfg) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c4feb2a26..a26945b68 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1 +1,3 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_subdirectory(SDFG) diff --git a/test/SDFG/CMakeLists.txt b/test/SDFG/CMakeLists.txt index 7ec91009c..cf6e01f2c 100644 --- a/test/SDFG/CMakeLists.txt +++ b/test/SDFG/CMakeLists.txt @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + add_subdirectory(Dialect) add_subdirectory(Translate) add_subdirectory(Converter) diff --git a/test/SDFG/Converter/CMakeLists.txt b/test/SDFG/Converter/CMakeLists.txt index 2ca8590a6..39ec50590 100644 --- a/test/SDFG/Converter/CMakeLists.txt +++ b/test/SDFG/Converter/CMakeLists.txt @@ -1,17 +1,15 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + configure_lit_site_cfg( - ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in - ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py - MAIN_CONFIG - ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py -) + ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in + ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py MAIN_CONFIG + ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py) set(SDFG_TEST_DEPENDS FileCheck count not sdfg-opt) -add_lit_testsuite( - check-sdfg-converter "Running the sdfg conversion tests" - ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${SDFG_TEST_DEPENDS} -) +add_lit_testsuite(check-sdfg-converter "Running the sdfg conversion tests" + ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${SDFG_TEST_DEPENDS}) set_target_properties(check-sdfg-converter PROPERTIES FOLDER "Tests") -add_lit_testsuites(SDFG-CONVERTER ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${SDFG_TEST_DEPENDS}) +add_lit_testsuites(SDFG-CONVERTER ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS + ${SDFG_TEST_DEPENDS}) diff --git a/test/SDFG/Converter/converter.mlir b/test/SDFG/Converter/converter.mlir index b371ba7c2..f61662b31 100644 --- a/test/SDFG/Converter/converter.mlir +++ b/test/SDFG/Converter/converter.mlir @@ -1,3 +1,4 @@ // RUN: sdfg-opt --help | FileCheck %s // CHECK: --convert-to-sdfg // CHECK: --linalg-to-sdfg +// CHECK: --lower-sdfg diff --git a/test/SDFG/Converter/fromSDFG/map/add.mlir b/test/SDFG/Converter/fromSDFG/map/add.mlir new file mode 100644 index 000000000..0e92ad0a5 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/map/add.mlir @@ -0,0 +1,21 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array<2x6xi32> + %B = sdfg.alloc() : !sdfg.array<2x6xi32> + %C = sdfg.alloc() : !sdfg.array<2x6xi32> + + sdfg.state @state_0 { + sdfg.map (%i, %j) = (0, 0) to (2, 2) step (1, 1) { + %a_ij = sdfg.load %A[%i, %j] : !sdfg.array<2x6xi32> -> i32 + %b_ij = sdfg.load %B[%i, %j] : !sdfg.array<2x6xi32> -> i32 + + %res = sdfg.tasklet(%a_ij: i32, %b_ij: i32) -> (i32) { + %z = arith.addi %a_ij, %b_ij : i32 + sdfg.return %z : i32 + } + + sdfg.store %res, %C[%i, %j] : i32 -> !sdfg.array<2x6xi32> + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/map/add_3dim.mlir b/test/SDFG/Converter/fromSDFG/map/add_3dim.mlir new file mode 100644 index 000000000..3628711b8 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/map/add_3dim.mlir @@ -0,0 +1,21 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array<2x6x8xi32> + %B = sdfg.alloc() : !sdfg.array<2x6x8xi32> + %C = sdfg.alloc() : !sdfg.array<2x6x8xi32> + + sdfg.state @state_0 { + sdfg.map (%i, %j, %g) = (0, 0, 0) to (2, 2, 2) step (1, 1, 1) { + %a_ijg = sdfg.load %A[%i, %j, %g] : !sdfg.array<2x6x8xi32> -> i32 + %b_ijg = sdfg.load %B[%i, %j, %g] : !sdfg.array<2x6x8xi32> -> i32 + + %res = sdfg.tasklet(%a_ijg: i32, %b_ijg: i32) -> (i32) { + %z = arith.addi %a_ijg, %b_ijg : i32 + sdfg.return %z : i32 + } + + sdfg.store %res, %C[%i, %j, %g] : i32 -> !sdfg.array<2x6x8xi32> + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/map/add_nested.mlir b/test/SDFG/Converter/fromSDFG/map/add_nested.mlir new file mode 100644 index 000000000..80b927ec3 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/map/add_nested.mlir @@ -0,0 +1,23 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array<2x6xi32> + %B = sdfg.alloc() : !sdfg.array<2x6xi32> + %C = sdfg.alloc() : !sdfg.array<2x6xi32> + + sdfg.state @state_0 { + sdfg.map (%i) = (0) to (2) step (1) { + sdfg.map (%j) = (0) to (2) step (1) { + %a_ij = sdfg.load %A[%i, %j] : !sdfg.array<2x6xi32> -> i32 + %b_ij = sdfg.load %B[%i, %j] : !sdfg.array<2x6xi32> -> i32 + + %res = sdfg.tasklet(%a_ij: i32, %b_ij: i32) -> (i32) { + %z = arith.addi %a_ij, %b_ij : i32 + sdfg.return %z : i32 + } + + sdfg.store %res, %C[%i, %j] : i32 -> !sdfg.array<2x6xi32> + } + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/map/consume_map.mlir b/test/SDFG/Converter/fromSDFG/map/consume_map.mlir new file mode 100644 index 000000000..f53e1fdc6 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/map/consume_map.mlir @@ -0,0 +1,19 @@ +// XFAIL: * +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> () { + %A = sdfg.alloc() : !sdfg.stream + %C = sdfg.alloc() : !sdfg.array + %B = sdfg.alloc() : !sdfg.array + + sdfg.state @state_0 { + sdfg.consume{num_pes=5} (%A : !sdfg.stream) -> (pe: %p, elem: %e) { + sdfg.map (%i) = (0) to (2) step (1) { + %res = sdfg.tasklet(%e: i32) -> (i32) { + sdfg.return %e : i32 + } + sdfg.store{wcr="add"} %res, %C[] : i32 -> !sdfg.array + } + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/map/dynamic.mlir b/test/SDFG/Converter/fromSDFG/map/dynamic.mlir new file mode 100644 index 000000000..1ba6a1246 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/map/dynamic.mlir @@ -0,0 +1,23 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array<2x6xi32> + %B = sdfg.alloc() : !sdfg.array<2x6xi32> + %C = sdfg.alloc() : !sdfg.array<2x6xi32> + + sdfg.state @state_0 { + %c = sdfg.load %r[] : !sdfg.array -> index + + sdfg.map (%i, %j) = (%c, %c) to (%c, %c) step (1, 1) { + %a_ij = sdfg.load %A[0, 0] : !sdfg.array<2x6xi32> -> i32 + %b_ij = sdfg.load %B[0, 0] : !sdfg.array<2x6xi32> -> i32 + + %res = sdfg.tasklet(%a_ij: i32, %b_ij: i32) -> (i32) { + %z = arith.addi %a_ij, %b_ij : i32 + sdfg.return %z : i32 + } + + sdfg.store %res, %C[0, 0] : i32 -> !sdfg.array<2x6xi32> + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/map/empty.mlir b/test/SDFG/Converter/fromSDFG/map/empty.mlir new file mode 100644 index 000000000..675a77232 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/map/empty.mlir @@ -0,0 +1,10 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array<2x6xi32> + + sdfg.state @state_0 { + sdfg.map (%i, %j) = (0, 0) to (2, 2) step (1, 1) { + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/map/using_index_tasklet.mlir b/test/SDFG/Converter/fromSDFG/map/using_index_tasklet.mlir new file mode 100644 index 000000000..52d1f7d70 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/map/using_index_tasklet.mlir @@ -0,0 +1,18 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> () { + %C = sdfg.alloc() : !sdfg.array<2xi32> + + sdfg.state @state_0 { + sdfg.map (%i) = (0) to (2) step (1) { + %a = sdfg.load %C[%i] : !sdfg.array<2xi32> -> i32 + + %res = sdfg.tasklet(%a: i32) -> (index) { + %0 = arith.index_cast %a : i32 to index + sdfg.return %0 : index + } + + sdfg.store{wcr="add"} %a, %C[%res] : i32 -> !sdfg.array<2xi32> + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/alloc.mlir b/test/SDFG/Converter/fromSDFG/memlet/alloc.mlir new file mode 100644 index 000000000..5cce88b71 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/alloc.mlir @@ -0,0 +1,8 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %a = sdfg.alloc() : !sdfg.array + + sdfg.state @state_0{ + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/alloc_in_state.mlir b/test/SDFG/Converter/fromSDFG/memlet/alloc_in_state.mlir new file mode 100644 index 000000000..274bc5540 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/alloc_in_state.mlir @@ -0,0 +1,7 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0 { + %a = sdfg.alloc() : !sdfg.array + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/alloc_param.mlir b/test/SDFG/Converter/fromSDFG/memlet/alloc_param.mlir new file mode 100644 index 000000000..9e2c86e5e --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/alloc_param.mlir @@ -0,0 +1,17 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0{ + %n = sdfg.tasklet() -> (index) { + %5 = arith.constant 5 : index + sdfg.return %5 : index + } + + %m = sdfg.tasklet() -> (index) { + %20 = arith.constant 20 : index + sdfg.return %20 : index + } + + %a = sdfg.alloc(%n, %m) : !sdfg.array + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/alloc_param_mixed.mlir b/test/SDFG/Converter/fromSDFG/memlet/alloc_param_mixed.mlir new file mode 100644 index 000000000..029f889ba --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/alloc_param_mixed.mlir @@ -0,0 +1,17 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0{ + %n = sdfg.tasklet() -> (index) { + %5 = arith.constant 5 : index + sdfg.return %5 : index + } + + %m = sdfg.tasklet() -> (index) { + %20 = arith.constant 20 : index + sdfg.return %20 : index + } + + %a = sdfg.alloc(%n, %m) : !sdfg.array + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/alloc_shaped.mlir b/test/SDFG/Converter/fromSDFG/memlet/alloc_shaped.mlir new file mode 100644 index 000000000..9aa5d8d53 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/alloc_shaped.mlir @@ -0,0 +1,8 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %a = sdfg.alloc() : !sdfg.array<23x45x123xi32> + + sdfg.state @state_0{ + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/alloc_transient.mlir b/test/SDFG/Converter/fromSDFG/memlet/alloc_transient.mlir new file mode 100644 index 000000000..6c934c3f5 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/alloc_transient.mlir @@ -0,0 +1,8 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0 { + %A = sdfg.alloc{transient}() : !sdfg.array + } +} + diff --git a/test/SDFG/Converter/fromSDFG/memlet/alloc_transient_param.mlir b/test/SDFG/Converter/fromSDFG/memlet/alloc_transient_param.mlir new file mode 100644 index 000000000..802e42417 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/alloc_transient_param.mlir @@ -0,0 +1,17 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0{ + %n = sdfg.tasklet() -> (index) { + %5 = arith.constant 5 : index + sdfg.return %5 : index + } + + %m = sdfg.tasklet() -> (index) { + %20 = arith.constant 20 : index + sdfg.return %20 : index + } + + %a = sdfg.alloc{transient}(%n, %m) : !sdfg.array + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/alloc_transient_param_mixed.mlir b/test/SDFG/Converter/fromSDFG/memlet/alloc_transient_param_mixed.mlir new file mode 100644 index 000000000..ca5a9d66e --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/alloc_transient_param_mixed.mlir @@ -0,0 +1,17 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0{ + %n = sdfg.tasklet() -> (index) { + %5 = arith.constant 5 : index + sdfg.return %5 : index + } + + %m = sdfg.tasklet() -> (index) { + %20 = arith.constant 20 : index + sdfg.return %20 : index + } + + %a = sdfg.alloc{transient}(%n, %m) : !sdfg.array + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/constant.mlir b/test/SDFG/Converter/fromSDFG/memlet/constant.mlir new file mode 100644 index 000000000..dad86786a --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/constant.mlir @@ -0,0 +1,13 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg{entry = @state_1} () -> (%arg1: !sdfg.array) { + sdfg.state @state_1 { + %0 = sdfg.tasklet() -> (i32) { + %c0 = arith.constant 0 : i32 + sdfg.return %c0 : i32 + } + + sdfg.store %0, %arg1[0, 0] : i32 -> !sdfg.array + } +} + diff --git a/test/SDFG/Converter/fromSDFG/memlet/copy.mlir b/test/SDFG/Converter/fromSDFG/memlet/copy.mlir new file mode 100644 index 000000000..22d48addf --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/copy.mlir @@ -0,0 +1,9 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array + %B = sdfg.alloc() : !sdfg.array + sdfg.state @state_0 { + sdfg.copy %A -> %B : !sdfg.array + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/indirect.mlir b/test/SDFG/Converter/fromSDFG/memlet/indirect.mlir new file mode 100644 index 000000000..294b4c403 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/indirect.mlir @@ -0,0 +1,15 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg{entry = @state_1} (%arg0: index) -> (%arg1: !sdfg.array){ + sdfg.state @state_1 { + %n = sdfg.load %arg1[%arg0, %arg0] : !sdfg.array -> i32 + + %0 = sdfg.tasklet(%n: i32) -> (i32) { + %c0 = arith.addi %n, %n : i32 + sdfg.return %c0 : i32 + } + + sdfg.store %0, %arg1[%arg0, %arg0] : i32 -> !sdfg.array + } +} + diff --git a/test/SDFG/Converter/fromSDFG/memlet/load.mlir b/test/SDFG/Converter/fromSDFG/memlet/load.mlir new file mode 100644 index 000000000..bc42690fd --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/load.mlir @@ -0,0 +1,10 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array + + sdfg.state @state_0 { + %a_1 = sdfg.load %A[] : !sdfg.array -> i32 + sdfg.store %a_1, %A[] : i32 -> !sdfg.array + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/load_multiple_indices.mlir b/test/SDFG/Converter/fromSDFG/memlet/load_multiple_indices.mlir new file mode 100644 index 000000000..7a5b08d79 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/load_multiple_indices.mlir @@ -0,0 +1,16 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array<12x45xi32> + + sdfg.state @state_0 { + %a_1 = sdfg.load %A[6, 12] : !sdfg.array<12x45xi32> -> i32 + + %res = sdfg.tasklet(%a_1: i32) -> (i32) { + %z = arith.addi %a_1, %a_1 : i32 + sdfg.return %z : i32 + } + + sdfg.store %res, %A[6, 12] : i32 -> !sdfg.array<12x45xi32> + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/single_dim.mlir b/test/SDFG/Converter/fromSDFG/memlet/single_dim.mlir new file mode 100644 index 000000000..7abca3f12 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/single_dim.mlir @@ -0,0 +1,35 @@ +// RUN: sdfg-opt --lower-sdfg %s +module { + sdfg.sdfg {entry = @init_0} () -> (%arg0: !sdfg.array){ + %0 = sdfg.alloc {name = "_load_tmp_5", transient} () : !sdfg.array + %1 = sdfg.alloc {name = "_alloc_tmp_3", transient} () : !sdfg.array<1xi32> + %2 = sdfg.alloc {name = "_constant_tmp_2", transient} () : !sdfg.array + sdfg.state @init_0{ + } + sdfg.state @constant_1{ + %3 = sdfg.tasklet () -> (index){ + %c0 = arith.constant 0 : index + sdfg.return %c0 : index + } + sdfg.store %3, %2[] : index -> !sdfg.array + %4 = sdfg.load %2[] : !sdfg.array -> index + } + sdfg.state @alloc_init_4{ + } + sdfg.state @load_6{ + %3 = sdfg.load %2[] : !sdfg.array -> index + %4 = sdfg.load %1[%3] : !sdfg.array<1xi32> -> i32 + sdfg.store %4, %0[] : i32 -> !sdfg.array + %5 = sdfg.load %0[] : !sdfg.array -> i32 + } + sdfg.state @return_7{ + %3 = sdfg.load %0[] : !sdfg.array -> i32 + sdfg.store %3, %arg0[] : i32 -> !sdfg.array + } + sdfg.edge {assign = [], condition = "1"} @init_0 -> @constant_1 + sdfg.edge {assign = [], condition = "1"} @constant_1 -> @alloc_init_4 + sdfg.edge {assign = [], condition = "1"} @alloc_init_4 -> @load_6 + sdfg.edge {assign = [], condition = "1"} @load_6 -> @return_7 + } +} + diff --git a/test/SDFG/Converter/fromSDFG/memlet/store.mlir b/test/SDFG/Converter/fromSDFG/memlet/store.mlir new file mode 100644 index 000000000..a896b3c02 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/store.mlir @@ -0,0 +1,14 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array + + sdfg.state @state_0 { + %1 = sdfg.tasklet() -> (i32) { + %1 = arith.constant 1 : i32 + sdfg.return %1 : i32 + } + + sdfg.store %1, %A[] : i32 -> !sdfg.array + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/store_multiple_indices.mlir b/test/SDFG/Converter/fromSDFG/memlet/store_multiple_indices.mlir new file mode 100644 index 000000000..fe4f44189 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/store_multiple_indices.mlir @@ -0,0 +1,19 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array<56x45xi32> + + sdfg.state @state_0 { + %0 = sdfg.tasklet() -> (index) { + %0 = arith.constant 0 : index + sdfg.return %0 : index + } + + %1 = sdfg.tasklet() -> (i32) { + %1 = arith.constant 1 : i32 + sdfg.return %1 : i32 + } + + sdfg.store %1, %A[%0, %0] : i32 -> !sdfg.array<56x45xi32> + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/subview.mlir b/test/SDFG/Converter/fromSDFG/memlet/subview.mlir new file mode 100644 index 000000000..3b3f01018 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/subview.mlir @@ -0,0 +1,10 @@ +// XFAIL: * +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array<8x16x4xi32> + + sdfg.state @state_0 { + %a_s = sdfg.subview %A[3, 4, 2][1, 6, 3][1, 1, 1] : !sdfg.array<8x16x4xi32> -> !sdfg.array<6x3xi32> + } +} diff --git a/test/SDFG/Converter/fromSDFG/memlet/view_cast.mlir b/test/SDFG/Converter/fromSDFG/memlet/view_cast.mlir new file mode 100644 index 000000000..ba1d52a6b --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/memlet/view_cast.mlir @@ -0,0 +1,10 @@ +// XFAIL: * +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array<2x12xi32> + + sdfg.state @state_0 { + %b = sdfg.view_cast %A : !sdfg.array<2x12xi32> -> !sdfg.array<2x12xi32> + } +} diff --git a/test/SDFG/Converter/fromSDFG/sdfg/args.mlir b/test/SDFG/Converter/fromSDFG/sdfg/args.mlir new file mode 100644 index 000000000..11eabc4e2 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/sdfg/args.mlir @@ -0,0 +1,14 @@ +// RUN: sdfg-opt --lower-sdfg %s + +module { + sdfg.sdfg () -> (%arg0: !sdfg.array) { + sdfg.state @state_0 { + %0 = sdfg.tasklet() -> (i32) { + %c42_i32 = arith.constant 42 : i32 + sdfg.return %c42_i32 : i32 + } + + sdfg.store %0, %arg0[] : i32 -> !sdfg.array + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/sdfg/assign_list.mlir b/test/SDFG/Converter/fromSDFG/sdfg/assign_list.mlir new file mode 100644 index 000000000..4817b91a5 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/sdfg/assign_list.mlir @@ -0,0 +1,7 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0{} + sdfg.state @state_1{} + sdfg.edge{assign=["i: 1", "j: 5"]} @state_0 -> @state_1 +} diff --git a/test/SDFG/Converter/fromSDFG/sdfg/call.mlir b/test/SDFG/Converter/fromSDFG/sdfg/call.mlir new file mode 100644 index 000000000..6936181e1 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/sdfg/call.mlir @@ -0,0 +1,12 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0 { + %N = sdfg.alloc() : !sdfg.array + %M = sdfg.alloc() : !sdfg.array + + sdfg.nested_sdfg{entry=@state_1} (%N: !sdfg.array) -> (%M: !sdfg.array) { + sdfg.state @state_1 {} + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/sdfg/edge_ref.mlir b/test/SDFG/Converter/fromSDFG/sdfg/edge_ref.mlir new file mode 100644 index 000000000..694fe0232 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/sdfg/edge_ref.mlir @@ -0,0 +1,13 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg (%arg0: index) -> (%r: !sdfg.array){ + sdfg.state @state_1 {} + sdfg.state @state_2 {} + sdfg.state @state_3 {} + sdfg.state @state_4 {} + + sdfg.edge {assign = ["idx: ref"]} (ref: %arg0: index) @state_1 -> @state_2 + sdfg.edge {condition = "idx < ref"} (ref: %arg0: index) @state_2 -> @state_3 + sdfg.edge {condition = "not(idx < ref)"} (ref: %arg0: index) @state_2 -> @state_4 +} + diff --git a/test/SDFG/Converter/fromSDFG/sdfg/missing_assign_condition.mlir b/test/SDFG/Converter/fromSDFG/sdfg/missing_assign_condition.mlir new file mode 100644 index 000000000..4b1c67a2d --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/sdfg/missing_assign_condition.mlir @@ -0,0 +1,7 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0 {} + sdfg.state @state_1 {} + sdfg.edge @state_0 -> @state_1 +} diff --git a/test/SDFG/Converter/fromSDFG/sdfg/missing_condition.mlir b/test/SDFG/Converter/fromSDFG/sdfg/missing_condition.mlir new file mode 100644 index 000000000..acbda5dde --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/sdfg/missing_condition.mlir @@ -0,0 +1,7 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0 {} + sdfg.state @state_1 {} + sdfg.edge{assign=["i: 1"]} @state_0 -> @state_1 +} diff --git a/test/SDFG/Converter/fromSDFG/sdfg/nested.mlir b/test/SDFG/Converter/fromSDFG/sdfg/nested.mlir new file mode 100644 index 000000000..e1f9a37f8 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/sdfg/nested.mlir @@ -0,0 +1,10 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0{ + sdfg.nested_sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_1{ + } + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/sdfg/nested_alloc.mlir b/test/SDFG/Converter/fromSDFG/sdfg/nested_alloc.mlir new file mode 100644 index 000000000..e2be0f000 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/sdfg/nested_alloc.mlir @@ -0,0 +1,21 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0{ + sdfg.nested_sdfg () -> (%r: !sdfg.array) { + %3 = sdfg.alloc{transient}() : !sdfg.array + + sdfg.state @state_1{ + %0 = sdfg.load %r[] : !sdfg.array -> i32 + sdfg.store %0, %3[] : i32 -> !sdfg.array + } + + sdfg.state @state_2{ + %0 = sdfg.load %3[] : !sdfg.array -> i32 + sdfg.store %0, %r[] : i32 -> !sdfg.array + } + + sdfg.edge {assign = [], condition = "1"} @state_1 -> @state_2 + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/sdfg/nested_args.mlir b/test/SDFG/Converter/fromSDFG/sdfg/nested_args.mlir new file mode 100644 index 000000000..cf1955c3f --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/sdfg/nested_args.mlir @@ -0,0 +1,12 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0{ + sdfg.nested_sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_1{ + %0 = sdfg.load %r[] : !sdfg.array -> i32 + sdfg.store %0, %r[] : i32 -> !sdfg.array + } + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/sdfg/two_states.mlir b/test/SDFG/Converter/fromSDFG/sdfg/two_states.mlir new file mode 100644 index 000000000..e228ad935 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/sdfg/two_states.mlir @@ -0,0 +1,8 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg{entry=@state_0} () -> (%r: !sdfg.array) { + sdfg.state @state_1 {} + sdfg.state @state_0 {} + + sdfg.edge{assign=["i: 1"]} @state_0 -> @state_1 +} diff --git a/test/SDFG/Converter/fromSDFG/state/attributes.mlir b/test/SDFG/Converter/fromSDFG/state/attributes.mlir new file mode 100644 index 000000000..a9e73a3ce --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/state/attributes.mlir @@ -0,0 +1,5 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state {nosync=false, instrument="No_Instrumentation"} @state_0 {} +} diff --git a/test/SDFG/Converter/fromSDFG/stream/alloc_stream.mlir b/test/SDFG/Converter/fromSDFG/stream/alloc_stream.mlir new file mode 100644 index 000000000..f5cc5b185 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/stream/alloc_stream.mlir @@ -0,0 +1,7 @@ +// XFAIL: * +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %a = sdfg.alloc() : !sdfg.stream + sdfg.state @state_0 {} +} diff --git a/test/SDFG/Converter/fromSDFG/stream/alloc_stream_shaped.mlir b/test/SDFG/Converter/fromSDFG/stream/alloc_stream_shaped.mlir new file mode 100644 index 000000000..810ed6a13 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/stream/alloc_stream_shaped.mlir @@ -0,0 +1,7 @@ +// XFAIL: * +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %a = sdfg.alloc() : !sdfg.stream<67x45xi32> + sdfg.state @state_0 {} +} diff --git a/test/SDFG/Converter/fromSDFG/stream/alloc_transient_stream.mlir b/test/SDFG/Converter/fromSDFG/stream/alloc_transient_stream.mlir new file mode 100644 index 000000000..314022430 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/stream/alloc_transient_stream.mlir @@ -0,0 +1,9 @@ +// XFAIL: * +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0 { + %A = sdfg.alloc{transient}() : !sdfg.stream + } +} + diff --git a/test/SDFG/Converter/fromSDFG/stream/stream_pop_push.mlir b/test/SDFG/Converter/fromSDFG/stream/stream_pop_push.mlir new file mode 100644 index 000000000..37e87818c --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/stream/stream_pop_push.mlir @@ -0,0 +1,18 @@ +// XFAIL: * +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.stream + + sdfg.state @state_0 { + %1 = sdfg.tasklet() -> (i32) { + %1 = arith.constant 1 : i32 + sdfg.return %1 : i32 + } + + sdfg.stream_push %1, %A : i32 -> !sdfg.stream + %a_1 = sdfg.stream_pop %A : !sdfg.stream -> i32 + + sdfg.store %a_1, %r[] : i32 -> !sdfg.array + } +} diff --git a/test/SDFG/Converter/fromSDFG/symbol/alloc.mlir b/test/SDFG/Converter/fromSDFG/symbol/alloc.mlir new file mode 100644 index 000000000..ffbe1f5bc --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/symbol/alloc.mlir @@ -0,0 +1,7 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0 { + sdfg.alloc_symbol("N4") + } +} diff --git a/test/SDFG/Converter/fromSDFG/symbol/array.mlir b/test/SDFG/Converter/fromSDFG/symbol/array.mlir new file mode 100644 index 000000000..d5cecf4ad --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/symbol/array.mlir @@ -0,0 +1,8 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0 { + sdfg.alloc_symbol("N") + %a = sdfg.alloc() : !sdfg.array + } +} diff --git a/test/SDFG/Converter/fromSDFG/symbol/eval_expr.mlir b/test/SDFG/Converter/fromSDFG/symbol/eval_expr.mlir new file mode 100644 index 000000000..514f01493 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/symbol/eval_expr.mlir @@ -0,0 +1,12 @@ +// RUN: sdfg-opt --lower-sdfg %s + +module { + sdfg.sdfg {entry = @state_0} () -> (%arg1: !sdfg.array) { + sdfg.alloc_symbol("idx") + + sdfg.state @state_0 { + %2 = sdfg.sym("idx") : index + sdfg.store %2, %arg1[] : index -> !sdfg.array + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/symbol/load.mlir b/test/SDFG/Converter/fromSDFG/symbol/load.mlir new file mode 100644 index 000000000..870410208 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/symbol/load.mlir @@ -0,0 +1,10 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array<5x3xi32> + + sdfg.state @state_0 { + sdfg.alloc_symbol("N") + %a_1 = sdfg.load %A[sym("N"), 0] : !sdfg.array<5x3xi32> -> i32 + } +} diff --git a/test/SDFG/Converter/fromSDFG/symbol/map.mlir b/test/SDFG/Converter/fromSDFG/symbol/map.mlir new file mode 100644 index 000000000..c5ea9192b --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/symbol/map.mlir @@ -0,0 +1,23 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0 { + sdfg.alloc_symbol("N") + %A = sdfg.alloc() : !sdfg.array<2x6xi32> + %B = sdfg.alloc() : !sdfg.array<2x6xi32> + %C = sdfg.alloc() : !sdfg.array<2x6xi32> + + + sdfg.map (%i, %j) = (0, 0) to (2, sym("N")) step (0, sym("N+2")) { + %a_ij = sdfg.load %A[%i, %j] : !sdfg.array<2x6xi32> -> i32 + %b_ij = sdfg.load %B[%i, %j] : !sdfg.array<2x6xi32> -> i32 + + %res = sdfg.tasklet(%a_ij: i32, %b_ij: i32) -> (i32) { + %z = arith.addi %a_ij, %b_ij : i32 + sdfg.return %z : i32 + } + + sdfg.store %res, %C[%i, %j] : i32 -> !sdfg.array<2x6xi32> + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/symbol/propagation.mlir b/test/SDFG/Converter/fromSDFG/symbol/propagation.mlir new file mode 100644 index 000000000..2be30d821 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/symbol/propagation.mlir @@ -0,0 +1,12 @@ +// RUN: sdfg-opt --lower-sdfg %s + +module { + sdfg.sdfg () -> (%arg9: !sdfg.array) { + sdfg.state @state_0 { + %arg0 = sdfg.alloc() : !sdfg.array + sdfg.nested_sdfg (%arg0: !sdfg.array) -> (%arg9: !sdfg.array) { + sdfg.state @init_4 {} + } + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/symbol/propagation_2.mlir b/test/SDFG/Converter/fromSDFG/symbol/propagation_2.mlir new file mode 100644 index 000000000..9b47ad6c4 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/symbol/propagation_2.mlir @@ -0,0 +1,17 @@ +// RUN: sdfg-opt --lower-sdfg %s + +module { + sdfg.sdfg () -> () { + sdfg.alloc_symbol("N") + + sdfg.state @state_0 { + + sdfg.nested_sdfg () -> () { + sdfg.state @init_4 { + %res = sdfg.sym("3*N+2") : i64 + } + } + + } + } +} diff --git a/test/SDFG/Converter/fromSDFG/symbol/store.mlir b/test/SDFG/Converter/fromSDFG/symbol/store.mlir new file mode 100644 index 000000000..9f69cb8e4 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/symbol/store.mlir @@ -0,0 +1,16 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array<12x12xi32> + + sdfg.state @state_0 { + sdfg.alloc_symbol("N") + + %1 = sdfg.tasklet () -> (i32) { + %1 = arith.constant 1 : i32 + sdfg.return %1 : i32 + } + + sdfg.store %1, %A[0, sym("N")] : i32 -> !sdfg.array<12x12xi32> + } +} diff --git a/test/SDFG/Converter/fromSDFG/symbol/stream.mlir b/test/SDFG/Converter/fromSDFG/symbol/stream.mlir new file mode 100644 index 000000000..08fc885e9 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/symbol/stream.mlir @@ -0,0 +1,9 @@ +// XFAIL: * +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0{ + sdfg.alloc_symbol("N") + %a = sdfg.alloc() : !sdfg.stream + } +} diff --git a/test/SDFG/Converter/fromSDFG/symbol/unused.mlir b/test/SDFG/Converter/fromSDFG/symbol/unused.mlir new file mode 100644 index 000000000..1fda6d3ce --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/symbol/unused.mlir @@ -0,0 +1,8 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + sdfg.state @state_0{ + sdfg.alloc_symbol("N") + %res = sdfg.sym("3*N+2") : i64 + } +} diff --git a/test/SDFG/Converter/fromSDFG/tasklet/call.mlir b/test/SDFG/Converter/fromSDFG/tasklet/call.mlir new file mode 100644 index 000000000..4d8768566 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/tasklet/call.mlir @@ -0,0 +1,19 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array) { + %A = sdfg.alloc() : !sdfg.array + + sdfg.state @state_0{ + %1 = sdfg.tasklet() -> (i32) { + %1 = arith.constant 1 : i32 + sdfg.return %1 : i32 + } + + %c = sdfg.tasklet(%1: i32) -> (i32) { + %c = arith.addi %1, %1 : i32 + sdfg.return %c : i32 + } + + sdfg.store %c, %A[] : i32 -> !sdfg.array + } +} diff --git a/test/SDFG/Converter/fromSDFG/tasklet/libcall.mlir b/test/SDFG/Converter/fromSDFG/tasklet/libcall.mlir new file mode 100644 index 000000000..bd2864df9 --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/tasklet/libcall.mlir @@ -0,0 +1,12 @@ +// XFAIL: * +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%r: !sdfg.array<2x2xi32>) { + %A = sdfg.alloc() : !sdfg.array<2x2xi32> + %B = sdfg.alloc() : !sdfg.array<2x2xi32> + + sdfg.state @state_0{ + %c = sdfg.libcall{inputs=["_a", "_b"], outputs=["_c"]} "dace.libraries.blas.nodes.MatMul" (%A, %B) : (!sdfg.array<2x2xi32>, !sdfg.array<2x2xi32>) -> !sdfg.array<2x2xi32> + sdfg.copy %c -> %r : !sdfg.array<2x2xi32> + } +} diff --git a/test/SDFG/Converter/fromSDFG/tasklet/multi_return.mlir b/test/SDFG/Converter/fromSDFG/tasklet/multi_return.mlir new file mode 100644 index 000000000..aa74a12bd --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/tasklet/multi_return.mlir @@ -0,0 +1,14 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%arg0: !sdfg.array, %arg1: !sdfg.array) { + sdfg.state @state_0{ + %n:2 = sdfg.tasklet() -> (i32, i32) { + %1 = arith.constant 1 : i32 + %5 = arith.constant 5 : i32 + sdfg.return %1, %5 : i32, i32 + } + + sdfg.store %n#0, %arg0[] : i32 -> !sdfg.array + sdfg.store %n#1, %arg1[] : i32 -> !sdfg.array + } +} diff --git a/test/SDFG/Converter/fromSDFG/tasklet/multi_return_interdependent.mlir b/test/SDFG/Converter/fromSDFG/tasklet/multi_return_interdependent.mlir new file mode 100644 index 000000000..9bce3cc0f --- /dev/null +++ b/test/SDFG/Converter/fromSDFG/tasklet/multi_return_interdependent.mlir @@ -0,0 +1,14 @@ +// RUN: sdfg-opt --lower-sdfg %s + +sdfg.sdfg () -> (%arg0: !sdfg.array, %arg1: !sdfg.array) { + sdfg.state @state_0{ + %n:2 = sdfg.tasklet() -> (i32, i32) { + %1 = arith.constant 1 : i32 + %2 = arith.addi %1, %1 : i32 + sdfg.return %1, %2 : i32, i32 + } + + sdfg.store %n#0, %arg0[] : i32 -> !sdfg.array + sdfg.store %n#1, %arg1[] : i32 -> !sdfg.array + } +} diff --git a/test/SDFG/Converter/lit.cfg.py b/test/SDFG/Converter/lit.cfg.py index 2e3331131..1a4e23515 100644 --- a/test/SDFG/Converter/lit.cfg.py +++ b/test/SDFG/Converter/lit.cfg.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + # -*- Python -*- import os diff --git a/test/SDFG/Converter/lit.site.cfg.py.in b/test/SDFG/Converter/lit.site.cfg.py.in index 9a1a6982c..3e375dc46 100644 --- a/test/SDFG/Converter/lit.site.cfg.py.in +++ b/test/SDFG/Converter/lit.site.cfg.py.in @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + @LIT_SITE_CFG_IN_HEADER@ import sys diff --git a/test/SDFG/Converter/arith/add.mlir b/test/SDFG/Converter/toSDFG/arith/add.mlir similarity index 100% rename from test/SDFG/Converter/arith/add.mlir rename to test/SDFG/Converter/toSDFG/arith/add.mlir diff --git a/test/SDFG/Converter/arith/args.mlir b/test/SDFG/Converter/toSDFG/arith/args.mlir similarity index 100% rename from test/SDFG/Converter/arith/args.mlir rename to test/SDFG/Converter/toSDFG/arith/args.mlir diff --git a/test/SDFG/Converter/arith/binop.mlir b/test/SDFG/Converter/toSDFG/arith/binop.mlir similarity index 100% rename from test/SDFG/Converter/arith/binop.mlir rename to test/SDFG/Converter/toSDFG/arith/binop.mlir diff --git a/test/SDFG/Converter/arith/constant.mlir b/test/SDFG/Converter/toSDFG/arith/constant.mlir similarity index 100% rename from test/SDFG/Converter/arith/constant.mlir rename to test/SDFG/Converter/toSDFG/arith/constant.mlir diff --git a/test/SDFG/Converter/arith/dependency.mlir b/test/SDFG/Converter/toSDFG/arith/dependency.mlir similarity index 100% rename from test/SDFG/Converter/arith/dependency.mlir rename to test/SDFG/Converter/toSDFG/arith/dependency.mlir diff --git a/test/SDFG/Converter/arith/multi_return.mlir b/test/SDFG/Converter/toSDFG/arith/multi_return.mlir similarity index 100% rename from test/SDFG/Converter/arith/multi_return.mlir rename to test/SDFG/Converter/toSDFG/arith/multi_return.mlir diff --git a/test/SDFG/Converter/func/call.mlir b/test/SDFG/Converter/toSDFG/func/call.mlir similarity index 94% rename from test/SDFG/Converter/func/call.mlir rename to test/SDFG/Converter/toSDFG/func/call.mlir index 073d1c3ac..dfcb942e5 100644 --- a/test/SDFG/Converter/func/call.mlir +++ b/test/SDFG/Converter/toSDFG/func/call.mlir @@ -1,4 +1,3 @@ -// XFAIL: * // RUN: sdfg-opt --convert-to-sdfg %s | sdfg-opt func.func private @ex(i32, i32) diff --git a/test/SDFG/Converter/func/call_return.mlir b/test/SDFG/Converter/toSDFG/func/call_return.mlir similarity index 100% rename from test/SDFG/Converter/func/call_return.mlir rename to test/SDFG/Converter/toSDFG/func/call_return.mlir diff --git a/test/SDFG/Converter/linalg/matmul.mlir b/test/SDFG/Converter/toSDFG/linalg/matmul.mlir similarity index 100% rename from test/SDFG/Converter/linalg/matmul.mlir rename to test/SDFG/Converter/toSDFG/linalg/matmul.mlir diff --git a/test/SDFG/Converter/llvm/alloca.mlir b/test/SDFG/Converter/toSDFG/llvm/alloca.mlir similarity index 100% rename from test/SDFG/Converter/llvm/alloca.mlir rename to test/SDFG/Converter/toSDFG/llvm/alloca.mlir diff --git a/test/SDFG/Converter/llvm/bitcast.mlir b/test/SDFG/Converter/toSDFG/llvm/bitcast.mlir similarity index 100% rename from test/SDFG/Converter/llvm/bitcast.mlir rename to test/SDFG/Converter/toSDFG/llvm/bitcast.mlir diff --git a/test/SDFG/Converter/llvm/load.mlir b/test/SDFG/Converter/toSDFG/llvm/load.mlir similarity index 100% rename from test/SDFG/Converter/llvm/load.mlir rename to test/SDFG/Converter/toSDFG/llvm/load.mlir diff --git a/test/SDFG/Converter/llvm/store.mlir b/test/SDFG/Converter/toSDFG/llvm/store.mlir similarity index 100% rename from test/SDFG/Converter/llvm/store.mlir rename to test/SDFG/Converter/toSDFG/llvm/store.mlir diff --git a/test/SDFG/Converter/memref/alloc.mlir b/test/SDFG/Converter/toSDFG/memref/alloc.mlir similarity index 100% rename from test/SDFG/Converter/memref/alloc.mlir rename to test/SDFG/Converter/toSDFG/memref/alloc.mlir diff --git a/test/SDFG/Converter/memref/alloc_multi_param.mlir b/test/SDFG/Converter/toSDFG/memref/alloc_multi_param.mlir similarity index 100% rename from test/SDFG/Converter/memref/alloc_multi_param.mlir rename to test/SDFG/Converter/toSDFG/memref/alloc_multi_param.mlir diff --git a/test/SDFG/Converter/memref/alloc_param.mlir b/test/SDFG/Converter/toSDFG/memref/alloc_param.mlir similarity index 100% rename from test/SDFG/Converter/memref/alloc_param.mlir rename to test/SDFG/Converter/toSDFG/memref/alloc_param.mlir diff --git a/test/SDFG/Converter/memref/alloca.mlir b/test/SDFG/Converter/toSDFG/memref/alloca.mlir similarity index 100% rename from test/SDFG/Converter/memref/alloca.mlir rename to test/SDFG/Converter/toSDFG/memref/alloca.mlir diff --git a/test/SDFG/Converter/memref/args.mlir b/test/SDFG/Converter/toSDFG/memref/args.mlir similarity index 100% rename from test/SDFG/Converter/memref/args.mlir rename to test/SDFG/Converter/toSDFG/memref/args.mlir diff --git a/test/SDFG/Converter/memref/cast.mlir b/test/SDFG/Converter/toSDFG/memref/cast.mlir similarity index 100% rename from test/SDFG/Converter/memref/cast.mlir rename to test/SDFG/Converter/toSDFG/memref/cast.mlir diff --git a/test/SDFG/Converter/memref/global.mlir b/test/SDFG/Converter/toSDFG/memref/global.mlir similarity index 100% rename from test/SDFG/Converter/memref/global.mlir rename to test/SDFG/Converter/toSDFG/memref/global.mlir diff --git a/test/SDFG/Converter/module/empty.mlir b/test/SDFG/Converter/toSDFG/module/empty.mlir similarity index 100% rename from test/SDFG/Converter/module/empty.mlir rename to test/SDFG/Converter/toSDFG/module/empty.mlir diff --git a/test/SDFG/Converter/module/module.mlir b/test/SDFG/Converter/toSDFG/module/module.mlir similarity index 100% rename from test/SDFG/Converter/module/module.mlir rename to test/SDFG/Converter/toSDFG/module/module.mlir diff --git a/test/SDFG/Converter/module/other.mlir b/test/SDFG/Converter/toSDFG/module/other.mlir similarity index 100% rename from test/SDFG/Converter/module/other.mlir rename to test/SDFG/Converter/toSDFG/module/other.mlir diff --git a/test/SDFG/Converter/module/sdfg.mlir b/test/SDFG/Converter/toSDFG/module/sdfg.mlir similarity index 100% rename from test/SDFG/Converter/module/sdfg.mlir rename to test/SDFG/Converter/toSDFG/module/sdfg.mlir diff --git a/test/SDFG/Converter/polybench/2mm.mlir b/test/SDFG/Converter/toSDFG/polybench/2mm.mlir similarity index 100% rename from test/SDFG/Converter/polybench/2mm.mlir rename to test/SDFG/Converter/toSDFG/polybench/2mm.mlir diff --git a/test/SDFG/Converter/scf/for_arith_args.mlir b/test/SDFG/Converter/toSDFG/scf/for_arith_args.mlir similarity index 100% rename from test/SDFG/Converter/scf/for_arith_args.mlir rename to test/SDFG/Converter/toSDFG/scf/for_arith_args.mlir diff --git a/test/SDFG/Converter/scf/for_arith_const.mlir b/test/SDFG/Converter/toSDFG/scf/for_arith_const.mlir similarity index 100% rename from test/SDFG/Converter/scf/for_arith_const.mlir rename to test/SDFG/Converter/toSDFG/scf/for_arith_const.mlir diff --git a/test/SDFG/Converter/scf/for_arith_iter.mlir b/test/SDFG/Converter/toSDFG/scf/for_arith_iter.mlir similarity index 100% rename from test/SDFG/Converter/scf/for_arith_iter.mlir rename to test/SDFG/Converter/toSDFG/scf/for_arith_iter.mlir diff --git a/test/SDFG/Converter/scf/for_double_memref.mlir b/test/SDFG/Converter/toSDFG/scf/for_double_memref.mlir similarity index 100% rename from test/SDFG/Converter/scf/for_double_memref.mlir rename to test/SDFG/Converter/toSDFG/scf/for_double_memref.mlir diff --git a/test/SDFG/Converter/scf/for_multi_yield.mlir b/test/SDFG/Converter/toSDFG/scf/for_multi_yield.mlir similarity index 100% rename from test/SDFG/Converter/scf/for_multi_yield.mlir rename to test/SDFG/Converter/toSDFG/scf/for_multi_yield.mlir diff --git a/test/SDFG/Converter/scf/for_nested.mlir b/test/SDFG/Converter/toSDFG/scf/for_nested.mlir similarity index 100% rename from test/SDFG/Converter/scf/for_nested.mlir rename to test/SDFG/Converter/toSDFG/scf/for_nested.mlir diff --git a/test/SDFG/Converter/scf/for_nested_between.mlir b/test/SDFG/Converter/toSDFG/scf/for_nested_between.mlir similarity index 100% rename from test/SDFG/Converter/scf/for_nested_between.mlir rename to test/SDFG/Converter/toSDFG/scf/for_nested_between.mlir diff --git a/test/SDFG/Converter/scf/for_triple_nested.mlir b/test/SDFG/Converter/toSDFG/scf/for_triple_nested.mlir similarity index 100% rename from test/SDFG/Converter/scf/for_triple_nested.mlir rename to test/SDFG/Converter/toSDFG/scf/for_triple_nested.mlir diff --git a/test/SDFG/Converter/scf/for_yield.mlir b/test/SDFG/Converter/toSDFG/scf/for_yield.mlir similarity index 100% rename from test/SDFG/Converter/scf/for_yield.mlir rename to test/SDFG/Converter/toSDFG/scf/for_yield.mlir diff --git a/test/SDFG/Converter/scf/for_yield_dependent.mlir b/test/SDFG/Converter/toSDFG/scf/for_yield_dependent.mlir similarity index 100% rename from test/SDFG/Converter/scf/for_yield_dependent.mlir rename to test/SDFG/Converter/toSDFG/scf/for_yield_dependent.mlir diff --git a/test/SDFG/Converter/scf/if_arith_const.mlir b/test/SDFG/Converter/toSDFG/scf/if_arith_const.mlir similarity index 100% rename from test/SDFG/Converter/scf/if_arith_const.mlir rename to test/SDFG/Converter/toSDFG/scf/if_arith_const.mlir diff --git a/test/SDFG/Converter/scf/if_else_empty.mlir b/test/SDFG/Converter/toSDFG/scf/if_else_empty.mlir similarity index 100% rename from test/SDFG/Converter/scf/if_else_empty.mlir rename to test/SDFG/Converter/toSDFG/scf/if_else_empty.mlir diff --git a/test/SDFG/Converter/scf/if_empty.mlir b/test/SDFG/Converter/toSDFG/scf/if_empty.mlir similarity index 100% rename from test/SDFG/Converter/scf/if_empty.mlir rename to test/SDFG/Converter/toSDFG/scf/if_empty.mlir diff --git a/test/SDFG/Converter/scf/if_multi_yield.mlir b/test/SDFG/Converter/toSDFG/scf/if_multi_yield.mlir similarity index 100% rename from test/SDFG/Converter/scf/if_multi_yield.mlir rename to test/SDFG/Converter/toSDFG/scf/if_multi_yield.mlir diff --git a/test/SDFG/Converter/scf/if_nested.mlir b/test/SDFG/Converter/toSDFG/scf/if_nested.mlir similarity index 100% rename from test/SDFG/Converter/scf/if_nested.mlir rename to test/SDFG/Converter/toSDFG/scf/if_nested.mlir diff --git a/test/SDFG/Converter/scf/if_yield.mlir b/test/SDFG/Converter/toSDFG/scf/if_yield.mlir similarity index 100% rename from test/SDFG/Converter/scf/if_yield.mlir rename to test/SDFG/Converter/toSDFG/scf/if_yield.mlir diff --git a/test/SDFG/Converter/scf/while.mlir b/test/SDFG/Converter/toSDFG/scf/while.mlir similarity index 100% rename from test/SDFG/Converter/scf/while.mlir rename to test/SDFG/Converter/toSDFG/scf/while.mlir diff --git a/test/SDFG/Converter/scf/while_multi_iter.mlir b/test/SDFG/Converter/toSDFG/scf/while_multi_iter.mlir similarity index 100% rename from test/SDFG/Converter/scf/while_multi_iter.mlir rename to test/SDFG/Converter/toSDFG/scf/while_multi_iter.mlir diff --git a/test/SDFG/Converter/scf/while_multi_res.mlir b/test/SDFG/Converter/toSDFG/scf/while_multi_res.mlir similarity index 100% rename from test/SDFG/Converter/scf/while_multi_res.mlir rename to test/SDFG/Converter/toSDFG/scf/while_multi_res.mlir diff --git a/test/SDFG/Converter/scf/while_res.mlir b/test/SDFG/Converter/toSDFG/scf/while_res.mlir similarity index 100% rename from test/SDFG/Converter/scf/while_res.mlir rename to test/SDFG/Converter/toSDFG/scf/while_res.mlir diff --git a/test/SDFG/Dialect/CMakeLists.txt b/test/SDFG/Dialect/CMakeLists.txt index a9e3723fc..b9ab7cb54 100644 --- a/test/SDFG/Dialect/CMakeLists.txt +++ b/test/SDFG/Dialect/CMakeLists.txt @@ -1,17 +1,15 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + configure_lit_site_cfg( - ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in - ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py - MAIN_CONFIG - ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py -) + ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in + ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py MAIN_CONFIG + ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py) set(SDFG_TEST_DEPENDS FileCheck count not sdfg-opt) -add_lit_testsuite( - check-sdfg-opt "Running the sdfg regression tests" - ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${SDFG_TEST_DEPENDS} -) +add_lit_testsuite(check-sdfg-opt "Running the sdfg regression tests" + ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${SDFG_TEST_DEPENDS}) set_target_properties(check-sdfg-opt PROPERTIES FOLDER "Tests") -add_lit_testsuites(SDFG-OPT ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${SDFG_TEST_DEPENDS}) +add_lit_testsuites(SDFG-OPT ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS + ${SDFG_TEST_DEPENDS}) diff --git a/test/SDFG/Dialect/lit.cfg.py b/test/SDFG/Dialect/lit.cfg.py index 72592c2ce..9bbf5216e 100644 --- a/test/SDFG/Dialect/lit.cfg.py +++ b/test/SDFG/Dialect/lit.cfg.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + # -*- Python -*- import os diff --git a/test/SDFG/Dialect/lit.site.cfg.py.in b/test/SDFG/Dialect/lit.site.cfg.py.in index 1629e4018..1c7b575ae 100644 --- a/test/SDFG/Dialect/lit.site.cfg.py.in +++ b/test/SDFG/Dialect/lit.site.cfg.py.in @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + @LIT_SITE_CFG_IN_HEADER@ import sys diff --git a/test/SDFG/Integration/CMakeLists.txt b/test/SDFG/Integration/CMakeLists.txt index f9614000c..5c37d7e79 100644 --- a/test/SDFG/Integration/CMakeLists.txt +++ b/test/SDFG/Integration/CMakeLists.txt @@ -1,17 +1,15 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + configure_lit_site_cfg( - ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in - ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py - MAIN_CONFIG - ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py -) + ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in + ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py MAIN_CONFIG + ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py) set(SDFG_TEST_DEPENDS FileCheck count not sdfg-opt sdfg-translate) -add_lit_testsuite( - check-sdfg-integration "Running the sdfg integration tests" - ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${SDFG_TEST_DEPENDS} -) +add_lit_testsuite(check-sdfg-integration "Running the sdfg integration tests" + ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${SDFG_TEST_DEPENDS}) set_target_properties(check-sdfg-integration PROPERTIES FOLDER "Tests") -add_lit_testsuites(SDFG-INTEGRATION ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${SDFG_TEST_DEPENDS}) +add_lit_testsuites(SDFG-INTEGRATION ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS + ${SDFG_TEST_DEPENDS}) diff --git a/test/SDFG/Integration/execute_sdfg.py b/test/SDFG/Integration/execute_sdfg.py index 9e98a7a1a..28d908740 100644 --- a/test/SDFG/Integration/execute_sdfg.py +++ b/test/SDFG/Integration/execute_sdfg.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + # Executes given SDFG with three-filled arrays and prints all output to stdout import json diff --git a/test/SDFG/Integration/import_translation_test.py b/test/SDFG/Integration/import_translation_test.py index bd13bcd7d..0ad70e397 100644 --- a/test/SDFG/Integration/import_translation_test.py +++ b/test/SDFG/Integration/import_translation_test.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + import json import sys from dace import SDFG diff --git a/test/SDFG/Integration/lit.cfg.py b/test/SDFG/Integration/lit.cfg.py index ef8b132c6..3be2dda7e 100644 --- a/test/SDFG/Integration/lit.cfg.py +++ b/test/SDFG/Integration/lit.cfg.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + # -*- Python -*- import os diff --git a/test/SDFG/Integration/lit.site.cfg.py.in b/test/SDFG/Integration/lit.site.cfg.py.in index 7ae13bd8b..20c78c9e7 100644 --- a/test/SDFG/Integration/lit.site.cfg.py.in +++ b/test/SDFG/Integration/lit.site.cfg.py.in @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + @LIT_SITE_CFG_IN_HEADER@ import sys diff --git a/test/SDFG/Translate/CMakeLists.txt b/test/SDFG/Translate/CMakeLists.txt index a48031779..2e694ef65 100644 --- a/test/SDFG/Translate/CMakeLists.txt +++ b/test/SDFG/Translate/CMakeLists.txt @@ -1,17 +1,15 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + configure_lit_site_cfg( - ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in - ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py - MAIN_CONFIG - ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py -) + ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in + ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py MAIN_CONFIG + ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py) set(SDFG_TEST_DEPENDS FileCheck count not sdfg-translate) -add_lit_testsuite( - check-sdfg-translate "Running the sdfg translation tests" - ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS ${SDFG_TEST_DEPENDS} -) +add_lit_testsuite(check-sdfg-translate "Running the sdfg translation tests" + ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${SDFG_TEST_DEPENDS}) set_target_properties(check-sdfg-translate PROPERTIES FOLDER "Tests") -add_lit_testsuites(SDFG-TRANSLATE ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${SDFG_TEST_DEPENDS}) +add_lit_testsuites(SDFG-TRANSLATE ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS + ${SDFG_TEST_DEPENDS}) diff --git a/test/SDFG/Translate/execute_sdfg.py b/test/SDFG/Translate/execute_sdfg.py index 31ea1d08e..07ad7ddf4 100644 --- a/test/SDFG/Translate/execute_sdfg.py +++ b/test/SDFG/Translate/execute_sdfg.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + # Executes given SDFG with three-filled arrays and prints all output to stdout import json diff --git a/test/SDFG/Translate/import_translation_test.py b/test/SDFG/Translate/import_translation_test.py index bd13bcd7d..0ad70e397 100644 --- a/test/SDFG/Translate/import_translation_test.py +++ b/test/SDFG/Translate/import_translation_test.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + import json import sys from dace import SDFG diff --git a/test/SDFG/Translate/lit.cfg.py b/test/SDFG/Translate/lit.cfg.py index 5c1e35ee2..1e6ed823a 100644 --- a/test/SDFG/Translate/lit.cfg.py +++ b/test/SDFG/Translate/lit.cfg.py @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + # -*- Python -*- import os diff --git a/test/SDFG/Translate/lit.site.cfg.py.in b/test/SDFG/Translate/lit.site.cfg.py.in index 7bdd506d2..8d9b1fb20 100644 --- a/test/SDFG/Translate/lit.site.cfg.py.in +++ b/test/SDFG/Translate/lit.site.cfg.py.in @@ -1,3 +1,5 @@ +# Copyright (c) 2021-2023, Scalable Parallel Computing Lab, ETH Zurich + @LIT_SITE_CFG_IN_HEADER@ import sys