diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 8671c8412f4ec8..60f7e95ade689f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -523,7 +523,7 @@ static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, llvmType); } -/// Creates a constant value with the 1-D vector shape provided in `llvmType`. +/// Creates a value with the 1-D vector shape provided in `llvmType`. /// This is used as effective vector length by some intrinsics supporting /// dynamic vector lengths at runtime. static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, @@ -532,9 +532,20 @@ static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, auto vShape = vType.getShape(); assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); - return rewriter.create( + Value baseVecLength = rewriter.create( loc, rewriter.getI32Type(), rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); + + if (!vType.getScalableDims()[0]) + return baseVecLength; + + // For a scalable vector type, create and return `vScale * baseVecLength`. + Value vScale = rewriter.create(loc); + vScale = + rewriter.create(loc, rewriter.getI32Type(), vScale); + Value scalableVecLength = + rewriter.create(loc, baseVecLength, vScale); + return scalableVecLength; } /// Helper method to lower a `vector.reduction` op that performs an arithmetic diff --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir index f98a05f8d17e2c..c7d9e22fb24234 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir @@ -79,6 +79,25 @@ func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) - // CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32 +// ----- + +func.func @masked_reduce_add_f32_scalable(%arg0: vector<[16]xf32>, %mask : vector<[16]xi1>) -> f32 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<[16]xf32> into f32 } : vector<[16]xi1> -> f32 + return %0 : f32 +} + +// CHECK-LABEL: func.func @masked_reduce_add_f32_scalable( +// CHECK-SAME: %[[INPUT:.*]]: vector<[16]xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) -> f32 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32 +// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32 +// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32 +// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32 + + // ----- func.func @masked_reduce_mul_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 { @@ -110,6 +129,24 @@ func.func @masked_reduce_minf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) // ----- +func.func @masked_reduce_minf_f32_scalable(%arg0: vector<[16]xf32>, %mask : vector<[16]xi1>) -> f32 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<[16]xf32> into f32 } : vector<[16]xi1> -> f32 + return %0 : f32 +} + +// CHECK-LABEL: func.func @masked_reduce_minf_f32_scalable( +// CHECK-SAME: %[[INPUT:.*]]: vector<[16]xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) -> f32 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32 +// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(16 : i32) : i32 +// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32 +// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32 +// CHECK: "llvm.intr.vp.reduce.fmin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (f32, vector<[16]xf32>, vector<[16]xi1>, i32) -> f32 + +// ----- + func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 { %0 = vector.mask %mask { vector.reduction , %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32 return %0 : f32 @@ -167,6 +204,25 @@ func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> // CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8 +// ----- + +func.func @masked_reduce_add_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_add_i8_scalable( +// CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8 +// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32 +// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32 +// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8 + + // ----- func.func @masked_reduce_mul_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { @@ -197,6 +253,24 @@ func.func @masked_reduce_minui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) - // ----- +func.func @masked_reduce_minui_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_minui_i8_scalable( +// CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8 +// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32 +// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32 +// CHECK: "llvm.intr.vp.reduce.umin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8 + +// ----- + func.func @masked_reduce_maxui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { %0 = vector.mask %mask { vector.reduction , %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8 return %0 : i8 @@ -239,6 +313,24 @@ func.func @masked_reduce_maxsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) - // ----- +func.func @masked_reduce_maxsi_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_maxsi_i8_scalable( +// CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-128 : i8) : i8 +// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32 +// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32 +// CHECK: "llvm.intr.vp.reduce.smax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8 + +// ----- + func.func @masked_reduce_or_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { %0 = vector.mask %mask { vector.reduction , %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8 return %0 : i8 @@ -280,4 +372,22 @@ func.func @masked_reduce_xor_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> // CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32 // CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8 +// ----- + +func.func @masked_reduce_xor_i8_scalable(%arg0: vector<[32]xi8>, %mask : vector<[32]xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<[32]xi8> into i8 } : vector<[32]xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_xor_i8_scalable( +// CHECK-SAME: %[[INPUT:.*]]: vector<[32]xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<[32]xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8 +// CHECK: %[[VL_BASE:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK: %[[CAST_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK: %[[CAST_I32:.*]] = arith.index_cast %[[CAST_IDX]] : index to i32 +// CHECK: %[[VL_MUL:.*]] = arith.muli %[[VL_BASE]], %[[CAST_I32]] : i32 +// CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL_MUL]]) : (i8, vector<[32]xi8>, vector<[32]xi1>, i32) -> i8 +