Skip to content

Commit

Permalink
[MLIR][VectorToLLVM] Handle scalable dim in createVectorLengthValue() (
Browse files Browse the repository at this point in the history
…llvm#93361)

LLVM's Vector Predication Intrinsics require an explicit vector length
parameter:
https://llvm.org/docs/LangRef.html#vector-predication-intrinsics.

For a scalable vector type, this should be caculated as VectorScaleOp
multiplied by base vector length, e.g.: for <[4]xf32> we should return:
vscale * 4.
  • Loading branch information
zhaoshiz authored Jun 13, 2024
1 parent 19b43e1 commit abcbbe7
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 2 deletions.
15 changes: 13 additions & 2 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter,
llvmType);
}

/// Creates a constant value with the 1-D vector shape provided in `llvmType`.
/// Creates a value with the 1-D vector shape provided in `llvmType`.
/// This is used as effective vector length by some intrinsics supporting
/// dynamic vector lengths at runtime.
static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
Expand All @@ -532,9 +532,20 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter,
auto vShape = vType.getShape();
assert(vShape.size() == 1 && "Unexpected multi-dim vector type");

return rewriter.create<LLVM::ConstantOp>(
Value baseVecLength = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0]));

if (!vType.getScalableDims()[0])
return baseVecLength;

// For a scalable vector type, create and return `vScale * baseVecLength`.
Value vScale = rewriter.create<vector::VectorScaleOp>(loc);
vScale =
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), vScale);
Value scalableVecLength =
rewriter.create<arith::MulIOp>(loc, baseVecLength, vScale);
return scalableVecLength;
}

/// Helper method to lower a `vector.reduction` op that performs an arithmetic
Expand Down
110 changes: 110 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,25 @@ func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -
// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32


// -----

func.func @masked_reduce_add_f32_scalable(%arg0: vector<[16]xf32>, %mask : vector<[16]xi1>) -> f32 {
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[16]xf32> into f32 } : vector<[16]xi1> -> f32
return %0 : f32
}

// CHECK-LABEL: func.func @masked_reduce_add_f32_scalable(
// CHECK-SAME: %[[INPUT:.*]]: vector<[16]xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) -> f32 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32


// -----

func.func @masked_reduce_mul_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
Expand Down Expand Up @@ -110,6 +129,24 @@ func.func @masked_reduce_minf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>)

// -----

func.func @masked_reduce_minf_f32_scalable(%arg0: vector<[16]xf32>, %mask : vector<[16]xi1>) -> f32 {
%0 = vector.mask %mask { vector.reduction <minnumf>, %arg0 : vector<[16]xf32> into f32 } : vector<[16]xi1> -> f32
return %0 : f32
}

// CHECK-LABEL: func.func @masked_reduce_minf_f32_scalable(
// CHECK-SAME: %[[INPUT:.*]]: vector<[16]xf32>,
// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) -> f32 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32
// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
// CHECK: "llvm.intr.vp.reduce.fmin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32

// -----

func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 {
%0 = vector.mask %mask { vector.reduction <maxnumf>, %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32
return %0 : f32
Expand Down Expand Up @@ -167,6 +204,25 @@ func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8


// -----

func.func @masked_reduce_add_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <add>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
return %0 : i8
}

// CHECK-LABEL: func.func @masked_reduce_add_i8_scalable(
// CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8


// -----

func.func @masked_reduce_mul_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
Expand Down Expand Up @@ -197,6 +253,24 @@ func.func @masked_reduce_minui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -

// -----

func.func @masked_reduce_minui_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <minui>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
return %0 : i8
}

// CHECK-LABEL: func.func @masked_reduce_minui_i8_scalable(
// CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8
// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
// CHECK: "llvm.intr.vp.reduce.umin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8

// -----

func.func @masked_reduce_maxui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <maxui>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8
Expand Down Expand Up @@ -239,6 +313,24 @@ func.func @masked_reduce_maxsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -

// -----

func.func @masked_reduce_maxsi_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <maxsi>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
return %0 : i8
}

// CHECK-LABEL: func.func @masked_reduce_maxsi_i8_scalable(
// CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-128 : i8) : i8
// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
// CHECK: "llvm.intr.vp.reduce.smax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8

// -----

func.func @masked_reduce_or_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <or>, %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8
return %0 : i8
Expand Down Expand Up @@ -280,4 +372,22 @@ func.func @masked_reduce_xor_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) ->
// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8

// -----

func.func @masked_reduce_xor_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 {
%0 = vector.mask %mask { vector.reduction <xor>, %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8
return %0 : i8
}

// CHECK-LABEL: func.func @masked_reduce_xor_i8_scalable(
// CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>,
// CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 {
// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8
// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32
// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32
// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32
// CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8


0 comments on commit abcbbe7

Please sign in to comment.