Skip to content

Commit

Permalink
[VectorExt] Add layout iterator classes (iree-org#16004)
Browse files Browse the repository at this point in the history
This PR adds iterator classes to iterate over the layout and to
concatenate iterators with other frozen iterators. This is required when
distributing reductions. Also adds a test to check for correctness.
  • Loading branch information
harsh-nod authored Dec 27, 2023
1 parent 9cde4e3 commit ccbe33f
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 1 deletion.
1 change: 1 addition & 0 deletions llvm-external-projects/iree-dialects/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ cc_library(
":IREELinalgExtDialect",
":IREELinalgTransformDialect",
":IREELinalgTransformDialectPasses",
":IREEVectorExtDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Rewrite",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,70 @@

// clang-format on

namespace mlir::iree_compiler::IREE::VectorExt {

/// Dimensional Strided Iterator class used to represent
/// an iterator through a single dimension of the layout.
class DimensionalIterator {
public:
DimensionalIterator(int64_t position = 0, int64_t stride = 1)
: position(position), stride(stride) {}
int64_t operator*() const { return position; }
DimensionalIterator &operator++() {
position += stride;
return *this;
}
bool operator!=(const DimensionalIterator &other) const {
return position != other.position;
}

private:
int64_t position, stride;
};

/// Dimensional Range class used to represent the range of
/// a particular dimension of the layout. Can be iterated on
/// using a DimensionalIterator.
class DimensionalRange {
public:
DimensionalRange() {}
DimensionalRange(int64_t start, int64_t stop, int64_t step = 1)
: start(start), stop(stop), step(step) {}
DimensionalIterator begin() const { return DimensionalIterator(start, step); }
DimensionalIterator end() const { return DimensionalIterator(stop, step); }

private:
int64_t start, stop, step;
};

// Iterator class for LayoutAttrs and PerDimLayoutAttrs.
// Provides O(1) access to state for any given dimension.
// Also preserves insertion order.
// Layout iterators skip lane dimensions as these are not
// required during distribution.
class LayoutIterator {
public:
using State = llvm::MapVector<LayoutDimension, DimensionalIterator>;
using DimensionMapping =
llvm::DenseMap<int64_t, SmallVector<LayoutDimension>>;
void maybeFreezeAndConcatenate(const LayoutIterator &frozenIterator);
LayoutIterator(LayoutAttr &attr, DenseMap<LayoutDimension, int64_t> strides);
LayoutIterator(PerDimLayoutAttr &attr,
DenseMap<LayoutDimension, int64_t> strides);
void apply(std::function<void(const LayoutIterator::State &)>);
LayoutIterator &operator++();
State getState() const { return state; }

private:
void initialize(PerDimLayoutAttr &attr,
DenseMap<LayoutDimension, int64_t> strides);
bool iterationComplete();
State state;
llvm::MapVector<LayoutDimension, DimensionalRange> ranges;
DimensionMapping simdDimensionToLayoutDimension;
DenseSet<LayoutDimension> frozenDimensions;
};

} // namespace mlir::iree_compiler::IREE::VectorExt

#endif // IREE_DIALECTS_DIALECT_VECTOREXT_IR_VECTOREXTOPS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,84 @@ OpFoldResult ToSIMTOp::fold(FoldAdaptor) {
return {};
}

LayoutIterator &LayoutIterator::operator++() {
for (auto &[dim, it] : state) {
if (frozenDimensions.contains(dim))
continue;
if (it != ranges[dim].end()) {
++it;
break;
}
it = ranges[dim].begin();
}
return *this;
}

void LayoutIterator::maybeFreezeAndConcatenate(
const LayoutIterator &frozenIterator) {
for (auto &[frozenDim, frozenIt] : frozenIterator.getState()) {
if (!state.contains(frozenDim)) {
frozenDimensions.insert(frozenDim);
state[frozenDim] = frozenIt;
}
}
}

static bool isLaneDimension(LayoutDimension dim) {
return (dim == LayoutDimension::LANEX) || (dim == LayoutDimension::LANEY) ||
(dim == LayoutDimension::LANEZ);
}

void LayoutIterator::initialize(PerDimLayoutAttr &attr,
DenseMap<LayoutDimension, int64_t> strides) {
auto reversedLabels = llvm::reverse(attr.getLabels());
auto reversedShapes = llvm::reverse(attr.getShapes());
for (auto [nameAttr, shape] : llvm::zip(reversedLabels, reversedShapes)) {
LayoutDimension dim = nameAttr.getValue();
if (isLaneDimension(dim))
continue;
int64_t stride = strides.contains(dim) ? strides[dim] : 1;
ranges[dim] = DimensionalRange(0, shape - 1, stride);
state[dim] = ranges[dim].begin();
}
}

LayoutIterator::LayoutIterator(LayoutAttr &attr,
DenseMap<LayoutDimension, int64_t> strides) {
for (PerDimLayoutAttr perDimAttr : attr.getLayouts()) {
initialize(perDimAttr, strides);
}
}

LayoutIterator::LayoutIterator(PerDimLayoutAttr &attr,
DenseMap<LayoutDimension, int64_t> strides) {
initialize(attr, strides);
}

/// The iterator is done when it returns back to
/// its begin state.
bool LayoutIterator::iterationComplete() {
bool complete{true};
for (auto &[dim, it] : state) {
if (frozenDimensions.contains(dim))
continue;
if (it != ranges[dim].begin()) {
complete = false;
break;
}
}
return complete;
}

void LayoutIterator::apply(
std::function<void(const LayoutIterator::State &)> callback) {
do {
callback(state);
++(*this);
} while (!iterationComplete());
}

// clang-format off
#define GET_OP_CLASSES
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.cpp.inc" // IWYU pragma: keep
// clang-format: on
// clang-format on
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// RUN: iree-dialects-opt --split-input-file --test-vector-ext-iterators %s | FileCheck %s

// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:0,
// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:0,
// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:0,
// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:0,
// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:0,
// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:0,
// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:0,
// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:0,
// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:1,
// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:1,
// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:1,
// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:1,
// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:1,
// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:1,
// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:1,
// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:1,
#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [2, 1, 2]>
#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [2, 1, 2]>
#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
func.func @iterator_test(%lhs: memref<4x4xf16>) -> vector<4x4xf16> {
%cst_0 = arith.constant 0.0 : f16
%c0 = arith.constant 0 : index
%result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true], __test_iterator_layout__ = #layout1} : memref<4x4xf16>, vector<4x4xf16>
return %result : vector<4x4xf16>
}

// -----

// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:0, VECTORZ:0,
// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:0, VECTORZ:0,
// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:0, VECTORZ:0,
// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:0, VECTORZ:0,
// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:0, VECTORZ:0,
// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:0, VECTORZ:0,
// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:0, VECTORZ:0,
// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:0, VECTORZ:0,
// CHECK: VECTORY:0, BATCHX:0, VECTORX:0, BATCHY:1, VECTORZ:0,
// CHECK: VECTORY:1, BATCHX:0, VECTORX:0, BATCHY:1, VECTORZ:0,
// CHECK: VECTORY:0, BATCHX:1, VECTORX:0, BATCHY:1, VECTORZ:0,
// CHECK: VECTORY:1, BATCHX:1, VECTORX:0, BATCHY:1, VECTORZ:0,
// CHECK: VECTORY:0, BATCHX:0, VECTORX:1, BATCHY:1, VECTORZ:0,
// CHECK: VECTORY:1, BATCHX:0, VECTORX:1, BATCHY:1, VECTORZ:0,
// CHECK: VECTORY:0, BATCHX:1, VECTORX:1, BATCHY:1, VECTORZ:0,
// CHECK: VECTORY:1, BATCHX:1, VECTORX:1, BATCHY:1, VECTORZ:0,
#row_layout1 = #iree_vector_ext.per_dim_layout<[BATCHX, LANEX, VECTORY], [2, 1, 2]>
#col_layout1 = #iree_vector_ext.per_dim_layout<[BATCHY, LANEY, VECTORX], [2, 1, 2]>
#layout1 = #iree_vector_ext.layout<#row_layout1, #col_layout1>
func.func @frozen_iterator_test(%lhs: memref<4x4xf16>) -> vector<4x4xf16> {
%cst_0 = arith.constant 0.0 : f16
%c0 = arith.constant 0 : index
%result = vector.transfer_read %lhs[%c0, %c0], %cst_0 {in_bounds = [true, true], __test_frozen_iterator_layout__ = #layout1} : memref<4x4xf16>, vector<4x4xf16>
return %result : vector<4x4xf16>
}
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(Transforms)
add_subdirectory(VectorExt)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
add_mlir_library(IREEVectorExtTestPasses
TestIterators.cpp

DEPENDS
mlir-headers

EXCLUDE_FROM_LIBMLIR

LINK_LIBS PUBLIC
IREEVectorExtDialect
MLIRPass
)

Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::iree_compiler::IREE::VectorExt;

namespace {

static const StringRef kIteratorMarker = "__test_iterator_layout__";
static const StringRef kFrozenIteratorMarker =
"__test_frozen_iterator_layout__";

struct TestVectorExtIteratorPass
: public PassWrapper<TestVectorExtIteratorPass, Pass> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorExtIteratorPass)
TestVectorExtIteratorPass() = default;
TestVectorExtIteratorPass(const TestVectorExtIteratorPass &other)
: PassWrapper(other) {}
StringRef getArgument() const final { return "test-vector-ext-iterators"; }
StringRef getDescription() const final {
return "Test VectorExt Iterator pass.";
}
bool canScheduleOn(RegisteredOperationName opName) const override {
return true;
}
// Prints the layout so that LIT can test it for correctness.
static void printFn(const LayoutIterator::State &state) {
for (const auto &[dim, it] : state) {
llvm::outs() << stringifyLayoutDimension(dim).str() + ":" +
std::to_string(*it) + ", ";
}
llvm::outs() << "\n";
}
void testIterator(Operation *op) {
auto layout = dyn_cast_or_null<LayoutAttr>(op->getAttr(kIteratorMarker));
DenseMap<LayoutDimension, int64_t> strides;
LayoutIterator iterator(layout, strides);
iterator.apply(printFn);
}
LayoutDimensionAttr createLayoutDimensionAttr(MLIRContext *ctx,
LayoutDimension dim) {
return LayoutDimensionAttr::get(ctx, dim);
}
LayoutIterator
createFrozenIterator(MLIRContext *ctx,
DenseMap<LayoutDimension, int64_t> &strides) {
SmallVector<LayoutDimensionAttr> labels{
createLayoutDimensionAttr(ctx, LayoutDimension::VECTORZ),
createLayoutDimensionAttr(ctx, LayoutDimension::VECTORX)};
auto newLayout =
LayoutAttr::get(ctx, {PerDimLayoutAttr::get(ctx, labels[0], {1}),
PerDimLayoutAttr::get(ctx, labels[1], {1})});
return LayoutIterator(newLayout, strides);
}
void testFrozenIterator(Operation *op) {
auto layout =
dyn_cast_or_null<LayoutAttr>(op->getAttr(kFrozenIteratorMarker));
DenseMap<LayoutDimension, int64_t> strides;
LayoutIterator iterator(layout, strides);
auto frozenIterator = createFrozenIterator(op->getContext(), strides);
iterator.maybeFreezeAndConcatenate(frozenIterator);
iterator.apply(printFn);
}
void runOnOperation() override {
getOperation()->walk([&](Operation *op) {
if (op->hasAttr(kIteratorMarker)) {
return testIterator(op);
}
if (op->hasAttr(kFrozenIteratorMarker)) {
return testFrozenIterator(op);
}
});
}
};

} // namespace

namespace mlir::test_ext {
void registerVectorExtTestPasses() {
PassRegistration<TestVectorExtIteratorPass>();
}
} // namespace mlir::test_ext
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ set(LIBS
IREELinalgTransformDialectPasses
IREETransformsTestPasses
IREEVectorExtDialect
IREEVectorExtTestPasses
# Core dialects.
MLIRAffineDialect
MLIRArithDialect
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ namespace mlir {
namespace test_ext {
/// Test passes, do not deserve an include.
void registerTestListenerPasses();
void registerVectorExtTestPasses();
} // namespace test_ext
} // namespace mlir

Expand Down Expand Up @@ -88,6 +89,7 @@ int main(int argc, char **argv) {
mlir::linalg::transform::registerDropSchedulePass();
// Local test passes.
mlir::test_ext::registerTestListenerPasses();
mlir::test_ext::registerVectorExtTestPasses();

// External models.
mlir::func::registerInlinerExtension(registry);
Expand Down

0 comments on commit ccbe33f

Please sign in to comment.