From 16e51afe60658d5508516feee84cef15f9391128 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Thu, 21 Nov 2024 13:06:11 -0800 Subject: [PATCH] [Encoding] Introduce "layouts" field to EncodingAttr. (#19215) The revision introduces an optional "layouts" field to EncodingAttr. It is an array of attributes that describes the potential layouts on the device. It is an array because a device could have several executable targets. Note that it can be any attribute that implements EncodingLayoutAttrInterface. The expectation of the field is to bridge the logics between host codes and device codes. If an attribute does not implement the interface, it could be discarded anytime. The revision also updates the TODO item for `round_dims_to` field. Because IREE is going to use the new "layouts" field and upcoming attribute interface to handle the allocation problem. It is a step towards https://github.com/iree-org/iree/issues/17924 Signed-off-by: hanhanW --- .../Dialect/Encoding/IR/EncodingAttrs.cpp | 8 +++++--- .../Dialect/Encoding/IR/EncodingAttrs.td | 18 +++++++++++++++--- .../Dialect/Encoding/IR/test/roundtrip.mlir | 12 ++++++++++++ 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp index ec163eaf14c7..333145d0e8c3 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp @@ -25,7 +25,8 @@ EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex, EncodingOpType opType, ArrayRef elemTypes, ArrayRef maps, std::optional bcastMap, - ArrayRef roundDimsTo) { + ArrayRef roundDimsTo, + ArrayRef layouts) { Builder b(ctx); auto opTypeAttr = EncodingOpTypeAttr::get(ctx, opType); auto roundDimsToAttr = roundDimsTo.empty() @@ -34,9 +35,10 @@ EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex, auto bcastMapAttr = bcastMap.has_value() ? AffineMapAttr::get(bcastMap.value()) : AffineMapAttr(); + auto layoutsAttr = layouts.empty() ? ArrayAttr() : b.getArrayAttr(layouts); return get(ctx, b.getIndexAttr(operandIndex), opTypeAttr, b.getTypeArrayAttr(elemTypes), b.getAffineMapArrayAttr(maps), - bcastMapAttr, roundDimsToAttr); + bcastMapAttr, roundDimsToAttr, layoutsAttr); } AffineMap EncodingAttr::getMapForOperandIndex() { @@ -106,7 +108,7 @@ SmallVector EncodingAttr::getElementTypesArray() { EncodingAttr EncodingAttr::clone(AffineMap bcastMap) { return get(bcastMap.getContext(), getOperandIndex(), getOpType(), getElementTypes(), getUserIndexingMaps(), - AffineMapAttr::get(bcastMap), getRoundDimsTo()); + AffineMapAttr::get(bcastMap), getRoundDimsTo(), getLayouts()); } MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) { diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td index 3ec4bd0d0408..9086c10b20b3 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td @@ -63,8 +63,19 @@ def EncodingAttr : AttrParameter<"ArrayAttr", "element types of the user's operands">:$element_types, OptionalParameter<"ArrayAttr", "Indexing maps of the operation using this tensor">:$user_indexing_maps, OptionalParameter<"AffineMapAttr", "Indexing map that represents the broadcasting dims in the producer">:$bcast_map, - // TODO(hanchung): The round_dims_to parameter can be revisited. We explicitly map them to M,N,K dimension for now. - OptionalParameter<"DenseArrayAttr", "Values for padding M,N,K dimensions">:$round_dims_to + // TODO(hanchung): Deprecate the round_dims_to field when we plumb the layouts + // field through the whole stack. See https://github.com/iree-org/iree/issues/17924 + // for details. Note that today we abuse the attribute to carry narrow + // matrix information. The end goal is deprecating the field and add a + // "iteration_space_size" field to describe the shape. It is useful to + // handle narrow matrix cases. + OptionalParameter<"DenseArrayAttr", "Values for padding M,N,K dimensions">:$round_dims_to, + OptionalParameter<"ArrayAttr", "An array of attributes that describes the " + "potential layouts on the device. It is an array because a device could " + "have several executable targets. Note that it can be any attribute that " + "implements EncodingLayoutAttrInterface. The expectation of the field " + "is to bridge the logics between host codes and device codes. If an " + "attribute does not implement the interface, it could be discarded anytime.">:$layouts ); let builders = [ @@ -73,7 +84,8 @@ def EncodingAttr : "ArrayRef":$elemTypes, CArg<"ArrayRef", "{}">:$maps, CArg<"std::optional", "{}">:$bcastMap, - CArg<"ArrayRef", "{}">:$roundDimsTo)> + CArg<"ArrayRef", "{}">:$roundDimsTo, + CArg<"ArrayRef", "{}">:$layouts)> ]; let extraClassDeclaration = [{ diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir index 775f2fa66324..b4f38d2001ed 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/test/roundtrip.mlir @@ -157,3 +157,15 @@ func.func @set_encoding_ops_with_indexing_maps(%arg0: tensor) -> tensor // CHECK: func.func @set_encoding_ops_with_indexing_maps( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: // CHECK: iree_encoding.set_encoding %[[ARG0]] : tensor -> tensor + +// ----- + +#encoding = #iree_encoding.encoding +func.func @set_encoding_ops_with_layouts(%arg0: tensor) -> tensor { + %0 = iree_encoding.set_encoding %arg0 : tensor -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[ENCODING:.+]] = #iree_encoding.encoding +// CHECK: func.func @set_encoding_ops_with_layouts( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: +// CHECK: iree_encoding.set_encoding %[[ARG0]] : tensor -> tensor