diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index e2b00b1306..c760c03452 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -210,6 +210,10 @@ llvm::cl::opt allowSorting("allowSorting", llvm::cl::desc("Perform topological sort on onnx graph"), llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions)); +llvm::cl::opt enableLinalg("enableLinalg", + llvm::cl::desc("Enable ONNX to Linalg conversion and related passes"), + llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); + // Configuration states associated with certain options. // For example, when maccel is specified, NNPA can register // dependent libdnn. diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index 79d955bd77..9f5e959524 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -71,6 +71,7 @@ extern llvm::cl::opt onnxOpTransformThreshold; extern llvm::cl::opt onnxOpTransformReport; extern llvm::cl::opt enableParallel; extern llvm::cl::opt enableSimdDataLayout; +extern llvm::cl::opt enableLinalg; // The customEnvFlags must be scanned before the normal options. bool parseCustomEnvFlagsCommandLineOption(int argc, const char *const *argv, diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index aef46dc34c..8e1317a064 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -23,6 +23,7 @@ #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Transforms/Passes.h" @@ -121,10 +122,25 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, if (enableInstrumentONNXSignature) pm.addNestedPass( onnx_mlir::createInstrumentONNXSignaturePass()); + if (enableLinalg) { + pm.addPass(onnx_mlir::createLowerONNXToLinalgPass()); + + // Convert tensor.EmptyOp to bufferization.alloc_tensor + // This pass has to come before Linalg Bufferize pass. + // Otherwise, the bufferization.alloc_tensor will not be lowered + pm.addNestedPass( + bufferization::createEmptyTensorToAllocTensorPass()); + + // Linalg bufferization can be before or after LowerToKrnlPass + pm.addNestedPass(createLinalgBufferizePass()); + } pm.addPass(onnx_mlir::createLowerToKrnlPass(optLevel, enableParallel)); // An additional pass of canonicalization is helpful because lowering // from ONNX dialect to Standard dialect exposes additional canonicalization // opportunities. + + // For Linalg and Krnl mixed IR: + // Canonicalization pass will clean up bufferization::to_tensor and to_memref pm.addPass(mlir::createCanonicalizerPass()); pm.addNestedPass( onnx_mlir::createDisconnectKrnlDimFromAllocPass()); @@ -132,6 +148,9 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, } // namespace onnx_mlir void addKrnlToAffinePasses(mlir::PassManager &pm) { + if (enableLinalg) { + pm.addNestedPass(createConvertLinalgToAffineLoopsPass()); + } pm.addNestedPass( onnx_mlir::krnl::createConvertKrnlToAffinePass()); } diff --git a/src/Conversion/CMakeLists.txt b/src/Conversion/CMakeLists.txt index beb29368d2..2052858a68 100644 --- a/src/Conversion/CMakeLists.txt +++ b/src/Conversion/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(KrnlToLLVM) add_subdirectory(KrnlToAffine) add_subdirectory(KrnlSeqToMemref) add_subdirectory(ONNXToTOSA) +add_subdirectory(ONNXToLinalg) if (ONNX_MLIR_ENABLE_MHLO) add_subdirectory(ONNXToMhlo) diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index 30a638532c..0deacd0f05 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -73,6 +73,7 @@ add_onnx_mlir_library(OMONNXToKrnl LINK_LIBS PUBLIC OMAccelerator + OMCompilerOptions OMConstPropHelper OMONNXOps OMSupport diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index a09a9f1901..4b05b7f030 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -13,6 +13,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "src/Compiler/CompilerOptions.hpp" @@ -350,6 +351,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); // If `emitDealloc` is turned off, make sure we don't have buffer deallocation // at this level. Will use MLIR buffer-deallocation for this purpose instead. diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index 8323080a44..f075045d58 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -13,8 +13,11 @@ // //===----------------------------------------------------------------------===// -#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" + #include "src/Accelerators/Accelerator.hpp" +#include "src/Compiler/CompilerOptions.hpp" +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/Mlir/DialectBuilder.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" @@ -724,8 +727,16 @@ KrnlTypeConverter::KrnlTypeConverter() { if (inputs.size() != 1) return llvm::None; - return builder.create(loc, resultType, inputs) - .getResult(0); + // Use ToTensorOp instead of UnrealizedConversionCastOp + // because Linalg use ToTensor, though they are the same in semantic + // Since UnrealizedConversionCastOp is used in other places and will not + // be replaced in this PR + if (enableLinalg) + return builder.create(loc, resultType, inputs) + .getResult(); + else + return builder.create(loc, resultType, inputs) + .getResult(0); }); addTargetMaterialization([&](OpBuilder &builder, Type resultType, @@ -734,8 +745,13 @@ KrnlTypeConverter::KrnlTypeConverter() { if (inputs.size() != 1) return llvm::None; - return builder.create(loc, resultType, inputs) - .getResult(0); + // Replace UnrealizedConversionCastOp + if (enableLinalg) + return builder.create(loc, resultType, inputs) + .getResult(); + else + return builder.create(loc, resultType, inputs) + .getResult(0); }); } diff --git a/src/Conversion/ONNXToLinalg/CMakeLists.txt b/src/Conversion/ONNXToLinalg/CMakeLists.txt new file mode 100644 index 0000000000..63488251e3 --- /dev/null +++ b/src/Conversion/ONNXToLinalg/CMakeLists.txt @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Please keep in alphabetical order. +add_onnx_mlir_library(OMONNXToLinalg + ConvertONNXToLinalg.cpp + ONNXToLinalgCommon.cpp + Math/MatMul.cpp + + LINK_LIBS PUBLIC + OMAccelerator + OMConstPropHelper + OMONNXOps + OMSupport + MLIRFuncDialect + MLIRFuncTransforms + ) diff --git a/src/Conversion/ONNXToLinalg/ConvertONNXToLinalg.cpp b/src/Conversion/ONNXToLinalg/ConvertONNXToLinalg.cpp new file mode 100644 index 0000000000..59b08172fa --- /dev/null +++ b/src/Conversion/ONNXToLinalg/ConvertONNXToLinalg.cpp @@ -0,0 +1,132 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====------ ConvertONNXToLinalg.cpp - ONNX dialects to Krnl lowering -----===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements the lowering of frontend operations to a combination of +// Krnl IR and standard operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "src/Compiler/CompilerOptions.hpp" + +#include "src/Accelerators/Accelerator.hpp" +#include "src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +void populateONNXToLinalgConversionPattern(RewritePatternSet &patterns, + TypeConverter &typeConverter, MLIRContext *ctx) { + + // Math + populateLoweringONNXMatMulOpLinalgPattern(patterns, typeConverter, ctx); +} + +//===----------------------------------------------------------------------===// +// ONNX to Krnl Dialect lowering pass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to Krnl loops of the ONNX operations. +struct ONNXToLinalgLoweringPass + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToLinalgLoweringPass) + + StringRef getArgument() const override { return "convert-onnx-to-linalg"; } + + StringRef getDescription() const override { + return "Lower ONNX ops to Linalg dialect."; + } + + // Make sure that we have a valid default constructor and copy + // constructor to make sure that the options are initialized properly. + ONNXToLinalgLoweringPass() = default; + ONNXToLinalgLoweringPass(const ONNXToLinalgLoweringPass &pass) + : PassWrapper>() {} + + void runOnOperation() final; +}; + +void ONNXToLinalgLoweringPass::runOnOperation() { + ModuleOp module = getOperation(); + + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. + target.addLegalDialect(); + // Needed to support unsigned int computations. To be removed if we use a + // scheme that does not rely on the UnrealizedConversionCastOp. + target.addLegalOp<::mlir::UnrealizedConversionCastOp>(); + // Make ONNXNoneOp legal so that other ONNX ops can use it during the + // lowering. ONNXNoneOp will be dangling and removed by calling + // canonicalization after the lowering. + target.addLegalOp<::mlir::ONNXNoneOp>(); + target.addLegalOp(); + target.addLegalOp(); + + // The following requirements are from Krnl and they are kept if ONNXToKrnl + // is after this pass. + // If the Linalg is on tensor instead of memref, this lowering will not + // generate memref or Affine load/store. However, these requiremnts will may + // be an issue if Ops are lowered other than Krnl Use krnl.load/store instead + // of std.load/store and affine.load/store. krnl.load/store will be lowered to + // std.load/store and affine.load/store by `convert-krnl-to-affine` pass. + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + target.addIllegalOp(); + + // TODO: add any other ops which are considered legal. + // Some operations can be marked as being still legal. + // Example: target.addLegalOp(); + + // For future: Handle the accelerator target. + // for (auto *accel : onnx_mlir::accel::Accelerator::getAccelerators()) + // accel->conversionTargetONNXToLinalg(target); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the frontend operations. + RewritePatternSet patterns(&getContext()); + + // Convert types to legal types for the Krnl dialect. + LinalgTypeConverter linalgTypeConverter; + + // Define patterns. + populateONNXToLinalgConversionPattern( + patterns, linalgTypeConverter, &getContext()); + + // For future: Rewrite patterns for accelerators. + // for (auto *accel : onnx_mlir::accel::Accelerator::getAccelerators()) + // accel->rewritePatternONNXToLinalg(patterns, krnlTypeConverter, + // &getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr createLowerONNXToLinalgPass() { + return std::make_unique(); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToLinalg/Math/MatMul.cpp b/src/Conversion/ONNXToLinalg/Math/MatMul.cpp new file mode 100644 index 0000000000..a5120ac653 --- /dev/null +++ b/src/Conversion/ONNXToLinalg/Math/MatMul.cpp @@ -0,0 +1,63 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===----------------- Matmul.cpp - Lowering Matmul Op --------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Matmul Operator to Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/Debug.h" + +#include "src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.hpp" +#include "src/Dialect/Mlir/DialectBuilder.hpp" + +#define DEBUG_TYPE "matmul" + +using namespace mlir; + +namespace onnx_mlir { + +struct ONNXMatMulOpLinalgLowering : public ConversionPattern { + ONNXMatMulOpLinalgLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : ConversionPattern( + typeConverter, mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {} + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + + auto outputType = op->getResult(0).getType().cast(); + + SmallVector dynamicDims; + if (outputType.isDynamicDim(0)) { + dynamicDims.emplace_back( + rewriter.create(loc, operands[0], 0)); + } + if (outputType.isDynamicDim(1)) { + dynamicDims.emplace_back( + rewriter.create(loc, operands[1], 1)); + } + + auto outV = rewriter.create( + loc, outputType.getShape(), outputType.getElementType(), dynamicDims); + + SmallVector outputs; + outputs.emplace_back(outV); + auto newOp = + rewriter.create(loc, outputType, operands, outputs); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; // namespace onnx_mlir + +void populateLoweringONNXMatMulOpLinalgPattern(RewritePatternSet &patterns, + TypeConverter &typeConverter, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.cpp b/src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.cpp new file mode 100644 index 0000000000..9834292046 --- /dev/null +++ b/src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.cpp @@ -0,0 +1,82 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====----- ONNXToLinalgCommon.cpp - ONNX dialects to Linalg lowering -----===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains common code shared by the functions performing the +// lowering to the KRNL dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.hpp" +#include "src/Accelerators/Accelerator.hpp" +#include "src/Dialect/Krnl/DialectBuilder.hpp" +#include "src/Dialect/Mlir/DialectBuilder.hpp" +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" + +bool ONNXToLinalg_gEmitDealloc = false; + +using namespace mlir; + +namespace onnx_mlir { + +//===----------------------------------------------------------------------===// +// Type conversion from Onnx types to Linalg types. +//===----------------------------------------------------------------------===// + +LinalgTypeConverter::LinalgTypeConverter() { + // The order of type conversion is important: later ones are tried earlier. + addConversion([](Type type) { return type; }); + + addConversion([](ONNXStringType stringType) { + return krnl::StringType::get(stringType.getContext()); + }); + + addConversion([](TensorType tensorType) { + assert(tensorType.hasRank() && "expected only ranked shapes"); + return tensorType; + }); + + addConversion([](SeqType seqType) { + ShapedType seqElementType = seqType.getElementType(); + Type elementType = seqElementType.getElementType(); + Type seqElementConvertedType; + if (seqElementType.hasRank()) { + seqElementConvertedType = + MemRefType::get(seqElementType.getShape(), elementType); + } else { + seqElementConvertedType = UnrankedMemRefType::get(elementType, 0); + } + SmallVector dims; + dims.emplace_back(seqType.getLength()); + llvm::ArrayRef shape(dims.data(), dims.size()); + return MemRefType::get(shape, seqElementConvertedType); + }); + + addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() != 1) + return llvm::None; + + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + + addTargetMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() != 1) + return llvm::None; + + return builder.create(loc, resultType, inputs) + .getResult(0); + }); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.hpp b/src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.hpp new file mode 100644 index 0000000000..b9efa852df --- /dev/null +++ b/src/Conversion/ONNXToLinalg/ONNXToLinalgCommon.hpp @@ -0,0 +1,93 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====------ ONNXToLinalgCommon.hpp - ONNX dialects to Krnl lowering -----===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains common code shared by the functions performing the +// lowering to the KRNL dialect. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "src/Dialect/Krnl/DialectBuilder.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Dialect/Mlir/IndexExpr.hpp" +#include "src/Dialect/ONNX/DialectBuilder.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" +#include "src/Pass/Passes.hpp" +#include "src/Support/KrnlSupport.hpp" +#include "src/Transform/ONNX/ConstPropHelper.hpp" + +// A global variable to indicate whether this pass will emit dealloc for +// allocated memrefs or not during the conversion of ONNX to Krnl. +extern bool ONNXToKrnl_gEmitDealloc; + +namespace onnx_mlir { + +//===----------------------------------------------------------------------===// +// Reuse the type converter for OnnxToKrnl now +// Type conversion from Onnx types to Krnl types: +// - from Tensor type to the Standard dialect MemRef type +// - from onnx.StringType to krnl.StringType +//===----------------------------------------------------------------------===// + +class LinalgTypeConverter : public mlir::TypeConverter { +public: + LinalgTypeConverter(); + + /// Return true if the inputs and outputs of the given function type are + /// legal. [Taken from MLIR and adapted to only check the legality of the + /// inputs. Once unranked results can be handled gracefully this + /// override needs to be removed in favour of the original MLIR one.] + bool isSignatureLegal(mlir::FunctionType funcType) { + return llvm::all_of(llvm::concat( + funcType.getInputs(), funcType.getResults()), + [this](mlir::Type type) { return isLegal(type); }); + } + + /// Return true if the operands/results of call have a legal type. + bool isSignatureLegal(mlir::func::CallOp call) { + auto f = [this](mlir::Type type) { return isLegal(type); }; + return llvm::all_of(call.getOperandTypes(), f) && + llvm::all_of(call.getResultTypes(), f); + } +}; + +//===----------------------------------------------------------------------===// +// Functions to add lowering patterns for frontend operations. +//===----------------------------------------------------------------------===// + +// For all ONNX operations. +void populateONNXToLinalgConversionPattern(mlir::RewritePatternSet &, + mlir::TypeConverter &, mlir::MLIRContext *, bool enableTiling); + +// `Math` directory methods: +void populateLoweringONNXMatMulOpLinalgPattern( + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); + +} // namespace onnx_mlir diff --git a/src/InitOMPasses.hpp b/src/InitOMPasses.hpp index b8c4d9168c..6cb1cd46fc 100644 --- a/src/InitOMPasses.hpp +++ b/src/InitOMPasses.hpp @@ -111,6 +111,10 @@ void initOMPasses(int optLevel) { mlir::registerPass( []() -> std::unique_ptr { return createLowerToMhloPass(); }); #endif + + mlir::registerPass([]() -> std::unique_ptr { + return createLowerONNXToLinalgPass(); + }); } } // namespace onnx_mlir diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index d4ab6444b4..f9aa7795b3 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -69,6 +69,8 @@ std::unique_ptr createLowerToKrnlPass( std::unique_ptr createLowerToMhloPass(); #endif +std::unique_ptr createLowerONNXToLinalgPass(); + /// Pass for lowering krnl.dim operations to standard dialect. std::unique_ptr createDisconnectKrnlDimFromAllocPass(); diff --git a/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp b/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp index 54e6ba2415..ef231a2fd3 100644 --- a/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp +++ b/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -135,6 +136,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); // Initialize accelerators if they exist. onnx_mlir::accel::initAccelerators(maccel);