Skip to content

Commit

Permalink
Switching external resources to be device-local only. (iree-org#14016)
Browse files Browse the repository at this point in the history
Previously all external resources (results returned by an invocation)
were made host-visible and mappable and this prevented the use of
queue-ordered allocations in CUDA as memory pools cannot service memory
with associated host pointers. Depending on device the host-visible
memory could also be much slower to access (or have more potential
pitfalls with page management) vs pinned device-local memory and this
got worse once we started doing more dispatches in-place on the results.

Now all external buffers are by default allocated as device-local. Users
will need to manually stage the buffers and otherwise they'll remain
on-device. For externalized state this is a good thing as it means we'll
keep state on device automatically. A temporary flag has been added to
revert to the old mappable behavior with
`--iree-stream-external-resources-mappable=true`. Note that some devices
(like CPU) will always allow mapping even if not requested and users can
avoid the copies by checking before performing the transfers.

GPT2 CUDA post-change with alloca and no caching allocator enabled
(~5us/invocation allocation overhead):

![image](https://github.com/openxla/iree/assets/75337/5f7f589d-b602-49b3-96c6-5c9dfa6578fe)
  • Loading branch information
benvanik authored Oct 19, 2023
1 parent 87c968c commit 63381a8
Show file tree
Hide file tree
Showing 27 changed files with 567 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
#include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"

Expand Down Expand Up @@ -43,39 +41,5 @@ HALConversionTarget::HALConversionTarget(MLIRContext *context,
});
}

// static
LogicalResult HALConversionTarget::applyDefaultBufferRewrite(
Operation *srcOp, ValueRange operands, StringRef dstOpName,
TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
OperationState state{srcOp->getLoc(), dstOpName};
state.addAttributes(srcOp->getAttrs());

for (auto [srcOperand, dstOperand] :
llvm::zip_equal(srcOp->getOperands(), operands)) {
// Check that any type that should have been mapped to buffer view was.
// This is just to catch conflicts in type conversions that may sneak in
// during development.
assert(
(!HALTypeConverter::shouldConvertToBufferView(srcOperand.getType()) ||
dstOperand.getType().isa<IREE::HAL::BufferViewType>()) &&
"expect that tensors have been mapped to buffer views");
state.addOperands({dstOperand});
}
for (auto resultType : srcOp->getResultTypes()) {
if (HALTypeConverter::shouldConvertToBufferView(resultType)) {
state.addTypes(IREE::HAL::BufferViewType::get(rewriter.getContext()));
} else {
// Normal pass-through result.
if (failed(typeConverter.convertType(resultType, state.types))) {
return failure();
}
}
}

auto *dstOp = rewriter.create(state);
rewriter.replaceOp(srcOp, dstOp->getResults());
return success();
}

} // namespace iree_compiler
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#define IREE_COMPILER_DIALECT_HAL_CONVERSION_CONVERSIONTARGET_H_

#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
Expand All @@ -22,47 +21,6 @@ namespace iree_compiler {
class HALConversionTarget : public ConversionTarget {
public:
HALConversionTarget(MLIRContext *context, TypeConverter &typeConverter);

// Attempts to rewrite an op that may use tensor values into an op using HAL
// buffers. See HALOpConversion for more information.
static LogicalResult
applyDefaultBufferRewrite(Operation *srcOp, ValueRange operands,
StringRef dstOpName, TypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
};

// HAL tensor-to-buffer conversion utility.
// This can be used by dialects to model custom op conversion from a dialect
// that uses the MLIR tensor type to the IREE HAL buffer type. At this point
// during conversion the source values will be TensorType and the target values
// will be IREE::HAL::BufferTypes. Any static information available about the
// tensor (such as static dimensions, element type, layout, etc) are extracted
// here and lowered as expanded values.
//
// The ABI is currently very basic and will change with the introduction of more
// dynamic shape logic.
//
// Source:
// my.tensor_op(%arg0 : tensor<2x4xf32>)
// Target:
// %arg0_view = hal.buffer_view.create %arg0, ...
// my.buffer_op(%arg0_view : !hal.buffer_view)
template <typename SRC, typename DST>
class HALOpConversion : public OpConversionPattern<SRC> {
public:
HALOpConversion(MLIRContext *context, TypeConverter &typeConverter)
: OpConversionPattern<SRC>(context), typeConverter(typeConverter) {}

LogicalResult
matchAndRewrite(SRC srcOp, typename SRC::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return HALConversionTarget::applyDefaultBufferRewrite(
srcOp, adaptor.getOperands(), DST::getOperationName(), typeConverter,
rewriter);
}

protected:
TypeConverter &typeConverter;
};

} // namespace iree_compiler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
Expand All @@ -23,6 +24,14 @@
namespace mlir {
namespace iree_compiler {

static llvm::cl::opt<bool> clExternalResourcesMappable(
"iree-stream-external-resources-mappable",
llvm::cl::desc("Allocates external resources as host-visible and mappable. "
"This can degrade performance and introduce allocation "
"overhead and staging buffers for readback on the host "
"should be managed by the calling application instead."),
llvm::cl::init(false));

namespace {

static Value lookupDeviceFor(Operation *op, OpBuilder &builder) {
Expand Down Expand Up @@ -263,17 +272,21 @@ deriveAllowedResourceBufferBits(Location loc,
default:
break;
case IREE::Stream::Lifetime::External:
// #yolo; these come from/go to outside the program.
// Today we assume they are device-local|host-visible just for
// practical purposes but that does not have to be true. We really
// want this to be something we analyze and handle on the edges
// (transferring devices/etc if needed).
memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal |
IREE::HAL::MemoryTypeBitfield::HostVisible;
// NOTE: we may not map it but users may after they get them back.
// Another reason we should annotate this - having a buffer be
// mappable is potentially expensive (may get a 2nd copy in memory!).
bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping;
if (clExternalResourcesMappable) {
// #yolo; these come from/go to outside the program.
// Today we assume they are device-local|host-visible just for
// practical purposes but that does not have to be true. We really
// want this to be something we analyze and handle on the edges
// (transferring devices/etc if needed).
memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal |
IREE::HAL::MemoryTypeBitfield::HostVisible;
// NOTE: we may not map it but users may after they get them back.
// Another reason we should annotate this - having a buffer be
// mappable is potentially expensive (may get a 2nd copy in memory!).
bufferUsage = bufferUsage | IREE::HAL::BufferUsageBitfield::Mapping;
} else {
memoryTypes = memoryTypes | IREE::HAL::MemoryTypeBitfield::DeviceLocal;
}
break;
}
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ module attributes {hal.device.targets = [#device_target_cpu]} {
%arg1_resource = stream.tensor.import %arg1 : !hal.buffer_view -> tensor<4xf32> in !stream.resource<external>{%c16}

// CHECK: %[[RESULT_BUFFER:.+]] = hal.allocator.allocate<%[[ALLOCATOR]] : !hal.allocator>
// CHECK-SAME: type("HostVisible|DeviceVisible|DeviceLocal")
// CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}Mapping{{.+}}")
// CHECK-SAME: type("DeviceVisible|DeviceLocal")
// CHECK-SAME: usage("{{.+}}Transfer{{.+}}Dispatch{{.+}}")
// CHECK-SAME: : !hal.buffer{%c16}
%result_resource = stream.resource.alloc uninitialized : !stream.resource<external>{%c16}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,27 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
getState() ^= targetUsage.getState();
})
.Case([&](IREE::Stream::TensorImportOp op) {
removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
auto targetType =
llvm::cast<IREE::Stream::ResourceType>(op.getResult().getType());
switch (targetType.getLifetime()) {
default:
case IREE::Stream::Lifetime::External:
removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
break;
case IREE::Stream::Lifetime::Staging:
removeAssumedBits(NOT_MUTATED | NOT_STAGING_READ |
NOT_STAGING_WRITE);
break;
case IREE::Stream::Lifetime::Transient:
removeAssumedBits(NOT_MUTATED);
break;
case IREE::Stream::Lifetime::Variable:
removeAssumedBits(NOT_MUTATED | NOT_GLOBAL_READ | NOT_GLOBAL_WRITE);
break;
case IREE::Stream::Lifetime::Constant:
removeAssumedBits(NOT_CONSTANT);
break;
}
auto &resultUsage = solver.getElementFor<ValueResourceUsage>(
*this, Position::forValue(op.getResult()),
DFX::Resolution::REQUIRED);
Expand Down Expand Up @@ -497,7 +517,6 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
*this, Position::forValue(op->getOperand(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= operandUsage.getState();

auto &beforeUsage = solver.getElementFor<ValueResourceUsage>(
*this,
Position::forValue(op.getBeforeBody()->getArgument(operandIdx)),
Expand All @@ -510,13 +529,11 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
*this, Position::forValue(op->getOperand(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= operandUsage.getState();

auto &parentUsage = solver.getElementFor<ValueResourceUsage>(
*this,
Position::forValue(op->getParentOp()->getResult(operandIdx - 1)),
DFX::Resolution::REQUIRED);
getState() ^= parentUsage.getState();

if (auto whileOp =
dyn_cast_or_null<scf::WhileOp>(op->getParentOp())) {
auto value = Position::forValue(
Expand All @@ -532,14 +549,12 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
*this, Position::forValue(op->getOperand(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= operandUsage.getState();

auto &parentUsage = solver.getElementFor<ValueResourceUsage>(
*this,
Position::forValue(op->getParentOp()->getResult(operandIdx)),
DFX::Resolution::REQUIRED);
getState() ^= parentUsage.getState();
}

if (auto whileOp =
dyn_cast_or_null<scf::WhileOp>(op->getParentOp())) {
auto value =
Expand Down Expand Up @@ -589,7 +604,33 @@ class ValueResourceUsage : public AbstractResourceUsage<DFX::ValueElement> {
removeAssumedBits(NOT_INDIRECT | NOT_GLOBAL_WRITE);
})
.Case([&](IREE::Stream::TensorExportOp op) {
removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
auto sourceType =
llvm::cast<IREE::Stream::ResourceType>(op.getSource().getType());
switch (sourceType.getLifetime()) {
default:
case IREE::Stream::Lifetime::External:
removeAssumedBits(NOT_MUTATED | NOT_EXTERNAL);
break;
case IREE::Stream::Lifetime::Staging:
removeAssumedBits(NOT_MUTATED | NOT_STAGING_READ |
NOT_STAGING_WRITE | NOT_TRANSFER_READ |
NOT_TRANSFER_WRITE);
break;
case IREE::Stream::Lifetime::Transient:
removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_READ |
NOT_TRANSFER_WRITE | NOT_DISPATCH_READ |
NOT_DISPATCH_WRITE);
break;
case IREE::Stream::Lifetime::Variable:
removeAssumedBits(NOT_MUTATED | NOT_TRANSFER_READ |
NOT_TRANSFER_WRITE | NOT_DISPATCH_READ |
NOT_DISPATCH_WRITE);
break;
case IREE::Stream::Lifetime::Constant:
removeAssumedBits(NOT_CONSTANT | NOT_TRANSFER_READ |
NOT_DISPATCH_READ);
break;
}
})
.Case([&](IREE::Stream::TensorTraceOp op) {
removeAssumedBits(NOT_STAGING_READ);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ iree_compiler_cc_library(
],
deps = [
"//compiler/src/iree/compiler/Dialect/HAL/Conversion",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/VM/Conversion",
"//compiler/src/iree/compiler/Modules/Check/IR",
"@llvm-project//mlir:Pass",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_cc_library(
MLIRPass
MLIRTransforms
iree::compiler::Dialect::HAL::Conversion
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::VM::Conversion
iree::compiler::Modules::Check::IR
PUBLIC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "iree/compiler/Modules/Check/Conversion/ConversionPatterns.h"

#include "iree/compiler/Dialect/HAL/Conversion/ConversionTarget.h"
#include "iree/compiler/Dialect/HAL/Conversion/TypeConverter.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/VM/Conversion/ImportUtils.h"
#include "iree/compiler/Modules/Check/IR/CheckOps.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -60,17 +62,90 @@ void populateCheckToVMPatterns(MLIRContext *context, SymbolTable &importSymbols,
context, importSymbols, typeConverter, "check.expect_almost_eq");
}

// Attempts to rewrite an op that may use tensor values into an op using HAL
// buffers.
static LogicalResult applyDefaultCheckBufferRewrite(
Operation *srcOp, ValueRange operands, StringRef dstOpName,
TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
OperationState state{srcOp->getLoc(), dstOpName};
state.addAttributes(srcOp->getAttrs());

// Add device argument.
Value device = rewriter.create<IREE::HAL::ExSharedDeviceOp>(srcOp->getLoc());
state.addOperands({device});

for (auto [srcOperand, dstOperand] :
llvm::zip_equal(srcOp->getOperands(), operands)) {
// Check that any type that should have been mapped to buffer view was.
// This is just to catch conflicts in type conversions that may sneak in
// during development.
assert(
(!HALTypeConverter::shouldConvertToBufferView(srcOperand.getType()) ||
dstOperand.getType().isa<IREE::HAL::BufferViewType>()) &&
"expect that tensors have been mapped to buffer views");
state.addOperands({dstOperand});
}
for (auto resultType : srcOp->getResultTypes()) {
if (HALTypeConverter::shouldConvertToBufferView(resultType)) {
state.addTypes(IREE::HAL::BufferViewType::get(rewriter.getContext()));
} else {
// Normal pass-through result.
if (failed(typeConverter.convertType(resultType, state.types))) {
return failure();
}
}
}

auto *dstOp = rewriter.create(state);
rewriter.replaceOp(srcOp, dstOp->getResults());
return success();
}

// HAL tensor-to-buffer conversion utility.
// This can be used by dialects to model custom op conversion from a dialect
// that uses the MLIR tensor type to the IREE HAL buffer type. At this point
// during conversion the source values will be TensorType and the target values
// will be IREE::HAL::BufferTypes. Any static information available about the
// tensor (such as static dimensions, element type, layout, etc) are extracted
// here and lowered as expanded values.
//
// The ABI is currently very basic and will change with the introduction of more
// dynamic shape logic.
//
// Source:
// my.tensor_op(%arg0 : tensor<2x4xf32>)
// Target:
// %arg0_view = hal.buffer_view.create %arg0, ...
// my.buffer_op(%arg0_view : !hal.buffer_view)
template <typename SRC, typename DST>
class HALCheckOpConversion : public OpConversionPattern<SRC> {
public:
HALCheckOpConversion(MLIRContext *context, TypeConverter &typeConverter)
: OpConversionPattern<SRC>(context), typeConverter(typeConverter) {}

LogicalResult
matchAndRewrite(SRC srcOp, typename SRC::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return applyDefaultCheckBufferRewrite(srcOp, adaptor.getOperands(),
DST::getOperationName(),
typeConverter, rewriter);
}

protected:
TypeConverter &typeConverter;
};

void populateCheckToHALPatterns(MLIRContext *context,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
// The same op handles both tensors and buffer views.
patterns
.insert<HALOpConversion<IREE::Check::ExpectAllTrueOp,
IREE::Check::ExpectAllTrueOp>,
HALOpConversion<IREE::Check::ExpectEqOp, IREE::Check::ExpectEqOp>,
HALOpConversion<IREE::Check::ExpectAlmostEqOp,
IREE::Check::ExpectAlmostEqOp>>(context,
typeConverter);
patterns.insert<
HALCheckOpConversion<IREE::Check::ExpectAllTrueOp,
IREE::Check::ExpectAllTrueOp>,
HALCheckOpConversion<IREE::Check::ExpectEqOp, IREE::Check::ExpectEqOp>,
HALCheckOpConversion<IREE::Check::ExpectAlmostEqOp,
IREE::Check::ExpectAlmostEqOp>>(context,
typeConverter);
}

} // namespace Check
Expand Down
Loading

0 comments on commit 63381a8

Please sign in to comment.