diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp index 02785f5855af..c9fd8df92ac3 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.cpp @@ -9,8 +9,6 @@ #include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -43,39 +41,5 @@ HALConversionTarget::HALConversionTarget(MLIRContext *context, }); } -// static -LogicalResult HALConversionTarget::applyDefaultBufferRewrite( - Operation *srcOp, ValueRange operands, StringRef dstOpName, - TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { - OperationState state{srcOp->getLoc(), dstOpName}; - state.addAttributes(srcOp->getAttrs()); - - for (auto [srcOperand, dstOperand] : - llvm::zip_equal(srcOp->getOperands(), operands)) { - // Check that any type that should have been mapped to buffer view was. - // This is just to catch conflicts in type conversions that may sneak in - // during development. - assert( - (!HALTypeConverter::shouldConvertToBufferView(srcOperand.getType()) || - dstOperand.getType().isa()) && - "expect that tensors have been mapped to buffer views"); - state.addOperands({dstOperand}); - } - for (auto resultType : srcOp->getResultTypes()) { - if (HALTypeConverter::shouldConvertToBufferView(resultType)) { - state.addTypes(IREE::HAL::BufferViewType::get(rewriter.getContext())); - } else { - // Normal pass-through result. - if (failed(typeConverter.convertType(resultType, state.types))) { - return failure(); - } - } - } - - auto *dstOp = rewriter.create(state); - rewriter.replaceOp(srcOp, dstOp->getResults()); - return success(); -} - } // namespace iree_compiler } // namespace mlir diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h index b41dd1f1a5e6..fd3d4898682d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h @@ -8,7 +8,6 @@ #define IREE_COMPILER_DIALECT_HAL_CONVERSION_CONVERSIONTARGET_H_ #include "iree/compiler/Dialect/HAL/IR/HALTypes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -22,47 +21,6 @@ namespace iree_compiler { class HALConversionTarget : public ConversionTarget { public: HALConversionTarget(MLIRContext *context, TypeConverter &typeConverter); - - // Attempts to rewrite an op that may use tensor values into an op using HAL - // buffers. See HALOpConversion for more information. - static LogicalResult - applyDefaultBufferRewrite(Operation *srcOp, ValueRange operands, - StringRef dstOpName, TypeConverter &typeConverter, - ConversionPatternRewriter &rewriter); -}; - -// HAL tensor-to-buffer conversion utility. -// This can be used by dialects to model custom op conversion from a dialect -// that uses the MLIR tensor type to the IREE HAL buffer type. At this point -// during conversion the source values will be TensorType and the target values -// will be IREE::HAL::BufferTypes. Any static information available about the -// tensor (such as static dimensions, element type, layout, etc) are extracted -// here and lowered as expanded values. -// -// The ABI is currently very basic and will change with the introduction of more -// dynamic shape logic. -// -// Source: -// my.tensor_op(%arg0 : tensor<2x4xf32>) -// Target: -// %arg0_view = hal.buffer_view.create %arg0, ... -// my.buffer_op(%arg0_view : !hal.buffer_view) -template -class HALOpConversion : public OpConversionPattern { -public: - HALOpConversion(MLIRContext *context, TypeConverter &typeConverter) - : OpConversionPattern(context), typeConverter(typeConverter) {} - - LogicalResult - matchAndRewrite(SRC srcOp, typename SRC::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - return HALConversionTarget::applyDefaultBufferRewrite( - srcOp, adaptor.getOperands(), DST::getOperationName(), typeConverter, - rewriter); - } - -protected: - TypeConverter &typeConverter; }; } // namespace iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 8d31a0282297..16e1b468c7eb 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Dialect/Stream/IR/StreamOps.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -23,6 +24,14 @@ namespace mlir { namespace iree_compiler { +static llvm::cl::opt clExternalResourcesMappable( + "iree-stream-external-resources-mappable", + llvm::cl::desc("Allocates external resources as host-visible and mappable. " + "This can degrade performance and introduce allocation " + "overhead and staging buffers for readback on the host " + "should be managed by the calling application instead."), + llvm::cl::init(false)); + namespace { static Value lookupDeviceFor(Operation *op, OpBuilder &builder) { @@ -263,17 +272,21 @@ deriveAllowedResourceBufferBits(Location loc, default: break; case IREE::Stream::Lifetime::External: - // #yolo; these come from/go to outside the program. - // Today we assume they are device-local|host-visible just for - // practical purposes but that does not have to be true. We really - // want this to be something we analyze and handle on the edges - // (transferring devices/etc if needed). - memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal | - IREE::HAL::MemoryTypeBitfield::HostVisible; - // NOTE: we may not map it but users may after they get them back. - // Another reason we should annotate this - having a buffer be - // mappable is potentially expensive (may get a 2nd copy in memory!). - bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping; + if (clExternalResourcesMappable) { + // #yolo; these come from/go to outside the program. + // Today we assume they are device-local|host-visible just for + // practical purposes but that does not have to be true. We really + // want this to be something we analyze and handle on the edges + // (transferring devices/etc if needed). + memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal | + IREE::HAL::MemoryTypeBitfield::HostVisible; + // NOTE: we may not map it but users may after they get them back. + // Another reason we should annotate this - having a buffer be + // mappable is potentially expensive (may get a 2nd copy in memory!). + bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping; + } else { + memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal; + } break; } return success(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir index 2aca6c590a14..d45978d6c273 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir @@ -80,8 +80,8 @@ module attributes {hal.device.targets = [#device_target_cpu]} { %arg1_resource = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<4xf32> in !stream.resource{%c16} // CHECK: %[[RESULT_BUFFER:.+]] = hal.allocator.allocate<%[[ALLOCATOR]] : !hal.allocator> - // CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal") - // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}Mapping{{.+}}") + // CHECK-SAME: type("DeviceVisible|DeviceLocal") + // CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}") // CHECK-SAME: : !hal.buffer{%c16} %result_resource = stream.resource.alloc uninitialized : !stream.resource{%c16} diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp index e7e2bedbf4b6..7254cee0bf99 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp @@ -307,7 +307,27 @@ class ValueResourceUsage : public AbstractResourceUsage { getState() ^= targetUsage.getState(); }) .Case([&](IREE::Stream::TensorImportOp op) { - removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL); + auto targetType = + llvm::cast(op.getResult().getType()); + switch (targetType.getLifetime()) { + default: + case IREE::Stream::Lifetime::External: + removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL); + break; + case IREE::Stream::Lifetime::Staging: + removeAssumedBits(NOT_MUTATED | NOT_STAGING_READ | + NOT_STAGING_WRITE); + break; + case IREE::Stream::Lifetime::Transient: + removeAssumedBits(NOT_MUTATED); + break; + case IREE::Stream::Lifetime::Variable: + removeAssumedBits(NOT_MUTATED | NOT_GLOBAL_READ | NOT_GLOBAL_WRITE); + break; + case IREE::Stream::Lifetime::Constant: + removeAssumedBits(NOT_CONSTANT); + break; + } auto &resultUsage = solver.getElementFor( *this, Position::forValue(op.getResult()), DFX::Resolution::REQUIRED); @@ -497,7 +517,6 @@ class ValueResourceUsage : public AbstractResourceUsage { *this, Position::forValue(op->getOperand(operandIdx)), DFX::Resolution::REQUIRED); getState() ^= operandUsage.getState(); - auto &beforeUsage = solver.getElementFor( *this, Position::forValue(op.getBeforeBody()->getArgument(operandIdx)), @@ -510,13 +529,11 @@ class ValueResourceUsage : public AbstractResourceUsage { *this, Position::forValue(op->getOperand(operandIdx)), DFX::Resolution::REQUIRED); getState() ^= operandUsage.getState(); - auto &parentUsage = solver.getElementFor( *this, Position::forValue(op->getParentOp()->getResult(operandIdx - 1)), DFX::Resolution::REQUIRED); getState() ^= parentUsage.getState(); - if (auto whileOp = dyn_cast_or_null(op->getParentOp())) { auto value = Position::forValue( @@ -532,14 +549,12 @@ class ValueResourceUsage : public AbstractResourceUsage { *this, Position::forValue(op->getOperand(operandIdx)), DFX::Resolution::REQUIRED); getState() ^= operandUsage.getState(); - auto &parentUsage = solver.getElementFor( *this, Position::forValue(op->getParentOp()->getResult(operandIdx)), DFX::Resolution::REQUIRED); getState() ^= parentUsage.getState(); } - if (auto whileOp = dyn_cast_or_null(op->getParentOp())) { auto value = @@ -589,7 +604,33 @@ class ValueResourceUsage : public AbstractResourceUsage { removeAssumedBits(NOT_INDIRECT | NOT_GLOBAL_WRITE); }) .Case([&](IREE::Stream::TensorExportOp op) { - removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL); + auto sourceType = + llvm::cast(op.getSource().getType()); + switch (sourceType.getLifetime()) { + default: + case IREE::Stream::Lifetime::External: + removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL); + break; + case IREE::Stream::Lifetime::Staging: + removeAssumedBits(NOT_MUTATED | NOT_STAGING_READ | + NOT_STAGING_WRITE | NOT_TRANSFER_READ | + NOT_TRANSFER_WRITE); + break; + case IREE::Stream::Lifetime::Transient: + removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_READ | + NOT_TRANSFER_WRITE | NOT_DISPATCH_READ | + NOT_DISPATCH_WRITE); + break; + case IREE::Stream::Lifetime::Variable: + removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_READ | + NOT_TRANSFER_WRITE | NOT_DISPATCH_READ | + NOT_DISPATCH_WRITE); + break; + case IREE::Stream::Lifetime::Constant: + removeAssumedBits(NOT_CONSTANT | NOT_TRANSFER_READ | + NOT_DISPATCH_READ); + break; + } }) .Case([&](IREE::Stream::AsyncCloneOp op) { removeAssumedBits(NOT_TRANSFER_READ); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp index 0ed495546185..b2a6abf26f66 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/Patterns.cpp @@ -249,12 +249,12 @@ struct ConvertTensorTraceOp llvm::zip_equal(op.getOperands(), adaptor.getOperands())) { auto source = consumeTensorOperand(op.getLoc(), resourceOperand, rewriter); - auto externalType = rewriter.getType( - IREE::Stream::Lifetime::External); + auto stagingType = rewriter.getType( + IREE::Stream::Lifetime::Staging); auto exportSource = resourceOperand; - if (source.resource.getType() != externalType) { + if (source.resource.getType() != stagingType) { exportSource = rewriter.create( - op.getLoc(), externalType, source.resource, source.resourceSize, + op.getLoc(), stagingType, source.resource, source.resourceSize, source.resourceSize, /*source_affinity=*/getAffinityFor(op), /*result_affinity=*/nullptr); diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel index 0644bdaac80f..4dcde8c579a9 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/BUILD.bazel @@ -22,6 +22,7 @@ iree_compiler_cc_library( ], deps = [ "//compiler/src/iree/compiler/Dialect/HAL/Conversion", + "//compiler/src/iree/compiler/Dialect/HAL/IR", "//compiler/src/iree/compiler/Dialect/VM/Conversion", "//compiler/src/iree/compiler/Modules/Check/IR", "@llvm-project//mlir:Pass", diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt index 582a6ada281e..c55d7713656e 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/CMakeLists.txt @@ -21,6 +21,7 @@ iree_cc_library( MLIRPass MLIRTransforms iree::compiler::Dialect::HAL::Conversion + iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::VM::Conversion iree::compiler::Modules::Check::IR PUBLIC diff --git a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp index 82da66bd7200..10cdbb35684a 100644 --- a/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp +++ b/compiler/src/iree/compiler/Modules/Check/Conversion/ConversionPatterns.cpp @@ -7,6 +7,8 @@ #include "iree/compiler/Modules/Check/Conversion/ConversionPatterns.h" #include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h" +#include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h" #include "iree/compiler/Modules/Check/IR/CheckOps.h" #include "mlir/Pass/Pass.h" @@ -60,17 +62,90 @@ void populateCheckToVMPatterns(MLIRContext *context, SymbolTable &importSymbols, context, importSymbols, typeConverter, "check.expect_almost_eq"); } +// Attempts to rewrite an op that may use tensor values into an op using HAL +// buffers. +static LogicalResult applyDefaultCheckBufferRewrite( + Operation *srcOp, ValueRange operands, StringRef dstOpName, + TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { + OperationState state{srcOp->getLoc(), dstOpName}; + state.addAttributes(srcOp->getAttrs()); + + // Add device argument. + Value device = rewriter.create(srcOp->getLoc()); + state.addOperands({device}); + + for (auto [srcOperand, dstOperand] : + llvm::zip_equal(srcOp->getOperands(), operands)) { + // Check that any type that should have been mapped to buffer view was. + // This is just to catch conflicts in type conversions that may sneak in + // during development. + assert( + (!HALTypeConverter::shouldConvertToBufferView(srcOperand.getType()) || + dstOperand.getType().isa()) && + "expect that tensors have been mapped to buffer views"); + state.addOperands({dstOperand}); + } + for (auto resultType : srcOp->getResultTypes()) { + if (HALTypeConverter::shouldConvertToBufferView(resultType)) { + state.addTypes(IREE::HAL::BufferViewType::get(rewriter.getContext())); + } else { + // Normal pass-through result. + if (failed(typeConverter.convertType(resultType, state.types))) { + return failure(); + } + } + } + + auto *dstOp = rewriter.create(state); + rewriter.replaceOp(srcOp, dstOp->getResults()); + return success(); +} + +// HAL tensor-to-buffer conversion utility. +// This can be used by dialects to model custom op conversion from a dialect +// that uses the MLIR tensor type to the IREE HAL buffer type. At this point +// during conversion the source values will be TensorType and the target values +// will be IREE::HAL::BufferTypes. Any static information available about the +// tensor (such as static dimensions, element type, layout, etc) are extracted +// here and lowered as expanded values. +// +// The ABI is currently very basic and will change with the introduction of more +// dynamic shape logic. +// +// Source: +// my.tensor_op(%arg0 : tensor<2x4xf32>) +// Target: +// %arg0_view = hal.buffer_view.create %arg0, ... +// my.buffer_op(%arg0_view : !hal.buffer_view) +template +class HALCheckOpConversion : public OpConversionPattern { +public: + HALCheckOpConversion(MLIRContext *context, TypeConverter &typeConverter) + : OpConversionPattern(context), typeConverter(typeConverter) {} + + LogicalResult + matchAndRewrite(SRC srcOp, typename SRC::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return applyDefaultCheckBufferRewrite(srcOp, adaptor.getOperands(), + DST::getOperationName(), + typeConverter, rewriter); + } + +protected: + TypeConverter &typeConverter; +}; + void populateCheckToHALPatterns(MLIRContext *context, RewritePatternSet &patterns, TypeConverter &typeConverter) { // The same op handles both tensors and buffer views. - patterns - .insert, - HALOpConversion, - HALOpConversion>(context, - typeConverter); + patterns.insert< + HALCheckOpConversion, + HALCheckOpConversion, + HALCheckOpConversion>(context, + typeConverter); } } // namespace Check diff --git a/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel b/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel index dff0294ea0b4..e55f3d25bd89 100644 --- a/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel +++ b/compiler/src/iree/compiler/Modules/Check/IR/BUILD.bazel @@ -57,6 +57,7 @@ iree_compiler_cc_library( ":IR", ":check_ops_gen", "//compiler/src/iree/compiler/Dialect/HAL/Conversion", + "//compiler/src/iree/compiler/Dialect/HAL/IR:HALDialect", "//compiler/src/iree/compiler/Dialect/VM/Conversion", "//compiler/src/iree/compiler/Modules/Check:check_imports", "//compiler/src/iree/compiler/Modules/Check/Conversion", diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt b/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt index c3a85740d27e..b0928ce62b9f 100644 --- a/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt +++ b/compiler/src/iree/compiler/Modules/Check/IR/CMakeLists.txt @@ -42,6 +42,7 @@ iree_cc_library( MLIRParser MLIRTransforms iree::compiler::Dialect::HAL::Conversion + iree::compiler::Dialect::HAL::IR::HALDialect iree::compiler::Dialect::VM::Conversion iree::compiler::Modules::Check::Conversion iree::compiler::Modules::Check::check_imports diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp b/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp index dbdb4e19f837..554baa6084e9 100644 --- a/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp +++ b/compiler/src/iree/compiler/Modules/Check/IR/CheckDialect.cpp @@ -7,6 +7,7 @@ #include "iree/compiler/Modules/Check/IR/CheckDialect.h" #include "iree/compiler/Dialect/HAL/Conversion/ConversionDialectInterface.h" +#include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h" #include "iree/compiler/Modules/Check/Conversion/ConversionPatterns.h" #include "iree/compiler/Modules/Check/IR/CheckOps.h" @@ -57,6 +58,8 @@ class CheckToHalConversionInterface : public HALConversionDialectInterface { CheckDialect::CheckDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context, TypeID::get()) { + context->loadDialect(); + addInterfaces(); addInterfaces(); #define GET_OP_LIST diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp index a651bfedcad6..69cfda7f104b 100644 --- a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp +++ b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.cpp @@ -24,7 +24,7 @@ struct ExpandAttributeToConst : public OpRewritePattern { LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { auto rhs = rewriter.create(op.getLoc(), op.getValue()); - rewriter.replaceOpWithNewOp(op, op.getLhs(), rhs); + rewriter.replaceOpWithNewOp(op, op.getDevice(), op.getLhs(), rhs); return success(); } }; diff --git a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td index 9d0b1b335c6c..59c2236c2ada 100644 --- a/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td +++ b/compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td @@ -36,7 +36,6 @@ def CHECK_ExpectTrueOp : Op { let assemblyFormat = "`(` $operand `)` attr-dict `:` type($operand)"; } - def CHECK_ExpectFalseOp : Op { let summary = [{Checks that the operand is false}]; let description = [{ @@ -64,18 +63,24 @@ def CHECK_ExpectAllTrueOp : Op { Issues a non-fatal failure if the verification fails. ```mlir - check.expect_all_true(%arg0) : !hal.buffer_view + check.expect_all_true<%device>(%arg0) : !hal.buffer_view check.expect_all_true(%arg1) : tensor<2x2xi32> ``` }]; - let arguments = - (ins AnyTypeOf<[HAL_BufferView, TensorOf<[AnySignlessInteger]>]>:$operand); + let arguments = (ins + Optional:$device, + AnyTypeOf<[HAL_BufferView, TensorOf<[AnySignlessInteger]>]>:$operand + ); - let assemblyFormat = "`(` $operand `)` attr-dict `:` type($operand)"; + let assemblyFormat = [{ + (`` `<` $device^ `>`)? + `` `(` $operand `)` attr-dict `:` type($operand) + }]; } -def CHECK_ExpectEqOp : Op { +def CHECK_ExpectEqOp : + Op]> { let summary = [{Checks that the tensor or buffer view operands are equal}]; let description = [{ Verifies that the operands are exactly equal. @@ -88,11 +93,15 @@ def CHECK_ExpectEqOp : Op { }]; let arguments = (ins - AnyTypeOf<[HAL_BufferView, AnyTensor]>:$lhs, - AnyTypeOf<[HAL_BufferView, AnyTensor]>:$rhs + Optional:$device, + AnyTypeOf<[HAL_BufferView, AnyTensor]>:$lhs, + AnyTypeOf<[HAL_BufferView, AnyTensor]>:$rhs ); - let assemblyFormat = "`(` $lhs `,` $rhs `)` attr-dict `:` type($lhs)"; + let assemblyFormat = [{ + (`` `<` $device^ `>`)? + `` `(` $lhs `,` $rhs `)` attr-dict `:` type($lhs) + }]; } def CHECK_ExpectEqConstOp : @@ -111,17 +120,21 @@ def CHECK_ExpectEqConstOp : }]; let arguments = (ins + Optional:$device, AnyTensor:$lhs, ElementsAttr:$value ); let hasCanonicalizer = 1; - let assemblyFormat = "`(` $lhs `,` $value `)` attr-dict `:` type($lhs)"; + let assemblyFormat = [{ + (`` `<` $device^ `>`)? + `` `(` $lhs `,` $value `)` attr-dict `:` type($lhs) + }]; } def CHECK_ExpectAlmostEqOp : - Op { + Op]> { let summary = [{Checks that the operands are almost equal}]; let description = [{ Verifies that the buffer view or tensor operands with float elements are @@ -135,11 +148,15 @@ def CHECK_ExpectAlmostEqOp : }]; let arguments = (ins - AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$lhs, - AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$rhs + Optional:$device, + AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$lhs, + AnyTypeOf<[HAL_BufferView, TensorOf<[AnyFloat]>]>:$rhs ); - let assemblyFormat = "`(` $lhs `,` $rhs `)` attr-dict `:` type($lhs)"; + let assemblyFormat = [{ + (`` `<` $device^ `>`)? + `` `(` $lhs `,` $rhs `)` attr-dict `:` type($lhs) + }]; } def CHECK_ExpectAlmostEqConstOp : @@ -160,13 +177,17 @@ def CHECK_ExpectAlmostEqConstOp : }]; let arguments = (ins + Optional:$device, TensorOf<[AnyFloat]>:$lhs, ElementsAttr:$value ); let hasCanonicalizer = 1; - let assemblyFormat = "`(` $lhs `,` $value `)` attr-dict `:` type($lhs)"; + let assemblyFormat = [{ + (`` `<` $device^ `>`)? + `` `(` $lhs `,` $value `)` attr-dict `:` type($lhs) + }]; } #endif // IREE_MODULES_CHECK_DIALECT_CHECK_OPS diff --git a/compiler/src/iree/compiler/Modules/Check/check.imports.mlir b/compiler/src/iree/compiler/Modules/Check/check.imports.mlir index 67bae93437b7..63b9d72392b6 100644 --- a/compiler/src/iree/compiler/Modules/Check/check.imports.mlir +++ b/compiler/src/iree/compiler/Modules/Check/check.imports.mlir @@ -15,15 +15,18 @@ vm.import private optional @expect_false( ) vm.import private optional @expect_all_true( + %device : !vm.ref, %operand : !vm.ref, ) vm.import private optional @expect_eq( + %device : !vm.ref, %lhs : !vm.ref, %rhs : !vm.ref ) vm.import private optional @expect_almost_eq( + %device : !vm.ref, %lhs : !vm.ref, %rhs : !vm.ref ) diff --git a/experimental/cuda2/cuda_device.c b/experimental/cuda2/cuda_device.c index 370b693ea0cf..404dd357a7cd 100644 --- a/experimental/cuda2/cuda_device.c +++ b/experimental/cuda2/cuda_device.c @@ -625,7 +625,7 @@ static iree_status_t iree_hal_cuda2_device_queue_alloca( // allocator is set on the device. iree_status_t status = iree_ok_status(); if (device->supports_memory_pools && - !iree_any_bit_set(params.access, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { status = iree_hal_cuda2_memory_pools_alloca( &device->memory_pools, device->dispatch_cu_stream, pool, params, allocation_size, out_buffer); diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c index 2c5a91b6caf4..ea38514abdc7 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c @@ -570,7 +570,7 @@ static iree_status_t iree_hal_cuda_device_queue_alloca( // allocator is set on the device. iree_status_t status = iree_ok_status(); if (device->supports_memory_pools && - !iree_any_bit_set(params.access, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { status = iree_hal_cuda_memory_pools_alloca(&device->memory_pools, device->stream, pool, params, allocation_size, out_buffer); diff --git a/runtime/src/iree/modules/check/check_test.cc b/runtime/src/iree/modules/check/check_test.cc index 67f194718bc5..7623fb5df562 100644 --- a/runtime/src/iree/modules/check/check_test.cc +++ b/runtime/src/iree/modules/check/check_test.cc @@ -197,6 +197,9 @@ class CheckTest : public ::testing::Test { IREE_RETURN_IF_ERROR( iree_vm_list_create(iree_vm_make_undefined_type_def(), args.size(), iree_allocator_system(), &inputs_)); + iree_vm_ref_t device_ref = iree_hal_device_retain_ref(device_); + IREE_RETURN_IF_ERROR( + iree_vm_list_push_ref_move(inputs_.get(), &device_ref)); for (auto& arg : args) { iree_vm_ref_t arg_ref = iree_hal_buffer_view_move_ref(arg.get()); IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(inputs_.get(), &arg_ref)); diff --git a/runtime/src/iree/modules/check/module.cc b/runtime/src/iree/modules/check/module.cc index b417eef5ac8c..edbb9fe99476 100644 --- a/runtime/src/iree/modules/check/module.cc +++ b/runtime/src/iree/modules/check/module.cc @@ -155,6 +155,100 @@ Status ExpectAllTrue(iree_byte_span_t bytes, "unsupported element type %s", element_type_str); } +static StatusOr>> +TransferBuffersToHost( + iree_hal_device_t* device, + const iree::span> source_views) { + IREE_TRACE_SCOPE(); + + // If all buffers are already host-accessible we can skip the transfer. + std::vector> target_views; + bool requires_transfer = false; + for (auto& source_view : source_views) { + iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(source_view.get()); + if (!iree_all_bits_set(iree_hal_buffer_memory_type(buffer), + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) || + !iree_all_bits_set(iree_hal_buffer_allowed_usage(buffer), + IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)) { + requires_transfer = true; + } + } + if (!requires_transfer) { + for (auto& source_view : source_views) target_views.push_back(source_view); + return std::move(target_views); + } + + vm::ref command_buffer; + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create( + device, + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT | + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION, + IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_HAL_QUEUE_AFFINITY_ANY, 0, + &command_buffer)); + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_begin(command_buffer.get())); + + iree_hal_buffer_params_t target_params = { + /*.usage=*/IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + /*.access=*/IREE_HAL_MEMORY_ACCESS_ALL, + /*.type=*/ + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + /*.queue_affinity=*/IREE_HAL_QUEUE_AFFINITY_ANY, + /*.min_alignment=*/0, + }; + for (size_t i = 0; i < source_views.size(); ++i) { + iree_hal_buffer_t* source_buffer = + iree_hal_buffer_view_buffer(source_views[i].get()); + iree_device_size_t buffer_length = + iree_hal_buffer_byte_length(source_buffer); + vm::ref target_buffer; + IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( + iree_hal_device_allocator(device), target_params, buffer_length, + &target_buffer)); + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_copy_buffer( + command_buffer.get(), source_buffer, 0, target_buffer.get(), 0, + buffer_length)); + vm::ref target_view; + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_create_like( + target_buffer.get(), source_views[i].get(), + iree_hal_device_host_allocator(device), &target_view)); + target_views.push_back(std::move(target_view)); + } + + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_end(command_buffer.get())); + vm::ref semaphore; + IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(device, 0ull, &semaphore)); + vm::ref fence; + IREE_RETURN_IF_ERROR(iree_hal_fence_create_at( + semaphore.get(), 1ull, iree_hal_device_host_allocator(device), &fence)); + IREE_RETURN_IF_ERROR(iree_hal_device_queue_execute( + device, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(), + iree_hal_fence_semaphore_list(fence.get()), 1, &command_buffer)); + IREE_RETURN_IF_ERROR( + iree_hal_fence_wait(fence.get(), iree_infinite_timeout())); + return std::move(target_views); +} + +static Status TransferToHost(iree_hal_device_t* device, + vm::ref& buffer_view) { + IREE_TRACE_SCOPE(); + IREE_ASSIGN_OR_RETURN(auto target_views, + TransferBuffersToHost(device, {buffer_view})); + buffer_view = std::move(target_views[0]); + return OkStatus(); +} + +static Status TransferToHost(iree_hal_device_t* device, + vm::ref& buffer_view_a, + vm::ref& buffer_view_b) { + IREE_TRACE_SCOPE(); + IREE_ASSIGN_OR_RETURN( + auto target_views, + TransferBuffersToHost(device, {buffer_view_a, buffer_view_b})); + buffer_view_a = std::move(target_views[0]); + buffer_view_b = std::move(target_views[1]); + return OkStatus(); +} + // Per-context module state. // This can contain "globals" and other arbitrary state. // @@ -177,7 +271,9 @@ class CheckModuleState final { return OkStatus(); } - Status ExpectAllTrue(vm::ref operand) { + Status ExpectAllTrue(vm::ref device, + vm::ref operand) { + IREE_RETURN_IF_ERROR(TransferToHost(device.get(), operand)); auto* view = operand.get(); iree_hal_element_type_t element_type = iree_hal_buffer_view_element_type(view); @@ -193,8 +289,10 @@ class CheckModuleState final { return OkStatus(); } - Status ExpectEq(vm::ref lhs_ref, + Status ExpectEq(vm::ref device, + vm::ref lhs_ref, vm::ref rhs_ref) { + IREE_RETURN_IF_ERROR(TransferToHost(device.get(), lhs_ref, rhs_ref)); auto* lhs = lhs_ref.get(); auto* rhs = rhs_ref.get(); @@ -272,8 +370,10 @@ class CheckModuleState final { return OkStatus(); } - Status ExpectAlmostEq(vm::ref lhs_ref, + Status ExpectAlmostEq(vm::ref device, + vm::ref lhs_ref, vm::ref rhs_ref) { + IREE_RETURN_IF_ERROR(TransferToHost(device.get(), lhs_ref, rhs_ref)); auto* lhs = lhs_ref.get(); auto* rhs = rhs_ref.get(); diff --git a/runtime/src/iree/modules/check/test/success.mlir b/runtime/src/iree/modules/check/test/success.mlir index ff5aa8e8d599..40d8bc33e733 100644 --- a/runtime/src/iree/modules/check/test/success.mlir +++ b/runtime/src/iree/modules/check/test/success.mlir @@ -14,9 +14,10 @@ func.func @expect_false() { } func.func @expect_all_true() { + %device = hal.ex.shared_device : !hal.device %all_true = util.unfoldable_constant dense<1> : tensor<2x2xi32> %all_true_view = hal.tensor.export %all_true : tensor<2x2xi32> -> !hal.buffer_view - check.expect_all_true(%all_true_view) : !hal.buffer_view + check.expect_all_true<%device>(%all_true_view) : !hal.buffer_view return } diff --git a/runtime/src/iree/modules/hal/types.c b/runtime/src/iree/modules/hal/types.c index 0c7e0d7900f9..52ce5a281a52 100644 --- a/runtime/src/iree/modules/hal/types.c +++ b/runtime/src/iree/modules/hal/types.c @@ -205,7 +205,7 @@ IREE_API_EXPORT iree_hal_buffer_t* iree_vm_list_get_buffer_retain( IREE_API_EXPORT iree_status_t iree_vm_list_set_buffer_retain( iree_vm_list_t* list, iree_host_size_t i, iree_hal_buffer_t* value) { - iree_vm_ref_t value_ref; + iree_vm_ref_t value_ref = iree_vm_ref_null(); IREE_RETURN_IF_ERROR( iree_vm_ref_wrap_assign(value, iree_hal_buffer_type(), &value_ref)); return iree_vm_list_set_ref_retain(list, i, &value_ref); @@ -226,7 +226,7 @@ IREE_API_EXPORT iree_hal_buffer_view_t* iree_vm_list_get_buffer_view_retain( IREE_API_EXPORT iree_status_t iree_vm_list_set_buffer_view_retain( iree_vm_list_t* list, iree_host_size_t i, iree_hal_buffer_view_t* value) { - iree_vm_ref_t value_ref; + iree_vm_ref_t value_ref = iree_vm_ref_null(); IREE_RETURN_IF_ERROR( iree_vm_ref_wrap_assign(value, iree_hal_buffer_view_type(), &value_ref)); return iree_vm_list_set_ref_retain(list, i, &value_ref); @@ -247,7 +247,7 @@ IREE_API_EXPORT iree_hal_fence_t* iree_vm_list_get_fence_retain( IREE_API_EXPORT iree_status_t iree_vm_list_set_fence_retain( iree_vm_list_t* list, iree_host_size_t i, iree_hal_fence_t* value) { - iree_vm_ref_t value_ref; + iree_vm_ref_t value_ref = iree_vm_ref_null(); IREE_RETURN_IF_ERROR( iree_vm_ref_wrap_assign(value, iree_hal_fence_type(), &value_ref)); return iree_vm_list_set_ref_retain(list, i, &value_ref); diff --git a/runtime/src/iree/tooling/run_module.c b/runtime/src/iree/tooling/run_module.c index ad5e674fedb1..2af3db4f891b 100644 --- a/runtime/src/iree/tooling/run_module.c +++ b/runtime/src/iree/tooling/run_module.c @@ -246,6 +246,22 @@ static iree_status_t iree_tooling_run_function( "processing instrument data"); } + // Transfer outputs to the host so they can be processed. Only required when + // using full HAL device-based execution. + if (iree_status_is_ok(status) && device != NULL) { + iree_hal_buffer_params_t target_params = { + .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + .access = IREE_HAL_MEMORY_ACCESS_ALL, + .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY, + .min_alignment = 0, + }; + status = iree_tooling_transfer_variant_list( + device, outputs, device_allocator, target_params, + /*wait_fence=*/NULL, /*signal_fence=*/NULL); + } + // Handle either printing/writing the outputs or checking them against // expected values (basic pass/fail testing). if (iree_status_is_ok(status)) { diff --git a/runtime/src/iree/tooling/vm_util.c b/runtime/src/iree/tooling/vm_util.c index b21eadaee78a..70e2e777bf97 100644 --- a/runtime/src/iree/tooling/vm_util.c +++ b/runtime/src/iree/tooling/vm_util.c @@ -324,6 +324,187 @@ iree_status_t iree_tooling_append_async_fence_inputs( return status; } +static bool iree_tooling_requires_buffer_transfer( + iree_hal_buffer_t* source_buffer, iree_hal_buffer_params_t target_params) { + return !iree_all_bits_set(iree_hal_buffer_memory_type(source_buffer), + target_params.type) || + !iree_all_bits_set(iree_hal_buffer_allowed_usage(source_buffer), + target_params.usage); +} + +static iree_status_t iree_tooling_setup_buffer_transfer( + iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer, + iree_hal_allocator_t* target_allocator, + iree_hal_buffer_params_t target_params, + iree_hal_buffer_t** out_target_buffer) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(source_buffer); + IREE_ASSERT_ARGUMENT(target_allocator); + IREE_ASSERT_ARGUMENT(out_target_buffer); + *out_target_buffer = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_buffer_t* target_buffer = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_allocator_allocate_buffer( + target_allocator, target_params, + iree_hal_buffer_allocation_size(source_buffer), &target_buffer)); + + iree_status_t status = iree_hal_command_buffer_copy_buffer( + command_buffer, source_buffer, 0, target_buffer, 0, + iree_hal_buffer_byte_length(source_buffer)); + + if (iree_status_is_ok(status)) { + *out_target_buffer = target_buffer; + } else { + iree_hal_buffer_release(target_buffer); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_tooling_submit_transfer( + iree_hal_device_t* device, iree_hal_fence_t* wait_fence, + iree_hal_queue_affinity_t queue_affinity, + iree_hal_command_buffer_t* command_buffer, iree_hal_fence_t* signal_fence) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + + bool needs_wait = signal_fence == NULL; + if (needs_wait) { + iree_hal_semaphore_t* semaphore = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_semaphore_create(device, 0ull, &semaphore)); + status = iree_hal_fence_create_at( + semaphore, 1ull, iree_hal_device_host_allocator(device), &signal_fence); + iree_hal_semaphore_release(semaphore); + } else { + iree_hal_fence_retain(signal_fence); + } + + if (iree_status_is_ok(status)) { + status = iree_hal_device_queue_execute( + device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence), + iree_hal_fence_semaphore_list(signal_fence), 1, &command_buffer); + } + + if (iree_status_is_ok(status) && needs_wait) { + status = iree_hal_fence_wait(signal_fence, iree_infinite_timeout()); + } + + iree_hal_fence_release(signal_fence); + IREE_TRACE_ZONE_END(z0); + return status; +} + +iree_status_t iree_tooling_transfer_variant_list( + iree_hal_device_t* device, iree_vm_list_t* list, + iree_hal_allocator_t* target_allocator, + iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence, + iree_hal_fence_t* signal_fence) { + IREE_ASSERT_ARGUMENT(device); + IREE_ASSERT_ARGUMENT(list); + IREE_ASSERT_ARGUMENT(target_allocator); + IREE_TRACE_ZONE_BEGIN(z0); + + // If all buffers are already host-accessible we can skip the transfer. + bool requires_transfer = false; + for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { + iree_vm_ref_t value = iree_vm_ref_null(); + IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value)); + if (iree_hal_buffer_isa(value)) { + iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value); + if (iree_tooling_requires_buffer_transfer(source_buffer, target_params)) { + requires_transfer = true; + break; + } + } else if (iree_hal_buffer_view_isa(value)) { + iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value); + iree_hal_buffer_t* source_buffer = + iree_hal_buffer_view_buffer(source_view); + if (iree_tooling_requires_buffer_transfer(source_buffer, target_params)) { + requires_transfer = true; + break; + } + } + } + if (!requires_transfer) { + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); + } + + iree_hal_command_buffer_t* command_buffer = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_command_buffer_create( + device, + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT | + IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION, + IREE_HAL_COMMAND_CATEGORY_TRANSFER, target_params.queue_affinity, + /*binding_capacity=*/0, &command_buffer)); + + iree_status_t status = iree_hal_command_buffer_begin(command_buffer); + if (iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < iree_vm_list_size(list); ++i) { + iree_vm_ref_t value = iree_vm_ref_null(); + IREE_IGNORE_ERROR(iree_vm_list_get_ref_assign(list, i, &value)); + if (iree_hal_buffer_isa(value)) { + iree_hal_buffer_t* source_buffer = iree_hal_buffer_deref(value); + if (!iree_tooling_requires_buffer_transfer(source_buffer, + target_params)) { + // Already ok. + continue; + } + iree_hal_buffer_t* target_buffer = NULL; + status = iree_tooling_setup_buffer_transfer( + command_buffer, source_buffer, target_allocator, target_params, + &target_buffer); + if (!iree_status_is_ok(status)) break; + status = iree_vm_list_set_buffer_retain(list, i, target_buffer); + iree_hal_buffer_release(target_buffer); + if (!iree_status_is_ok(status)) break; + } else if (iree_hal_buffer_view_isa(value)) { + iree_hal_buffer_view_t* source_view = iree_hal_buffer_view_deref(value); + iree_hal_buffer_t* source_buffer = + iree_hal_buffer_view_buffer(source_view); + if (!iree_tooling_requires_buffer_transfer(source_buffer, + target_params)) { + // Already ok. + continue; + } + iree_hal_buffer_t* target_buffer = NULL; + status = iree_tooling_setup_buffer_transfer( + command_buffer, source_buffer, target_allocator, target_params, + &target_buffer); + if (!iree_status_is_ok(status)) break; + iree_hal_buffer_view_t* target_view = NULL; + status = iree_hal_buffer_view_create_like( + target_buffer, source_view, + iree_hal_allocator_host_allocator(target_allocator), &target_view); + iree_hal_buffer_release(target_buffer); + if (!iree_status_is_ok(status)) break; + status = iree_vm_list_set_buffer_view_retain(list, i, target_view); + iree_hal_buffer_view_release(target_view); + if (!iree_status_is_ok(status)) break; + } + } + } + if (iree_status_is_ok(status)) { + status = iree_hal_command_buffer_end(command_buffer); + } + + if (iree_status_is_ok(status)) { + status = iree_tooling_submit_transfer(device, wait_fence, + target_params.queue_affinity, + command_buffer, signal_fence); + } + + iree_hal_command_buffer_release(command_buffer); + + IREE_TRACE_ZONE_END(z0); + return status; +} + #define IREE_PRINTVARIANT_CASE_I(SIZE, B, V) \ case IREE_VM_VALUE_TYPE_I##SIZE: \ return iree_string_builder_append_format( \ diff --git a/runtime/src/iree/tooling/vm_util.h b/runtime/src/iree/tooling/vm_util.h index e2a031192378..bc9ca008236b 100644 --- a/runtime/src/iree/tooling/vm_util.h +++ b/runtime/src/iree/tooling/vm_util.h @@ -54,6 +54,16 @@ iree_status_t iree_tooling_append_async_fence_inputs( iree_hal_device_t* device, iree_hal_fence_t* wait_fence, iree_hal_fence_t** out_signal_fence); +// Transfers all buffers in |list| to ones using |target_params|. +// If no |wait_fence| is provided then the transfer will begin immediately. +// If no |signal_fence| is provided then the call will block until the transfer +// completes. +iree_status_t iree_tooling_transfer_variant_list( + iree_hal_device_t* device, iree_vm_list_t* list, + iree_hal_allocator_t* target_allocator, + iree_hal_buffer_params_t target_params, iree_hal_fence_t* wait_fence, + iree_hal_fence_t* signal_fence); + // Appends a variant list of VM scalars and buffers to |builder|. // |list_name| will be printed alongside each element ordinal. // diff --git a/tools/BUILD.bazel b/tools/BUILD.bazel index 3b5925091e1d..2ea9aa24424c 100644 --- a/tools/BUILD.bazel +++ b/tools/BUILD.bazel @@ -210,6 +210,7 @@ iree_runtime_cc_binary( "//runtime/src/iree/modules/hal", "//runtime/src/iree/tooling:device_util", "//runtime/src/iree/tooling:trace_replay", + "//runtime/src/iree/tooling:vm_util", "//runtime/src/iree/tooling:yaml_util", "//runtime/src/iree/vm", "@com_github_yaml_libyaml//:yaml", diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 3cf3a0ab1214..2445774a92fa 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -215,6 +215,7 @@ iree_cc_binary( iree::modules::hal iree::tooling::device_util iree::tooling::trace_replay + iree::tooling::vm_util iree::tooling::yaml_util iree::vm yaml diff --git a/tools/iree-e2e-matmul-test.c b/tools/iree-e2e-matmul-test.c index d1082a216bb0..5730608682d2 100644 --- a/tools/iree-e2e-matmul-test.c +++ b/tools/iree-e2e-matmul-test.c @@ -19,6 +19,7 @@ #include "iree/modules/hal/module.h" #include "iree/tooling/device_util.h" #include "iree/tooling/trace_replay.h" +#include "iree/tooling/vm_util.h" #include "iree/tooling/yaml_util.h" #include "iree/vm/api.h" @@ -192,10 +193,8 @@ static iree_status_t map_host_local_row_major_data( iree_hal_buffer_view_t* buffer_view, enum iree_hal_memory_access_bits_t access, iree_hal_buffer_mapping_t* mapping) { - // Really validate host-local, not just host-visible: callers may rely on - // host-coherency. IREE_RETURN_IF_ERROR( - validate_memory_type(buffer_view, IREE_HAL_MEMORY_TYPE_HOST_LOCAL)); + validate_memory_type(buffer_view, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); if (iree_hal_buffer_view_encoding_type(buffer_view) != IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR) { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, @@ -1014,42 +1013,46 @@ static iree_status_t do_matmul_and_check_results( replay->device, device_allocator, device_inputs, &host_inputs)); // Invoke the function to produce the actual result. - iree_vm_list_t* device_outputs = NULL; + iree_vm_list_t* outputs = NULL; IREE_CHECK_OK(iree_vm_list_create(iree_vm_make_undefined_type_def(), /*initial_capacity=*/8, - replay->host_allocator, &device_outputs)); + replay->host_allocator, &outputs)); IREE_CHECK_OK(iree_vm_invoke( replay->context, function, IREE_VM_INVOCATION_FLAG_NONE, - /*policy=*/NULL, device_inputs, device_outputs, replay->host_allocator)); + /*policy=*/NULL, device_inputs, outputs, replay->host_allocator)); iree_vm_list_release(device_inputs); - // Get the device_actual_result from the device_outputs. - iree_hal_buffer_view_t* device_actual_result; - IREE_CHECK_OK( - get_item_as_buffer_view(device_outputs, 0, &device_actual_result)); + // Transfer device buffers to host buffers. + iree_hal_buffer_params_t host_params = { + .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + .access = IREE_HAL_MEMORY_ACCESS_ALL, + .type = + IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY, + .min_alignment = 0, + }; + IREE_CHECK_OK(iree_tooling_transfer_variant_list( + replay->device, outputs, device_allocator, host_params, + /*wait_fence=*/NULL, /*signal_fence=*/NULL)); - // Copy the results to a host local buffer to be able to map it. - iree_hal_buffer_view_t* host_actual_result = NULL; - IREE_CHECK_OK(copy_device_buffer_view_to_host( - replay->device, device_allocator, device_actual_result, - &host_actual_result)); + // Get the actual result computed by the program. + iree_hal_buffer_view_t* actual_result; + IREE_CHECK_OK(get_item_as_buffer_view(outputs, 0, &actual_result)); - // Allocate host_expected_result with same shape as host_actual_result. + // Allocate host_expected_result with same shape as actual_result. iree_hal_buffer_view_t* host_expected_result = NULL; - IREE_CHECK_OK(allocate_host_buffer_view_like(replay->device, device_allocator, - host_actual_result, - &host_expected_result)); + IREE_CHECK_OK(allocate_host_buffer_view_like( + replay->device, device_allocator, actual_result, &host_expected_result)); // Use the reference matmul implementation to fill host_expected_result IREE_CHECK_OK(reference_matmul(host_inputs, host_expected_result)); - // Check that host_actual_result and host_expected_result agree. - iree_status_t status = check_matmul_results( - file, host_inputs, host_actual_result, host_expected_result); + // Check that actual_result and host_expected_result agree. + iree_status_t status = check_matmul_results(file, host_inputs, actual_result, + host_expected_result); - iree_vm_list_release(device_outputs); // releases device_actual_result + iree_vm_list_release(outputs); // releases actual_result iree_vm_list_release(host_inputs); - iree_hal_buffer_view_release(host_actual_result); iree_hal_buffer_view_release(host_expected_result); return status; } diff --git a/tools/iree-run-trace-main.c b/tools/iree-run-trace-main.c index fa46810c28df..b1b39dc6caaa 100644 --- a/tools/iree-run-trace-main.c +++ b/tools/iree-run-trace-main.c @@ -197,6 +197,21 @@ static iree_status_t iree_run_trace_file(iree_string_view_t root_path, yaml_parser_delete(&parser); + // Transfer outputs to the host so they can be processed. + if (iree_status_is_ok(status) && replay.device != NULL) { + iree_hal_buffer_params_t target_params = { + .usage = IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, + .access = IREE_HAL_MEMORY_ACCESS_ALL, + .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL | + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + .queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY, + .min_alignment = 0, + }; + status = iree_tooling_transfer_variant_list( + replay.device, replay.outputs, iree_hal_device_allocator(replay.device), + target_params, /*wait_fence=*/NULL, /*signal_fence=*/NULL); + } + // Optionally process outputs from the replay session. if (iree_status_is_ok(status)) { if (FLAG_output_list().count == 0) {