Skip to content

Commit

Permalink
Adding stream.tensor.dispatch op.
Browse files Browse the repository at this point in the history
This allows us to carry tensor encodings from higher levels of the stack
down to where we now need them during the specialize encodings pass.
Once the encodings are consumed the new op can be lowerd back into a
`stream.async.dispatch` with the encodings erased.

Fixes #19806.
  • Loading branch information
benvanik committed Jan 25, 2025
1 parent bbe7f5c commit 36531c6
Show file tree
Hide file tree
Showing 17 changed files with 1,126 additions and 497 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -782,17 +782,11 @@ struct ConvertDispatchOp
IREE::Flow::DispatchOp op, OneToNOpAdaptor adaptor,
IREE::Stream::AffinityAttr executionAffinityAttr,
ConversionPatternRewriter &rewriter) const override {
// Zero is going to be used for each operand to start.
auto zeroOffset = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);

// Query and resolve all operands and their sizes.
SmallVector<Value> dispatchOperands;
SmallVector<Value> dispatchOperandSizes;
SmallVector<Value> dispatchOperandOffsets;
SmallVector<Value> dispatchOperandEnds;
SmallVector<Value> dispatchOperandLengths;
SmallVector<Value> operands;
SmallVector<Value> operandSizes;

SmallVector<Value> allOperandSizes;
SmallVector<Type> operandEncodings;
for (auto [oldOperand, convertedOperands] :
llvm::zip_equal(op.getArguments(), adaptor.getArguments())) {
Value newOperand;
Expand All @@ -801,50 +795,58 @@ struct ConvertDispatchOp
transferTensorOperands(op.getLoc(), oldOperand, convertedOperands,
executionAffinityAttr, rewriter);
newOperand = newOperandCast.resource;
dispatchOperandSizes.push_back(newOperandCast.resourceSize);
operandSizes.push_back(newOperandCast.resourceSize);
dispatchOperandOffsets.push_back(zeroOffset);
dispatchOperandEnds.push_back(newOperandCast.resourceSize);
dispatchOperandLengths.push_back(newOperandCast.resourceSize);
allOperandSizes.push_back(newOperandCast.resourceSize);
operandEncodings.push_back(oldOperand.getType());
} else {
operandSizes.push_back({});
allOperandSizes.push_back({});
operandEncodings.push_back(rewriter.getType<IREE::Util::UnusedType>());
newOperand = convertedOperands.front();
}
dispatchOperands.push_back(newOperand);
operands.push_back(newOperand);
}

// Construct result sizes or reuse tied operand sizes from above.
SmallVector<Value> resultSizes;
SmallVector<Type> resultTypes;
SmallVector<Type> resultEncodings;
auto unknownType = rewriter.getType<IREE::Stream::ResourceType>();
auto tiedOperandBase = op.getTiedOperandsIndexAndLength().first;
for (auto result : llvm::enumerate(op.getResults())) {
auto oldResultType = result.value().getType();
if (!llvm::isa<ShapedType>(oldResultType)) {
resultTypes.push_back(getTypeConverter()->convertType(oldResultType));
resultEncodings.push_back(rewriter.getType<IREE::Util::UnusedType>());
continue;
}
auto tiedOperand = op.getTiedResultOperandIndex(result.index());
if (tiedOperand.has_value()) {
auto operandIndex = tiedOperand.value() - tiedOperandBase;
resultSizes.push_back(operandSizes[operandIndex]);
resultTypes.push_back(dispatchOperands[operandIndex].getType());
resultSizes.push_back(allOperandSizes[operandIndex]);
resultTypes.push_back(operands[operandIndex].getType());
resultEncodings.push_back(operandEncodings[operandIndex]);
} else {
auto resultDynamicDims = IREE::Util::buildDynamicDimsForValue(
op.getLoc(), result.value(), rewriter);
resultSizes.push_back(
buildResultSizeOf(op.getLoc(), result.value(), resultDynamicDims,
executionAffinityAttr, rewriter));
resultTypes.push_back(unknownType);
resultEncodings.push_back(oldResultType);
}
}

auto newOp = rewriter.create<IREE::Stream::AsyncDispatchOp>(
auto newOp = rewriter.create<IREE::Stream::TensorDispatchOp>(
op.getLoc(), resultTypes, flattenValues(adaptor.getWorkload()),
adaptor.getEntryPointsAttr(), dispatchOperands, dispatchOperandSizes,
dispatchOperandOffsets, dispatchOperandEnds, dispatchOperandLengths,
resultSizes, adaptor.getTiedOperandsAttr(), executionAffinityAttr);
newOp->setDialectAttrs(op->getDialectAttrs());
adaptor.getEntryPointsAttr(), operands, operandSizes,
rewriter.getTypeArrayAttr(operandEncodings), op.getArgumentDims(),
resultSizes, rewriter.getTypeArrayAttr(resultEncodings),
op.getResultDims(), adaptor.getTiedOperandsAttr(),
executionAffinityAttr);
newOp->setDialectAttrs(
llvm::make_filter_range(op->getDialectAttrs(), [](NamedAttribute attr) {
return attr.getName() != "stream.affinity";
}));
SmallVector<SmallVector<Value>> replacementsVec = llvm::map_to_vector(
llvm::zip_equal(newOp->getResults(), resultSizes), [](auto it) {
return SmallVector<Value>{std::get<0>(it), std::get<1>(it)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,29 @@
// CHECK-LABEL: @dispatchNoWorkload
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index)
util.func public @dispatchNoWorkload(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> tensor<?x?x1024xf32> {
// CHECK: %[[RESULT_SIZE:.+]] = stream.tensor.sizeof tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
// CHECK: %[[RESULT:.+]] = stream.async.dispatch @ex::@entry(%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) :
// CHECK-SAME: (!stream.resource<*>{%[[INPUT_SIZE]]}) -> !stream.resource<*>{%[[RESULT_SIZE]]}
// CHECK: %[[RESULT_SIZE:.+]] = stream.tensor.sizeof tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
// CHECK: %[[RESULT:.+]] = stream.tensor.dispatch @ex::@entry(%[[INPUT]]) :
// CHECK-SAME: (tensor<7x?x24x?xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[INPUT_SIZE]]}) -> tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[RESULT_SIZE]]}
%0 = flow.dispatch @ex::@entry(%input) : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim1, %dim3}
// return %[[RESULT]], %[[RESULT_SIZE]] : !stream.resource<*>, index
// CHECK: util.return %[[RESULT]], %[[RESULT_SIZE]] : !stream.resource<*>, index
util.return %0 : tensor<?x?x1024xf32>
}

// -----

// CHECK-LABEL: @dispatch
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index)
util.func public @dispatch(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> tensor<?x?x1024xf32> {
util.func public @dispatch(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> (tensor<?x?x1024xf32>, tensor<1024x?x?xf32>) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
// CHECK: %[[RESULT_SIZE:.+]] = stream.tensor.sizeof tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
// CHECK: %[[RESULT:.+]] = stream.async.dispatch @ex::@entry[%c1, %c2, %c3](%[[INPUT]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]]) :
// CHECK-SAME: (!stream.resource<*>{%[[INPUT_SIZE]]}) -> !stream.resource<*>{%[[RESULT_SIZE]]}
%0 = flow.dispatch @ex::@entry[%c1, %c2, %c3](%input) : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim1, %dim3}
// return %[[RESULT]], %[[RESULT_SIZE]] : !stream.resource<*>, index
util.return %0 : tensor<?x?x1024xf32>
// CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
// CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof tensor<1024x?x?xf32>{%[[DIM3]], %[[DIM1]]}
// CHECK: %[[RESULTS:.+]]:2 = stream.tensor.dispatch @ex::@entry[%c1, %c2, %c3](%[[INPUT]]) :
// CHECK-SAME: (tensor<7x?x24x?xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[INPUT_SIZE]]}) -> (tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[RESULT0_SIZE]]}, tensor<1024x?x?xf32>{%[[DIM3]], %[[DIM1]]} in !stream.resource<*>{%[[RESULT1_SIZE]]})
%results:2 = flow.dispatch @ex::@entry[%c1, %c2, %c3](%input) : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> (tensor<?x?x1024xf32>{%dim1, %dim3}, tensor<1024x?x?xf32>{%dim3, %dim1})
// CHECK: util.return %[[RESULTS]]#0, %[[RESULT0_SIZE]], %[[RESULTS]]#1, %[[RESULT1_SIZE]] : !stream.resource<*>, index, !stream.resource<*>, index
util.return %results#0, %results#1 : tensor<?x?x1024xf32>, tensor<1024x?x?xf32>
}

// -----
Expand All @@ -36,9 +37,11 @@ util.func public @tiedDispatch(%input0: tensor<i32>, %input1: tensor<2x3xi32>) -
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
// CHECK: %[[T_SIZE:.+]] = stream.tensor.sizeof tensor<3x9xi32> : index
// CHECK: %[[T:.+]] = stream.async.dispatch @ex::@entry0[%c1, %c2, %c3](%[[INPUT0]][%c0 to %[[INPUT0_SIZE]] for %[[INPUT0_SIZE]]]) : (!stream.resource<*>{%[[INPUT0_SIZE]]}) -> !stream.resource<*>{%[[T_SIZE]]}
// CHECK: %[[T:.+]] = stream.tensor.dispatch @ex::@entry0[%c1, %c2, %c3](%[[INPUT0]]) :
// CHECK-SAME: (tensor<i32> in !stream.resource<*>{%[[INPUT0_SIZE]]}) -> tensor<3x9xi32> in !stream.resource<*>{%[[T_SIZE]]}
%0 = flow.dispatch @ex::@entry0[%c1, %c2, %c3](%input0) : (tensor<i32>) -> tensor<3x9xi32>
// CHECK: %[[RESULT:.+]] = stream.async.dispatch @ex::@entry1[%c1, %c2, %c3](%[[INPUT1]][%c0 to %[[INPUT1_SIZE]] for %[[INPUT1_SIZE]]], %[[T]][%c0 to %[[T_SIZE]] for %[[T_SIZE]]]) : (!stream.resource<*>{%[[INPUT1_SIZE]]}, !stream.resource<*>{%[[T_SIZE]]}) -> %[[T]]{%[[T_SIZE]]}
// CHECK: %[[RESULT:.+]] = stream.tensor.dispatch @ex::@entry1[%c1, %c2, %c3](%[[INPUT1]], %[[T]]) :
// CHECK-SAME: (tensor<2x3xi32> in !stream.resource<*>{%[[INPUT1_SIZE]]}, tensor<3x9xi32> in !stream.resource<*>{%[[T_SIZE]]}) -> tensor<3x9xi32> in %[[T]]{%[[T_SIZE]]}
%1 = flow.dispatch @ex::@entry1[%c1, %c2, %c3](%input1, %0) : (tensor<2x3xi32>, tensor<3x9xi32>) -> %0
// CHECK: util.return %[[RESULT]], %[[T_SIZE]] : !stream.resource<*>, index
util.return %1 : tensor<3x9xi32>
Expand All @@ -52,18 +55,20 @@ util.global private @device_b : !hal.device
// CHECK-LABEL: @dispatchAffinity
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM1:.+]]: index, %[[DIM3:.+]]: index)
util.func public @dispatchAffinity(%input: tensor<7x?x24x?xf32>, %dim1: index, %dim3: index) -> (tensor<?x?x1024xf32>, tensor<?x?x1024xf32>) {
// CHECK: %[[INPUT_A:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%[[INPUT_SIZE]]}
// CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
// CHECK: %[[RESULT0:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT_A]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]])
// CHECK: %[[INPUT_A:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%[[INPUT_SIZE]]}
// CHECK: %[[RESULT0_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]}
// CHECK: %[[RESULT0:.+]] = stream.tensor.dispatch on(#hal.device.affinity<@device_a>) @ex::@entry0(%[[INPUT_A]])
// CHECK-SAME: (tensor<7x?x24x?xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[INPUT_SIZE]]}) -> tensor<?x?x1024xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[RESULT0_SIZE]]}
%0 = flow.dispatch @ex::@entry0(%input) {
stream.affinity = #hal.device.affinity<@device_a>
} : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim1, %dim3}
// CHECK: %[[INPUT_B:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_b>) !stream.resource<*>{%[[INPUT_SIZE]]}
// CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<?x?x1024xf32>{%[[DIM3]], %[[DIM1]]}
// CHECK: %[[RESULT1:.+]] = stream.async.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT_B]][%c0 to %[[INPUT_SIZE]] for %[[INPUT_SIZE]]])
// CHECK: %[[INPUT_B:.+]] = stream.async.transfer %[[INPUT]] : !stream.resource<*>{%[[INPUT_SIZE]]} -> to(#hal.device.affinity<@device_b>) !stream.resource<*>{%[[INPUT_SIZE]]}
// CHECK: %[[RESULT1_SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<?x?x1024xf32>{%[[DIM3]], %[[DIM1]]}
// CHECK: %[[RESULT1:.+]] = stream.tensor.dispatch on(#hal.device.affinity<@device_b>) @ex::@entry1(%[[INPUT_B]])
// CHECK-SAME: (tensor<7x?x24x?xf32>{%[[DIM1]], %[[DIM3]]} in !stream.resource<*>{%[[INPUT_SIZE]]}) -> tensor<?x?x1024xf32>{%[[DIM3]], %[[DIM1]]} in !stream.resource<*>{%[[RESULT1_SIZE]]}
%1 = flow.dispatch @ex::@entry1(%input) {
stream.affinity = #hal.device.affinity<@device_b>
} : (tensor<7x?x24x?xf32>{%dim1, %dim3}) -> tensor<?x?x1024xf32>{%dim3, %dim1}
// return %[[RESULT0]], %[[RESULT0_SIZE]], %[[RESULT1]], %[[RESULT1_SIZE]]
// CHECK: return %[[RESULT0]], %[[RESULT0_SIZE]], %[[RESULT1]], %[[RESULT1_SIZE]]
util.return %0, %1 : tensor<?x?x1024xf32>, tensor<?x?x1024xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,12 @@ util.global private @device : !hal.device
// CHECK-LABEL: @tensorBarrierDispatch
// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
util.func public @tensorBarrierDispatch(%input: tensor<?x128xi8>, %dim0: index) -> tensor<?x128xi8> {
%c0 = arith.constant 0 : index
%barrier = flow.tensor.barrier %input : tensor<?x128xi8>{%dim0} on #hal.device.affinity<@device>
%0 = flow.dispatch @ex::@entry[%c0](%barrier) : (tensor<?x128xi8>{%dim0}) -> tensor<?x128xi8>{%dim0}

// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[BARRIER:.+]] = stream.async.barrier %[[INPUT]] : !stream.resource<*>{%[[DIM0]]} -> !stream.resource<*>
// CHECK: %[[C0_2:.+]] = arith.constant 0 : index
%barrier = flow.tensor.barrier %input : tensor<?x128xi8>{%dim0} on #hal.device.affinity<@device>
// CHECK: %[[SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device>) tensor<?x128xi8>{%arg2} : index
// CHECK: %[[DISP:.+]] = stream.async.dispatch on(#hal.device.affinity<@device>) @ex::@entry[%[[C0]]](%[[BARRIER]][%[[C0_2]] to %[[DIM0]] for %[[DIM0]]])
// CHECK: util.return %[[DISP]], %[[SIZE]]
// CHECK: %[[RESULT:.+]] = stream.tensor.dispatch on(#hal.device.affinity<@device>) @ex::@entry(%[[BARRIER]])
%0 = flow.dispatch @ex::@entry(%barrier) : (tensor<?x128xi8>{%dim0}) -> tensor<?x128xi8>{%dim0}
// CHECK: util.return %[[RESULT]], %[[SIZE]]
util.return %0 : tensor<?x128xi8>
}

Expand Down
30 changes: 30 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,36 @@ void TensorStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
// TODO(benvanik): combine multiple stores to the same target if contiguous.
}

//===----------------------------------------------------------------------===//
// stream.tensor.dispatch
//===----------------------------------------------------------------------===//

namespace {

struct DeduplicateTensorDispatchEntryRefs final
: public OpRewritePattern<TensorDispatchOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TensorDispatchOp dispatchOp,
PatternRewriter &rewriter) const override {
auto originalAttr = dispatchOp.getEntryPointsAttr();
auto newAttr = deduplicateArrayElements(originalAttr);
if (newAttr == originalAttr)
return failure();
rewriter.modifyOpInPlace(dispatchOp,
[&]() { dispatchOp.setEntryPointsAttr(newAttr); });
return success();
}
};

} // namespace

void TensorDispatchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO(benvanik): maybe tied type/lifetime updates?
results.insert<ElideUnusedOp<TensorDispatchOp>>(context);
results.insert<DeduplicateTensorDispatchEntryRefs>(context);
}

//===----------------------------------------------------------------------===//
// stream.async.alloca
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 36531c6

Please sign in to comment.