Skip to content

Commit

Permalink
[Encoding] Introduce "layouts" field to EncodingAttr. (iree-org#19215)
Browse files Browse the repository at this point in the history
The revision introduces an optional "layouts" field to EncodingAttr. It
is an array of attributes that describes the potential layouts on the
device. It is an array because a device could have several executable
targets. Note that it can be any attribute that implements
EncodingLayoutAttrInterface. The expectation of the field is to bridge
the logics between host codes and device codes. If an attribute does not
implement the interface, it could be discarded anytime.

The revision also updates the TODO item for `round_dims_to` field.
Because IREE is going to use the new "layouts" field and upcoming
attribute interface to handle the allocation problem.

It is a step towards iree-org#17924

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW authored Nov 21, 2024
1 parent 4ee5d19 commit 16e51af
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex,
EncodingOpType opType, ArrayRef<Type> elemTypes,
ArrayRef<AffineMap> maps,
std::optional<AffineMap> bcastMap,
ArrayRef<int64_t> roundDimsTo) {
ArrayRef<int64_t> roundDimsTo,
ArrayRef<Attribute> layouts) {
Builder b(ctx);
auto opTypeAttr = EncodingOpTypeAttr::get(ctx, opType);
auto roundDimsToAttr = roundDimsTo.empty()
Expand All @@ -34,9 +35,10 @@ EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex,
auto bcastMapAttr = bcastMap.has_value()
? AffineMapAttr::get(bcastMap.value())
: AffineMapAttr();
auto layoutsAttr = layouts.empty() ? ArrayAttr() : b.getArrayAttr(layouts);
return get(ctx, b.getIndexAttr(operandIndex), opTypeAttr,
b.getTypeArrayAttr(elemTypes), b.getAffineMapArrayAttr(maps),
bcastMapAttr, roundDimsToAttr);
bcastMapAttr, roundDimsToAttr, layoutsAttr);
}

AffineMap EncodingAttr::getMapForOperandIndex() {
Expand Down Expand Up @@ -106,7 +108,7 @@ SmallVector<Type> EncodingAttr::getElementTypesArray() {
EncodingAttr EncodingAttr::clone(AffineMap bcastMap) {
return get(bcastMap.getContext(), getOperandIndex(), getOpType(),
getElementTypes(), getUserIndexingMaps(),
AffineMapAttr::get(bcastMap), getRoundDimsTo());
AffineMapAttr::get(bcastMap), getRoundDimsTo(), getLayouts());
}

MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) {
Expand Down
18 changes: 15 additions & 3 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,19 @@ def EncodingAttr :
AttrParameter<"ArrayAttr", "element types of the user's operands">:$element_types,
OptionalParameter<"ArrayAttr", "Indexing maps of the operation using this tensor">:$user_indexing_maps,
OptionalParameter<"AffineMapAttr", "Indexing map that represents the broadcasting dims in the producer">:$bcast_map,
// TODO(hanchung): The round_dims_to parameter can be revisited. We explicitly map them to M,N,K dimension for now.
OptionalParameter<"DenseArrayAttr", "Values for padding M,N,K dimensions">:$round_dims_to
// TODO(hanchung): Deprecate the round_dims_to field when we plumb the layouts
// field through the whole stack. See https://github.com/iree-org/iree/issues/17924
// for details. Note that today we abuse the attribute to carry narrow
// matrix information. The end goal is deprecating the field and add a
// "iteration_space_size" field to describe the shape. It is useful to
// handle narrow matrix cases.
OptionalParameter<"DenseArrayAttr", "Values for padding M,N,K dimensions">:$round_dims_to,
OptionalParameter<"ArrayAttr", "An array of attributes that describes the "
"potential layouts on the device. It is an array because a device could "
"have several executable targets. Note that it can be any attribute that "
"implements EncodingLayoutAttrInterface. The expectation of the field "
"is to bridge the logics between host codes and device codes. If an "
"attribute does not implement the interface, it could be discarded anytime.">:$layouts
);

let builders = [
Expand All @@ -73,7 +84,8 @@ def EncodingAttr :
"ArrayRef<Type>":$elemTypes,
CArg<"ArrayRef<AffineMap>", "{}">:$maps,
CArg<"std::optional<AffineMap>", "{}">:$bcastMap,
CArg<"ArrayRef<int64_t>", "{}">:$roundDimsTo)>
CArg<"ArrayRef<int64_t>", "{}">:$roundDimsTo,
CArg<"ArrayRef<Attribute>", "{}">:$layouts)>
];

let extraClassDeclaration = [{
Expand Down
12 changes: 12 additions & 0 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,15 @@ func.func @set_encoding_ops_with_indexing_maps(%arg0: tensor<?x?xf32>) -> tensor
// CHECK: func.func @set_encoding_ops_with_indexing_maps(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
// CHECK: iree_encoding.set_encoding %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32, #[[ENCODING]]>

// -----

#encoding = #iree_encoding.encoding<operand_index = 0 : i64, op_type = matmul, element_types = [f32, f32, f32], layouts = [{}]>
func.func @set_encoding_ops_with_layouts(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32, #encoding> {
%0 = iree_encoding.set_encoding %arg0 : tensor<?x?xf32> -> tensor<?x?xf32, #encoding>
return %0 : tensor<?x?xf32, #encoding>
}
// CHECK-DAG: #[[ENCODING:.+]] = #iree_encoding.encoding<operand_index = 0 : i64, op_type = matmul, element_types = [f32, f32, f32], layouts = [{}]>
// CHECK: func.func @set_encoding_ops_with_layouts(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
// CHECK: iree_encoding.set_encoding %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32, #[[ENCODING]]>

0 comments on commit 16e51af

Please sign in to comment.