Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding stream.tensor.dispatch op. #19817

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -782,17 +782,11 @@ struct ConvertDispatchOp
IREE::Flow::DispatchOp op, OneToNOpAdaptor adaptor,
IREE::Stream::AffinityAttr executionAffinityAttr,
ConversionPatternRewriter &rewriter) const override {
// Zero is going to be used for each operand to start.
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);

// Query and resolve all operands and their sizes.
SmallVector<Value> dispatchOperands;
SmallVector<Value> dispatchOperandSizes;
SmallVector<Value> dispatchOperandOffsets;
SmallVector<Value> dispatchOperandEnds;
SmallVector<Value> dispatchOperandLengths;
SmallVector<Value> operands;
SmallVector<Value> operandSizes;

SmallVector<Value> allOperandSizes;
SmallVector<Type> operandEncodings;
for (auto [oldOperand, convertedOperands] :
llvm::zip_equal(op.getArguments(), adaptor.getArguments())) {
Value newOperand;
Expand All @@ -801,50 +795,58 @@ struct ConvertDispatchOp
transferTensorOperands(op.getLoc(), oldOperand, convertedOperands,
executionAffinityAttr, rewriter);
newOperand = newOperandCast.resource;
dispatchOperandSizes.push_back(newOperandCast.resourceSize);
operandSizes.push_back(newOperandCast.resourceSize);
dispatchOperandOffsets.push_back(zeroOffset);
dispatchOperandEnds.push_back(newOperandCast.resourceSize);
dispatchOperandLengths.push_back(newOperandCast.resourceSize);
allOperandSizes.push_back(newOperandCast.resourceSize);
operandEncodings.push_back(oldOperand.getType());
} else {
operandSizes.push_back({});
allOperandSizes.push_back({});
operandEncodings.push_back(rewriter.getType<IREE::Util::UnusedType>());
newOperand = convertedOperands.front();
}
dispatchOperands.push_back(newOperand);
operands.push_back(newOperand);
}

// Construct result sizes or reuse tied operand sizes from above.
SmallVector<Value> resultSizes;
SmallVector<Type> resultTypes;
SmallVector<Type> resultEncodings;
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
auto tiedOperandBase = op.getTiedOperandsIndexAndLength().first;
for (auto result : llvm::enumerate(op.getResults())) {
auto oldResultType = result.value().getType();
if (!llvm::isa<ShapedType>(oldResultType)) {
resultTypes.push_back(getTypeConverter()->convertType(oldResultType));
resultEncodings.push_back(rewriter.getType<IREE::Util::UnusedType>());
continue;
}
auto tiedOperand = op.getTiedResultOperandIndex(result.index());
if (tiedOperand.has_value()) {
auto operandIndex = tiedOperand.value() - tiedOperandBase;
resultSizes.push_back(operandSizes[operandIndex]);
resultTypes.push_back(dispatchOperands[operandIndex].getType());
resultSizes.push_back(allOperandSizes[operandIndex]);
resultTypes.push_back(operands[operandIndex].getType());
resultEncodings.push_back(operandEncodings[operandIndex]);
} else {
auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue(
op.getLoc(), result.value(), rewriter);
resultSizes.push_back(
buildResultSizeOf(op.getLoc(), result.value(), resultDynamicDims,
executionAffinityAttr, rewriter));
resultTypes.push_back(unknownType);
resultEncodings.push_back(oldResultType);
}
}

auto newOp = rewriter.create<IREE::Stream::AsyncDispatchOp>(
auto newOp = rewriter.create<IREE::Stream::TensorDispatchOp>(
op.getLoc(), resultTypes, flattenValues(adaptor.getWorkload()),
adaptor.getEntryPointsAttr(), dispatchOperands, dispatchOperandSizes,
dispatchOperandOffsets, dispatchOperandEnds, dispatchOperandLengths,
resultSizes, adaptor.getTiedOperandsAttr(), executionAffinityAttr);
newOp->setDialectAttrs(op->getDialectAttrs());
adaptor.getEntryPointsAttr(), operands, operandSizes,
rewriter.getTypeArrayAttr(operandEncodings), op.getArgumentDims(),
resultSizes, rewriter.getTypeArrayAttr(resultEncodings),
op.getResultDims(), adaptor.getTiedOperandsAttr(),
executionAffinityAttr);
newOp->setDialectAttrs(
llvm::make_filter_range(op->getDialectAttrs(), [](NamedAttribute attr) {
return attr.getName() != "stream.affinity";
}));
SmallVector<SmallVector<Value>> replacementsVec = llvm::map_to_vector(
llvm::zip_equal(newOp->getResults(), resultSizes), [](auto it) {
return SmallVector<Value>{std::get<0>(it), std::get<1>(it)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,29 @@
// CHECK-LABEL: @dispatchNoWorkload
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index)
util.func public @dispatchNoWorkload(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> tensor<?x?x1024xf32> {
// CHECK: %[[RESULT_SIZE:.+]] = stream.tensor.sizeof tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
// CHECK: %[[RESULT:.+]] = stream.async.dispatch @ex::@entry(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) :
// CHECK-SAME: (!stream.resource<*>{%[[INPUT_SIZE]]}) -> !stream.resource<*>{%[[RESULT_SIZE]]}
// CHECK: %[[RESULT_SIZE:.+]] = stream.tensor.sizeof tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
// CHECK: %[[RESULT:.+]] = stream.tensor.dispatch @ex::@entry(%[[INPUT]]) :
// CHECK-SAME: (tensor<7x?x24x?xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[INPUT_SIZE]]}) -> tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[RESULT_SIZE]]}
%0 = flow.dispatch @ex::@entry(%input) : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim1, %dim3}
// return %[[RESULT]], %[[RESULT_SIZE]] : !stream.resource<*>, index
// CHECK: util.return %[[RESULT]], %[[RESULT_SIZE]] : !stream.resource<*>, index
util.return %0 : tensor<?x?x1024xf32>
}

// -----

// CHECK-LABEL: @dispatch
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index)
util.func public @dispatch(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> tensor<?x?x1024xf32> {
util.func public @dispatch(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> (tensor<?x?x1024xf32>, tensor<1024x?x?xf32>) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
// CHECK: %[[RESULT_SIZE:.+]] = stream.tensor.sizeof tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
// CHECK: %[[RESULT:.+]] = stream.async.dispatch @ex::@entry[%c1, %c2, %c3](%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) :
// CHECK-SAME: (!stream.resource<*>{%[[INPUT_SIZE]]}) -> !stream.resource<*>{%[[RESULT_SIZE]]}
%0 = flow.dispatch @ex::@entry[%c1, %c2, %c3](%input) : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim1, %dim3}
// return %[[RESULT]], %[[RESULT_SIZE]] : !stream.resource<*>, index
util.return %0 : tensor<?x?x1024xf32>
// CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
// CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof tensor<1024x?x?xf32>{%[[DIM3]], %[[DIM1]]}
// CHECK: %[[RESULTS:.+]]:2 = stream.tensor.dispatch @ex::@entry[%c1, %c2, %c3](%[[INPUT]]) :
// CHECK-SAME: (tensor<7x?x24x?xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[INPUT_SIZE]]}) -> (tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[RESULT0_SIZE]]}, tensor<1024x?x?xf32>{%[[DIM3]], %[[DIM1]]} in !stream.resource<*>{%[[RESULT1_SIZE]]})
%results:2 = flow.dispatch @ex::@entry[%c1, %c2, %c3](%input) : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> (tensor<?x?x1024xf32>{%dim1, %dim3}, tensor<1024x?x?xf32>{%dim3, %dim1})
// CHECK: util.return %[[RESULTS]]#0, %[[RESULT0_SIZE]], %[[RESULTS]]#1, %[[RESULT1_SIZE]] : !stream.resource<*>, index, !stream.resource<*>, index
util.return %results#0, %results#1 : tensor<?x?x1024xf32>, tensor<1024x?x?xf32>
}

// -----
Expand All @@ -36,9 +37,11 @@ util.func public @tiedDispatch(%input0: tensor<i32>, %input1: tensor<2x3xi32>) -
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
// CHECK: %[[T_SIZE:.+]] = stream.tensor.sizeof tensor<3x9xi32> : index
// CHECK: %[[T:.+]] = stream.async.dispatch @ex::@entry0[%c1, %c2, %c3](%[[INPUT0]][%c0 to %[[INPUT0_SIZE]] for %[[INPUT0_SIZE]]]) : (!stream.resource<*>{%[[INPUT0_SIZE]]}) -> !stream.resource<*>{%[[T_SIZE]]}
// CHECK: %[[T:.+]] = stream.tensor.dispatch @ex::@entry0[%c1, %c2, %c3](%[[INPUT0]]) :
// CHECK-SAME: (tensor<i32> in !stream.resource<*>{%[[INPUT0_SIZE]]}) -> tensor<3x9xi32> in !stream.resource<*>{%[[T_SIZE]]}
%0 = flow.dispatch @ex::@entry0[%c1, %c2, %c3](%input0) : (tensor<i32>) -> tensor<3x9xi32>
// CHECK: %[[RESULT:.+]] = stream.async.dispatch @ex::@entry1[%c1, %c2, %c3](%[[INPUT1]][%c0 to %[[INPUT1_SIZE]] for %[[INPUT1_SIZE]]], %[[T]][%c0 to %[[T_SIZE]] for %[[T_SIZE]]]) : (!stream.resource<*>{%[[INPUT1_SIZE]]}, !stream.resource<*>{%[[T_SIZE]]}) -> %[[T]]{%[[T_SIZE]]}
// CHECK: %[[RESULT:.+]] = stream.tensor.dispatch @ex::@entry1[%c1, %c2, %c3](%[[INPUT1]], %[[T]]) :
// CHECK-SAME: (tensor<2x3xi32> in !stream.resource<*>{%[[INPUT1_SIZE]]}, tensor<3x9xi32> in !stream.resource<*>{%[[T_SIZE]]}) -> tensor<3x9xi32> in %[[T]]{%[[T_SIZE]]}
%1 = flow.dispatch @ex::@entry1[%c1, %c2, %c3](%input1, %0) : (tensor<2x3xi32>, tensor<3x9xi32>) -> %0
// CHECK: util.return %[[RESULT]], %[[T_SIZE]] : !stream.resource<*>, index
util.return %1 : tensor<3x9xi32>
Expand All @@ -52,18 +55,20 @@ util.global private @device_b : !hal.device
// CHECK-LABEL: @dispatchAffinity
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index)
util.func public @dispatchAffinity(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> (tensor<?x?x1024xf32>, tensor<?x?x1024xf32>) {
// CHECK: %[[INPUT_A:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%[[INPUT_SIZE]]}
// CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
// CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT_A]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]])
// CHECK: %[[INPUT_A:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%[[INPUT_SIZE]]}
// CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
// CHECK: %[[RESULT0:.+]] = stream.tensor.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT_A]])
// CHECK-SAME: (tensor<7x?x24x?xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[INPUT_SIZE]]}) -> tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[RESULT0_SIZE]]}
%0 = flow.dispatch @ex::@entry0(%input) {
stream.affinity = #hal.device.affinity<@device_a>
} : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim1, %dim3}
// CHECK: %[[INPUT_B:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_b>) !stream.resource<*>{%[[INPUT_SIZE]]}
// CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<?x?x1024xf32>{%[[DIM3]], %[[DIM1]]}
// CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT_B]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]])
// CHECK: %[[INPUT_B:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_b>) !stream.resource<*>{%[[INPUT_SIZE]]}
// CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<?x?x1024xf32>{%[[DIM3]], %[[DIM1]]}
// CHECK: %[[RESULT1:.+]] = stream.tensor.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT_B]])
// CHECK-SAME: (tensor<7x?x24x?xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[INPUT_SIZE]]}) -> tensor<?x?x1024xf32>{%[[DIM3]], %[[DIM1]]} in !stream.resource<*>{%[[RESULT1_SIZE]]}
%1 = flow.dispatch @ex::@entry1(%input) {
stream.affinity = #hal.device.affinity<@device_b>
} : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim3, %dim1}
// return %[[RESULT0]], %[[RESULT0_SIZE]], %[[RESULT1]], %[[RESULT1_SIZE]]
// CHECK: return %[[RESULT0]], %[[RESULT0_SIZE]], %[[RESULT1]], %[[RESULT1_SIZE]]
util.return %0, %1 : tensor<?x?x1024xf32>, tensor<?x?x1024xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,12 @@ util.global private @device : !hal.device
// CHECK-LABEL: @tensorBarrierDispatch
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
util.func public @tensorBarrierDispatch(%input: tensor<?x128xi8>, %dim0: index) -> tensor<?x128xi8> {
%c0 = arith.constant 0 : index
%barrier = flow.tensor.barrier %input : tensor<?x128xi8>{%dim0} on #hal.device.affinity<@device>
%0 = flow.dispatch @ex::@entry[%c0](%barrier) : (tensor<?x128xi8>{%dim0}) -> tensor<?x128xi8>{%dim0}

// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[BARRIER:.+]] = stream.async.barrier %[[INPUT]] : !stream.resource<*>{%[[DIM0]]} -> !stream.resource<*>
// CHECK: %[[C0_2:.+]] = arith.constant 0 : index
%barrier = flow.tensor.barrier %input : tensor<?x128xi8>{%dim0} on #hal.device.affinity<@device>
// CHECK: %[[SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device>) tensor<?x128xi8>{%arg2} : index
// CHECK: %[[DISP:.+]] = stream.async.dispatch on(#hal.device.affinity<@device>) @ex::@entry[%[[C0]]](%[[BARRIER]][%[[C0_2]] to %[[DIM0]] for %[[DIM0]]])
// CHECK: util.return %[[DISP]], %[[SIZE]]
// CHECK: %[[RESULT:.+]] = stream.tensor.dispatch on(#hal.device.affinity<@device>) @ex::@entry(%[[BARRIER]])
%0 = flow.dispatch @ex::@entry(%barrier) : (tensor<?x128xi8>{%dim0}) -> tensor<?x128xi8>{%dim0}
// CHECK: util.return %[[RESULT]], %[[SIZE]]
util.return %0 : tensor<?x128xi8>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,36 @@ void TensorStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
// TODO(benvanik): combine multiple stores to the same target if contiguous.
}

//===----------------------------------------------------------------------===//
// stream.tensor.dispatch
//===----------------------------------------------------------------------===//

namespace {

struct DeduplicateTensorDispatchEntryRefs final
: public OpRewritePattern<TensorDispatchOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TensorDispatchOp dispatchOp,
PatternRewriter &rewriter) const override {
auto originalAttr = dispatchOp.getEntryPointsAttr();
auto newAttr = deduplicateArrayElements(originalAttr);
if (newAttr == originalAttr)
return failure();
rewriter.modifyOpInPlace(dispatchOp,
[&]() { dispatchOp.setEntryPointsAttr(newAttr); });
return success();
}
};

} // namespace

void TensorDispatchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): maybe tied type/lifetime updates?
results.insert<ElideUnusedOp<TensorDispatchOp>>(context);
results.insert<DeduplicateTensorDispatchEntryRefs>(context);
}

//===----------------------------------------------------------------------===//
// stream.async.alloca
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading