forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[VectorExt] Add layout iterator classes (iree-org#16004)
This PR adds iterator classes to iterate over the layout and to concatenate iterators with other frozen iterators. This is required when distributing reductions. Also adds a test to check for correctness.
- Loading branch information
Showing
9 changed files
with
307 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
llvm-external-projects/iree-dialects/test/Dialect/iree_vector_ext/iterators.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
// RUN: iree-dialects-opt --split-input-file --test-vector-ext-iterators %s | FileCheck %s | ||
|
||
// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:0, | ||
// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:0, | ||
// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:0, | ||
// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:0, | ||
// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:0, | ||
// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:0, | ||
// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:0, | ||
// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:0, | ||
// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:1, | ||
// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:1, | ||
// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:1, | ||
// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:1, | ||
// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:1, | ||
// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:1, | ||
// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:1, | ||
// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:1, | ||
#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [2, 1, 2]> | ||
#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [2, 1, 2]> | ||
#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1> | ||
func.func @iterator_test(%lhs: memref<4x4xf16>) -> vector<4x4xf16> { | ||
%cst_0 = arith.constant 0.0 : f16 | ||
%c0 = arith.constant 0 : index | ||
%result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true], __test_iterator_layout__ = #layout1} : memref<4x4xf16>, vector<4x4xf16> | ||
return %result : vector<4x4xf16> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:0, VECTORZ:0, | ||
// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:0, VECTORZ:0, | ||
// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:0, VECTORZ:0, | ||
// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:0, VECTORZ:0, | ||
// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:0, VECTORZ:0, | ||
// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:0, VECTORZ:0, | ||
// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:0, VECTORZ:0, | ||
// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:0, VECTORZ:0, | ||
// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:1, VECTORZ:0, | ||
// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:1, VECTORZ:0, | ||
// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:1, VECTORZ:0, | ||
// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:1, VECTORZ:0, | ||
// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:1, VECTORZ:0, | ||
// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:1, VECTORZ:0, | ||
// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:1, VECTORZ:0, | ||
// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:1, VECTORZ:0, | ||
#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [2, 1, 2]> | ||
#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [2, 1, 2]> | ||
#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1> | ||
func.func @frozen_iterator_test(%lhs: memref<4x4xf16>) -> vector<4x4xf16> { | ||
%cst_0 = arith.constant 0.0 : f16 | ||
%c0 = arith.constant 0 : index | ||
%result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true], __test_frozen_iterator_layout__ = #layout1} : memref<4x4xf16>, vector<4x4xf16> | ||
return %result : vector<4x4xf16> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
add_subdirectory(Transforms) | ||
add_subdirectory(VectorExt) |
13 changes: 13 additions & 0 deletions
13
llvm-external-projects/iree-dialects/test/lib/VectorExt/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
add_mlir_library(IREEVectorExtTestPasses | ||
TestIterators.cpp | ||
|
||
DEPENDS | ||
mlir-headers | ||
|
||
EXCLUDE_FROM_LIBMLIR | ||
|
||
LINK_LIBS PUBLIC | ||
IREEVectorExtDialect | ||
MLIRPass | ||
) | ||
|
90 changes: 90 additions & 0 deletions
90
llvm-external-projects/iree-dialects/test/lib/VectorExt/TestIterators.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
// Copyright 2023 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h" | ||
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::iree_compiler::IREE::VectorExt; | ||
|
||
namespace { | ||
|
||
static const StringRef kIteratorMarker = "__test_iterator_layout__"; | ||
static const StringRef kFrozenIteratorMarker = | ||
"__test_frozen_iterator_layout__"; | ||
|
||
struct TestVectorExtIteratorPass | ||
: public PassWrapper<TestVectorExtIteratorPass, Pass> { | ||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorExtIteratorPass) | ||
TestVectorExtIteratorPass() = default; | ||
TestVectorExtIteratorPass(const TestVectorExtIteratorPass &other) | ||
: PassWrapper(other) {} | ||
StringRef getArgument() const final { return "test-vector-ext-iterators"; } | ||
StringRef getDescription() const final { | ||
return "Test VectorExt Iterator pass."; | ||
} | ||
bool canScheduleOn(RegisteredOperationName opName) const override { | ||
return true; | ||
} | ||
// Prints the layout so that LIT can test it for correctness. | ||
static void printFn(const LayoutIterator::State &state) { | ||
for (const auto &[dim, it] : state) { | ||
llvm::outs() << stringifyLayoutDimension(dim).str() + ":" + | ||
std::to_string(*it) + ", "; | ||
} | ||
llvm::outs() << "\n"; | ||
} | ||
void testIterator(Operation *op) { | ||
auto layout = dyn_cast_or_null<LayoutAttr>(op->getAttr(kIteratorMarker)); | ||
DenseMap<LayoutDimension, int64_t> strides; | ||
LayoutIterator iterator(layout, strides); | ||
iterator.apply(printFn); | ||
} | ||
LayoutDimensionAttr createLayoutDimensionAttr(MLIRContext *ctx, | ||
LayoutDimension dim) { | ||
return LayoutDimensionAttr::get(ctx, dim); | ||
} | ||
LayoutIterator | ||
createFrozenIterator(MLIRContext *ctx, | ||
DenseMap<LayoutDimension, int64_t> &strides) { | ||
SmallVector<LayoutDimensionAttr> labels{ | ||
createLayoutDimensionAttr(ctx, LayoutDimension::VECTORZ), | ||
createLayoutDimensionAttr(ctx, LayoutDimension::VECTORX)}; | ||
auto newLayout = | ||
LayoutAttr::get(ctx, {PerDimLayoutAttr::get(ctx, labels[0], {1}), | ||
PerDimLayoutAttr::get(ctx, labels[1], {1})}); | ||
return LayoutIterator(newLayout, strides); | ||
} | ||
void testFrozenIterator(Operation *op) { | ||
auto layout = | ||
dyn_cast_or_null<LayoutAttr>(op->getAttr(kFrozenIteratorMarker)); | ||
DenseMap<LayoutDimension, int64_t> strides; | ||
LayoutIterator iterator(layout, strides); | ||
auto frozenIterator = createFrozenIterator(op->getContext(), strides); | ||
iterator.maybeFreezeAndConcatenate(frozenIterator); | ||
iterator.apply(printFn); | ||
} | ||
void runOnOperation() override { | ||
getOperation()->walk([&](Operation *op) { | ||
if (op->hasAttr(kIteratorMarker)) { | ||
return testIterator(op); | ||
} | ||
if (op->hasAttr(kFrozenIteratorMarker)) { | ||
return testFrozenIterator(op); | ||
} | ||
}); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace mlir::test_ext { | ||
void registerVectorExtTestPasses() { | ||
PassRegistration<TestVectorExtIteratorPass>(); | ||
} | ||
} // namespace mlir::test_ext |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters