Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu authored Aug 13, 2024
1 parent b297d5b commit 08583d5
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,7 @@ struct DistributeReductions final
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
auto reductionDims = llvm::to_vector<4>(
reductionOp.getReductionDims().getAsRange<IntegerAttr>());
ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
// TODO: Add support for reductions along multiple dimensions.
if (reductionDims.size() > 1)
return failure();
Expand Down Expand Up @@ -461,7 +460,7 @@ struct DistributeReductions final
Value storeVec = rewriter.create<arith::ConstantOp>(
loc, storeVectorType, rewriter.getZeroAttr(storeVectorType));

int reductionDim = reductionDims[0].getInt();
int reductionDim = reductionDims[0];
int parallelDim = reductionDim ^ 1;
if (!sourceLayout.getLane(reductionDim))
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,7 @@ static void propagateLayoutToReduceBroadcastTranspose(
if (!layoutMap.count(reductionSrc))
return;
// Get the reduction dims
auto reductionDims =
llvm::to_vector(reductionOp.getReductionDims().getAsRange<IntegerAttr>());
ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
// Get the transpose permutation
ArrayRef<int64_t> perm = transposeOp.getPermutation();
// Don't support dim-1 broadcasted dims
Expand All @@ -325,8 +324,7 @@ static void propagateLayoutToReduceBroadcastTranspose(
return;
// Check that transpose(reductionDim) == broadcastDim
// and that the shapes match
for (IntegerAttr dimAttr : reductionDims) {
int64_t dim = dimAttr.getInt();
for (int64_t dim : reductionDims) {
int64_t transposedDim = perm[dim];
if (!broadcastedDims.contains(transposedDim))
return;
Expand Down Expand Up @@ -816,13 +814,12 @@ static void distributeReductionBroadcastTranspose(
return;
Location loc = reductionOp.getLoc();
Layout layout = layoutMap.at(source);
auto reductionDims =
llvm::to_vector(reductionOp.getReductionDims().getAsRange<IntegerAttr>());
ArrayRef<int64_t> reductionDims = reductionOp.getReductionDims();
vector::CombiningKind combiningKind = reductionOp.getKind();
// Only support reduction on one dimension
if (reductionDims.size() > 1)
return;
int reductionDim = reductionDims[0].getInt();
int reductionDim = reductionDims[0];
std::array<int, 4> reductionOrder = layout.order[reductionDim];
std::array<int, 4> parallelOrder = layout.order[!reductionDim];
Value acc = reductionOp.getAcc();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ SmallVector<int64_t> getNativeVectorShapeImpl(vector::MultiDimReductionOp op) {
// Unroll all reduction dimensions by size 1 for vector.multi_reduction.
VectorType srcVectorType = op.getSourceVectorType();
auto nativeSize = llvm::to_vector(srcVectorType.getShape());
auto dims = op.getReductionDims().getAsValueRange<IntegerAttr>();
for (const auto &dimAttr : dims) {
nativeSize[dimAttr.getZExtValue()] = 1;
ArrayRef<int64_t> dims = op.getReductionDims();
for (const int64_t dim : dims) {
nativeSize[dim] = 1;
}
return nativeSize;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module @static_1d_mesh_grouping_along_axis_0 {
// CHECK-NOT: util.global private @_mesh_mesh_1d_axes_0
mesh.mesh @mesh_1d(shape = 2)
util.func public @f(%arg0: tensor<1xi8>) -> tensor<1xi8> {
%0 = mesh.all_reduce %arg0 on @mesh_1d mesh_axes = [0] reduction = <sum> : tensor<1xi8> -> tensor<1xi8>
%0 = mesh.all_reduce %arg0 on @mesh_1d mesh_axes = [0] reduction = sum : tensor<1xi8> -> tensor<1xi8>
util.return %0 : tensor<1xi8>
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ util.func public @all_reduce_min_non_default_channel(%arg: tensor<1xi8>) -> tens
// CHECK-DAG: %[[INITIAL_VAL:.+]] = tensor.empty() : tensor<1xi8>
// CHECK: %[[RES:.+]] = flow.collective.all_reduce minimum, ui8, %[[INITIAL_VAL]], %[[ARG]], %[[CHANNEL]]
// CHECK-SAME: (tensor<1xi8>, tensor<1xi8>, !flow.channel) -> %[[INITIAL_VAL]] as tensor<1xi8>
%0 = mesh.all_reduce %arg on @mesh_2d mesh_axes = [1, 0] reduction = <min>
%0 = mesh.all_reduce %arg on @mesh_2d mesh_axes = [1, 0] reduction = min
: tensor<1xi8> -> tensor<1xi8>
// CHECK: util.return %[[RES]] : tensor<1xi8>
util.return %0 : tensor<1xi8>
Expand Down
2 changes: 1 addition & 1 deletion third_party/llvm-project
Submodule llvm-project updated 924 files

0 comments on commit 08583d5

Please sign in to comment.