Skip to content

Commit

Permalink
update for comments
Browse files Browse the repository at this point in the history
Signed-off-by: Rob Suderman <[email protected]>
  • Loading branch information
rsuderman committed Feb 8, 2025
1 parent d47e8a1 commit 0f8da4a
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 48 deletions.
33 changes: 0 additions & 33 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2495,39 +2495,6 @@ LogicalResult AsyncTransferOp::verify() {
return success();
}

IREE::Stream::AffinityAttr AsyncTransferOp::getAffinityAttr() {
if (getExecAffinityAttr()) {
return getExecAffinityAttr();
}

auto sourceType = cast<IREE::Stream::ResourceType>(getSource().getType());
auto resultType = cast<IREE::Stream::ResourceType>(getResult().getType());
if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging &&
resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
// TODO(multi-device): figure out how to model staging->staging transfers.
return getSourceAffinityAttr();
} else if (sourceType.getLifetime() == IREE::Stream::Lifetime::External ||
sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) {
// If source is staging then the op should execute on the consumer.
return getResultAffinityAttr();
} else if (resultType.getLifetime() == IREE::Stream::Lifetime::External ||
resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
// If result is staging then the op should execute on the producer.
return getSourceAffinityAttr();
} else {
// Default to result affinity.
return getSourceAffinityAttr();
}
}

void AsyncTransferOp::setAffinityAttr(IREE::Stream::AffinityAttr value) {
if (value) {
setExecAffinityAttr(value);
} else {
removeExecAffinityAttr();
}
}

void AsyncTransferOp::build(OpBuilder &builder, OperationState &state,
Type type, Value source, Value source_size,
Value result_size, AffinityAttr source_attr,
Expand Down
9 changes: 3 additions & 6 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2344,10 +2344,7 @@ def Stream_AsyncBarrierOp : Stream_Op<"async.barrier", [
}

def Stream_AsyncTransferOp : Stream_Op<"async.transfer", [
DeclareOpInterfaceMethods<Stream_AffinityOp, [
"getAffinityAttr",
"setAffinityAttr",
]>,
DeclareOpInterfaceMethods<Stream_AffinityOp>,
Stream_AsyncPhaseOp,
Stream_StreamableOp,
DeclareOpInterfaceMethods<Stream_AsyncAccessOp, [
Expand All @@ -2372,7 +2369,7 @@ def Stream_AsyncTransferOp : Stream_Op<"async.transfer", [
Stream_Size:$result_size,
OptionalAttr<Stream_AffinityAttr>:$source_affinity,
OptionalAttr<Stream_AffinityAttr>:$result_affinity,
OptionalAttr<Stream_AffinityAttr>:$exec_affinity
OptionalAttr<Stream_AffinityAttr>:$affinity
);
let results = (outs
AnyTypeOf<[
Expand All @@ -2384,10 +2381,10 @@ def Stream_AsyncTransferOp : Stream_Op<"async.transfer", [
let assemblyFormat = [{
$source `:` type($source)
`` `{` $source_size `}`
(`on` `(` $affinity^ `)`)?
(`from` `(` $source_affinity^ `)`)?
`->`
(`to` `(` $result_affinity^ `)`)?
(`on` `(` $exec_affinity^ `)`)?
type($result) `` `{` $result_size `}`
attr-dict-with-keyword
}];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,22 @@ struct ExecutionPlacementPass
void runOnOperation() override {

getOperation()->walk([](IREE::Stream::AsyncTransferOp transfer) {
if (transfer.getExecAffinityAttr())
if (transfer.getAffinityAttr())
return;

auto operand = transfer.getSource();
auto producer = operand.getDefiningOp();
auto streamable =
dyn_cast_or_null<IREE::Stream::StreamableOpInterface>(producer);
auto srcAffinity = dyn_cast<IREE::Stream::AffinityOpInterface>(producer);
auto srcAffinity =
dyn_cast_or_null<IREE::Stream::AffinityOpInterface>(producer);

bool hasOneUse = operand.hasOneUse();
if (hasOneUse && streamable && srcAffinity) {
transfer.setExecAffinityAttr(srcAffinity.getAffinityAttr());
} else {
transfer.setExecAffinityAttr(transfer.getResultAffinityAttr());
transfer.setAffinityAttr(srcAffinity.getAffinityAttr());
return;
}
transfer.setAffinityAttr(transfer.getResultAffinityAttr());
});
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ util.func public @deviceTripleSync(%arg0: i1) -> (!stream.resource<transient>, !
// CHECK: stream.async.dispatch
// CHECK: stream.async.transfer
%3 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
%4 = stream.async.transfer %3 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device1>) -> to(#hal.device.affinity<@device0>) on(#hal.device.affinity<@device1>) !stream.resource<transient>{%c128}
%4 = stream.async.transfer %3 : !stream.resource<transient>{%c128} on(#hal.device.affinity<@device1>) from(#hal.device.affinity<@device1>) -> to(#hal.device.affinity<@device0>) !stream.resource<transient>{%c128}

// CHECK: stream.async.execute
// CHECK: stream.async.splat
// CHECK: stream.async.dispatch
// CHECK: stream.async.transfer
%5 = stream.async.dispatch on(#hal.device.affinity<@device2>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
%6 = stream.async.transfer %5 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device2>) -> to(#hal.device.affinity<@device0>) on(#hal.device.affinity<@device2>) !stream.resource<transient>{%c128}
%6 = stream.async.transfer %5 : !stream.resource<transient>{%c128} on(#hal.device.affinity<@device2>) from(#hal.device.affinity<@device2>) -> to(#hal.device.affinity<@device0>) !stream.resource<transient>{%c128}

// CHECK: stream.async.execute
// CHECK: stream.async.dispatch
Expand All @@ -121,13 +121,13 @@ util.func public @deviceTripleSync(%arg0: i1) -> (!stream.resource<transient>, !
// CHECK: stream.async.execute
// CHECK: stream.async.transfer
// CHECK: stream.async.dispatch
%9 = stream.async.transfer %7 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device1>) on(#hal.device.affinity<@device1>) !stream.resource<transient>{%c128}
%9 = stream.async.transfer %7 : !stream.resource<transient>{%c128} on(#hal.device.affinity<@device1>) from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device1>) !stream.resource<transient>{%c128}
%12 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch0[%c1, %c1, %c1](%9[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}

// CHECK: stream.async.execute
// CHECK: stream.async.transfer
// CHECK: stream.async.dispatch
%10 = stream.async.transfer %7 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device2>) on(#hal.device.affinity<@device2>) !stream.resource<transient>{%c128}
%10 = stream.async.transfer %7 : !stream.resource<transient>{%c128} on(#hal.device.affinity<@device2>) from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device2>) !stream.resource<transient>{%c128}
%13 = stream.async.dispatch on(#hal.device.affinity<@device2>) @ex::@dispatch0[%c1, %c1, %c1](%10[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}

util.return %11, %12, %13 : !stream.resource<transient>, !stream.resource<transient>, !stream.resource<transient>
Expand Down

0 comments on commit 0f8da4a

Please sign in to comment.