Skip to content

Commit

Permalink
[CPU][SVE] Enable scalable vectorization and tiling for non-padded ma…
Browse files Browse the repository at this point in the history
…tmuls (iree-org#15108)

This patch implements a "vertical slice" that allows for the scalable
vectorization of matmuls on AArch64 targets with SVE. This only updates
lowerings along that path, so more is needed to enable scalable
vectorization everywhere.

This required a few changes:

- The default vector sizes of matmuls on AArch64+SVE are now (8, [32],
16)
   * That is a middle scalable dimension (i.e. 32 x vscale)
- `iree-llvmcpu-tile-and-fuse` now generates vscale bounded loops for
scalable tiles
- `TilingConfig` now returns a pair of tile sizes and scalable flags
(`SizesAndScalableFlags`) for vector sizes
* This allows connecting the scalable sizes to the generic vectorizer
(which passes them down to the linalg vectorizer)

A few unit tests have been added for this, but a complete e2e test is
not possible without SVE testing infrastructure.
  • Loading branch information
MacDue authored Oct 10, 2023
1 parent 729bb75 commit ba3e6a7
Show file tree
Hide file tree
Showing 15 changed files with 327 additions and 79 deletions.
28 changes: 18 additions & 10 deletions compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,21 +128,27 @@ inferVectorSizesFromIR(linalg::LinalgOp linalgOp) {

// Return the vector sizes from the local lowering config or try to infer them
// from the tensor shapes and tiled loops in the IR.
static FailureOr<SmallVector<int64_t>>
static FailureOr<SizesAndScalableFlags>
getVectorSizes(linalg::LinalgOp linalgOp) {
// Get vector sizes from the lowering config, if available in the op itself.
IREE::Codegen::LoweringConfigAttr loweringConfig =
getLoweringConfig(linalgOp);
if (loweringConfig) {
TilingConfig tilingConfig(loweringConfig);
SmallVector<int64_t> vectorSizes = tilingConfig.getVectorTileSizes();
auto [vectorSizes, scalableFlags] = tilingConfig.getVectorTileSizes();
// Replace zeros in canonical vector shape to turn it into a valid shape.
std::replace(vectorSizes.begin(), vectorSizes.end(), 0, 1);
return vectorSizes;
return std::make_pair(vectorSizes, scalableFlags);
}

// Try to infer the vector sizes from the IR.
return inferVectorSizesFromIR(linalgOp);
auto vectorSizes = inferVectorSizesFromIR(linalgOp);
if (succeeded(vectorSizes)) {
// This can't identify scalable flags, so pad them with `false`.
return std::make_pair(*vectorSizes,
SmallVector<bool>(vectorSizes->size(), false));
}
return failure();
}

static LogicalResult isWithinVectorSizeLimit(linalg::LinalgOp linalgOp,
Expand Down Expand Up @@ -193,14 +199,16 @@ void GenericVectorizationPass::runOnOperation() {
});
for (auto op : candidates) {
SmallVector<int64_t> vectorSizes;
SmallVector<bool> scalableVecDims;
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
// Do not vectorize the op if the vector size is greater than or equal
// to limit.
if (enableVectorMasking) {
auto maybeVectorSizes = getVectorSizes(linalgOp);
if (succeeded(maybeVectorSizes)) {
vectorSizes.append(maybeVectorSizes->begin(),
maybeVectorSizes->end());
auto vectorSizesAndScalableDims = getVectorSizes(linalgOp);
if (succeeded(vectorSizesAndScalableDims)) {
auto [sizes, scalableDims] = *vectorSizesAndScalableDims;
vectorSizes.append(sizes.begin(), sizes.end());
scalableVecDims.append(scalableDims.begin(), scalableDims.end());
}
if (std::accumulate(vectorSizes.begin(), vectorSizes.end(), 1,
std::multiplies<int64_t>()) >= maxVectorSize)
Expand All @@ -217,8 +225,8 @@ void GenericVectorizationPass::runOnOperation() {
continue;
vectorSizes.append(ty.getShape().begin(), ty.getShape().end());
}

SmallVector<bool> scalableVecDims(vectorSizes.size(), false);
// Pad scalable dims with `false` to match the vector sizes.
scalableVecDims.resize(vectorSizes.size());
(void)linalg::vectorize(rewriter, op, vectorSizes, scalableVecDims,
vectorizeGatherAccesses);
};
Expand Down
26 changes: 17 additions & 9 deletions compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,32 @@ TilingConfig::TilingConfig(IREE::Codegen::LoweringConfigAttr lc)

/// Returns the tile sizes of all the vector dimensions, including parallel
/// and reduction dimensions.
SmallVector<int64_t> TilingConfig::getVectorTileSizes() {
SizesAndScalableFlags TilingConfig::getVectorTileSizes() {
unsigned numDims = getNumDimensions();
SmallVector<int64_t> vectorSizes(numDims, 0);
SmallVector<int64_t> parallelCommonSizes = getVectorCommonParallelSizes();
SmallVector<int64_t> reductionSizes = getVectorReductionSizes();
SmallVector<int64_t> parallelInnerSizes = getVectorInnerParallelSizes();
SmallVector<bool> scalableFlags(numDims, false);
auto [parallelCommonSizes, parallelCommonScalableFlags] =
getVectorCommonParallelSizes();
auto [reductionSizes, reductionScalableFlags] = getVectorReductionSizes();
auto [parallelInnerSizes, parallelInnerScalableFlags] =
getVectorInnerParallelSizes();
for (int i = 0; i < numDims; ++i) {
unsigned nonZeroCnt = llvm::count_if(
ArrayRef<int64_t>{parallelCommonSizes[i], reductionSizes[i],
parallelInnerSizes[i]},
[](auto v) { return v != 0; });
unsigned nonZeroCnt = llvm::count(
ArrayRef<bool>{
!!parallelCommonSizes[i] || parallelCommonScalableFlags[i],
!!reductionSizes[i] || reductionScalableFlags[i],
!!parallelInnerSizes[i] || parallelInnerScalableFlags[i]},
true);
assert(nonZeroCnt <= 1 && "expected one tile size at most to be non-zero");
(void)nonZeroCnt;
vectorSizes[i] =
parallelCommonSizes[i] ^ reductionSizes[i] ^ parallelInnerSizes[i];
scalableFlags[i] = parallelCommonScalableFlags[i] ||
reductionScalableFlags[i] ||
parallelInnerScalableFlags[i];
}

return vectorSizes;
return std::make_pair(vectorSizes, scalableFlags);
}

/// Returns a list with the tiling levels that can be fused for this
Expand Down
44 changes: 32 additions & 12 deletions compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
namespace mlir {
namespace iree_compiler {

using SizesAndScalableFlags =
std::pair<SmallVector<int64_t>, SmallVector<bool>>;

/// Provides unified API to get access to all the tile size needed during the
/// CPU lowering process, while abstracting the representation and verification
/// details of such information in the IR.
Expand All @@ -35,12 +38,20 @@ class TilingConfig {

/// Returns the number of dimensions of the configuration. All the tiling
/// levels must have the same number of dimensions.
unsigned getNumDimensions() { return getDistributionTileSizes().size(); }
unsigned getNumDimensions() {
return getNumTilingLevels() > 0
? loweringConfig.getTilingLevels()[0].getSizes().size()
: 0;
}

/// Returns the number of parallel dimensions to tile at vector level.
unsigned getNumVectorParallelTiles() {
return llvm::count_if(getVectorCommonParallelSizes(),
[](int64_t tileSize) { return tileSize != 0; });
unsigned parallelLevel = getVectorCommonParallelLevel();
if (parallelLevel <= getNumTilingLevels())
return 0;
return llvm::count_if(
loweringConfig.getTilingLevels()[parallelLevel].getSizes(),
[](int64_t tileSize) { return tileSize != 0; });
}

/// Returns the tiling level for cache parallel dimensions.
Expand Down Expand Up @@ -76,28 +87,28 @@ class TilingConfig {

/// Returns the distribution tile sizes of the configuration.
SmallVector<int64_t> getDistributionTileSizes() {
return loweringConfig.getTileSizeVals(getActualLevel(DistributionTiles));
return getTileSizesForLevel(getActualLevel(DistributionTiles));
}

SmallVector<int64_t> getCacheReductionSizes() {
return loweringConfig.getTileSizeVals(getCacheReductionLevel());
return getTileSizesForLevel(getCacheReductionLevel());
}

SmallVector<int64_t> getVectorCommonParallelSizes() {
return loweringConfig.getTileSizeVals(getVectorCommonParallelLevel());
SizesAndScalableFlags getVectorCommonParallelSizes() {
return getVectorSizesForLevel(getVectorCommonParallelLevel());
}

SmallVector<int64_t> getVectorReductionSizes() {
return loweringConfig.getTileSizeVals(getVectorReductionLevel());
SizesAndScalableFlags getVectorReductionSizes() {
return getVectorSizesForLevel(getVectorReductionLevel());
}

SmallVector<int64_t> getVectorInnerParallelSizes() {
return loweringConfig.getTileSizeVals(getVectorInnerParallelLevel());
SizesAndScalableFlags getVectorInnerParallelSizes() {
return getVectorSizesForLevel(getVectorInnerParallelLevel());
}

/// Returns the tile sizes of all the vector dimensions, including parallel
/// and reduction dimensions.
SmallVector<int64_t> getVectorTileSizes();
SizesAndScalableFlags getVectorTileSizes();

/// Returns a list with the tiling levels that can be fused for this
/// configuration.
Expand All @@ -112,6 +123,15 @@ class TilingConfig {
}

private:
SizesAndScalableFlags getVectorSizesForLevel(unsigned level) {
return std::make_pair(loweringConfig.getTileSizeVals(level),
loweringConfig.getScalableTileFlagVals(level));
}

SmallVector<int64_t> getTileSizesForLevel(unsigned level) {
return loweringConfig.getTileSizeVals(level);
}

/// Internal representation for all the supported tiling levels. All or just
/// a subset of them may be available in a valid configuration.
enum TilingLevel : unsigned {
Expand Down
39 changes: 34 additions & 5 deletions compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,23 +191,35 @@ LogicalResult LoweringConfigTilingLevelAttr::verify(
// iree_codegen.lowering_config
//===----------------------------------------------------------------------===//

LoweringConfigAttr LoweringConfigAttr::get(MLIRContext *context,
TileSizesListTypeRef tileSizes,
TileSizesListTypeRef tileInterchange,
ArrayRef<int64_t> nativeVectorSize) {
LoweringConfigAttr
LoweringConfigAttr::get(MLIRContext *context, TileSizesListTypeRef tileSizes,
ScalableTileFlagsListTypeRef scalableTileFlags,
TileSizesListTypeRef tileInterchange,
ArrayRef<int64_t> nativeVectorSize) {
SmallVector<LoweringConfigTilingLevelAttr> tilinglevels;
for (auto [level, sizes] : llvm::enumerate(tileSizes)) {
ArrayRef<int64_t> interchange = level < tileInterchange.size()
? tileInterchange[level]
: ArrayRef<int64_t>{};
ArrayRef<bool> scalableFlags = level < scalableTileFlags.size()
? scalableTileFlags[level]
: ArrayRef<bool>{};
tilinglevels.push_back(LoweringConfigTilingLevelAttr::get(
context, sizes, interchange, ArrayRef<bool>{}));
context, sizes, interchange, scalableFlags));
}
return get(context,
LoweringConfigTilingLevelsAttr::get(context, tilinglevels),
nativeVectorSize);
}

LoweringConfigAttr LoweringConfigAttr::get(MLIRContext *context,
TileSizesListTypeRef tileSizes,
TileSizesListTypeRef tileInterchange,
ArrayRef<int64_t> nativeVectorSize) {

return get(context, tileSizes, {}, tileInterchange, nativeVectorSize);
}

TileSizesListType LoweringConfigAttr::getTileSizeVals() {
TileSizesListType tileSizes;
for (auto &level : getTilingLevels())
Expand All @@ -222,6 +234,23 @@ SmallVector<int64_t> LoweringConfigAttr::getTileSizeVals(unsigned level) {
return SmallVector<int64_t>(levels[level].getSizes());
}

ScalableTileFlagsListType LoweringConfigAttr::getScalableTileFlagVals() {
ScalableTileFlagsListType scalableFlags;
for (auto &level : getTilingLevels())
scalableFlags.push_back(SmallVector<bool>(level.getScalableFlags()));
return scalableFlags;
}

SmallVector<bool> LoweringConfigAttr::getScalableTileFlagVals(unsigned level) {
auto levels = getTilingLevels();
if (level >= levels.size())
return {};
SmallVector<bool> scalableFlags(levels[level].getScalableFlags());
// Extend the scalable flags with `false` to match the length of the sizes.
scalableFlags.resize(levels[level].getSizes().size());
return scalableFlags;
}

SmallVector<int64_t>
LoweringConfigAttr::getTileInterchangeVals(unsigned level) {
auto levels = getTilingLevels();
Expand Down
21 changes: 20 additions & 1 deletion compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ namespace iree_compiler {
/// Typedef for tile sizes to use at different levels of tiling.
using TileSizesListType = SmallVector<SmallVector<int64_t>>;
using TileSizesListTypeRef = ArrayRef<SmallVector<int64_t>>;
/// Typedef for scalable tile flags at different levels of tiling.
using ScalableTileFlagsListType = SmallVector<SmallVector<bool>>;
using ScalableTileFlagsListTypeRef = ArrayRef<SmallVector<bool>>;
} // namespace iree_compiler
} // namespace mlir

Expand Down Expand Up @@ -121,13 +124,15 @@ void setLoweringConfig(Operation *op, IREE::Codegen::LoweringConfigAttr config);
/// translation.
inline LogicalResult setOpConfigAndEntryPointFnTranslation(
func::FuncOp entryPointFn, Operation *op, TileSizesListTypeRef tileSizes,
ScalableTileFlagsListTypeRef scalableTileFlags,
IREE::Codegen::DispatchLoweringPassPipeline passPipeline,
ArrayRef<int64_t> workgroupSize = {},
std::optional<int64_t> subgroupSize = {},
unsigned softwarePipelineDepth = 0,
unsigned softwarePipelineStoreStage = 1) {
MLIRContext *context = entryPointFn.getContext();
auto config = IREE::Codegen::LoweringConfigAttr::get(context, tileSizes);
auto config = IREE::Codegen::LoweringConfigAttr::get(context, tileSizes,
scalableTileFlags);
setLoweringConfig(op, config);
if (failed(setDispatchConfig(entryPointFn, workgroupSize, subgroupSize)))
return failure();
Expand All @@ -137,6 +142,20 @@ inline LogicalResult setOpConfigAndEntryPointFnTranslation(
return setTranslationInfo(entryPointFn, translationInfo);
}

/// Overload of setOpConfigAndEntryPointFnTranslation() for the "no scalable
/// flags" case.
inline LogicalResult setOpConfigAndEntryPointFnTranslation(
func::FuncOp entryPointFn, Operation *op, TileSizesListTypeRef tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline passPipeline,
ArrayRef<int64_t> workgroupSize = {},
std::optional<int64_t> subgroupSize = {},
unsigned softwarePipelineDepth = 0,
unsigned softwarePipelineStoreStage = 1) {
return setOpConfigAndEntryPointFnTranslation(
entryPointFn, op, tileSizes, {}, passPipeline, workgroupSize,
subgroupSize, softwarePipelineDepth, softwarePipelineStoreStage);
}

//===----------------------------------------------------------------------===//
// Helpers for getting/setting `iree_codegen.compilation_info` attribute on root
// operations to override IREEs default compilation.
Expand Down
10 changes: 10 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/IREECodegenAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ def IREECodegen_LoweringConfigAttr :
);
let builders = [
AttrBuilder<(ins "TileSizesListTypeRef":$tileSizes,
CArg<"TileSizesListTypeRef", "{}">:$tileInterchange,
CArg<"ArrayRef<int64_t>", "{}">:$nativeVectorSize)>,
AttrBuilder<(ins "TileSizesListTypeRef":$tileSizes,
"ScalableTileFlagsListTypeRef":$scalableTileFlags,
CArg<"TileSizesListTypeRef", "{}">:$tileInterchange,
CArg<"ArrayRef<int64_t>", "{}">:$nativeVectorSize)>
];
Expand All @@ -217,6 +221,12 @@ def IREECodegen_LoweringConfigAttr :
// Returns the tile sizes for a level set for the op.
SmallVector<int64_t> getTileSizeVals(unsigned level);

// Returns the scalable tile flags for all levels set for the op.
ScalableTileFlagsListType getScalableTileFlagVals();

// Returns the scalable tile flags for a level set for the op.
SmallVector<bool> getScalableTileFlagVals(unsigned level);

// Returns the tile interchange for a level set for the op.
SmallVector<int64_t> getTileInterchangeVals(unsigned level);

Expand Down
Loading

0 comments on commit ba3e6a7

Please sign in to comment.