Skip to content

Commit 566dfd1

Browse files
committed
simd intrinsics with mask: accept unsigned integer masks
1 parent a7fc463 commit 566dfd1

18 files changed

+119
-148
lines changed

Diff for: compiler/rustc_codegen_gcc/src/intrinsic/simd.rs

+11-14
Original file line numberDiff line numberDiff line change
@@ -443,9 +443,14 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
443443
m_len == v_len,
444444
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
445445
);
446+
// TODO: also support unsigned integers.
446447
match *m_elem_ty.kind() {
447448
ty::Int(_) => {}
448-
_ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }),
449+
_ => return_error!(InvalidMonomorphization::MaskWrongElementType {
450+
span,
451+
name,
452+
ty: m_elem_ty
453+
}),
449454
}
450455
return Ok(bx.vector_select(args[0].immediate(), args[1].immediate(), args[2].immediate()));
451456
}
@@ -987,19 +992,15 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
987992
assert_eq!(pointer_count - 1, ptr_count(element_ty0));
988993
assert_eq!(underlying_ty, non_ptr(element_ty0));
989994

990-
// The element type of the third argument must be a signed integer type of any width:
995+
// The element type of the third argument must be an integer type of any width:
996+
// TODO: also support unsigned integers.
991997
let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx());
992998
match *element_ty2.kind() {
993999
ty::Int(_) => (),
9941000
_ => {
9951001
require!(
9961002
false,
997-
InvalidMonomorphization::ThirdArgElementType {
998-
span,
999-
name,
1000-
expected_element: element_ty2,
1001-
third_arg: arg_tys[2]
1002-
}
1003+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
10031004
);
10041005
}
10051006
}
@@ -1105,17 +1106,13 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
11051106
assert_eq!(underlying_ty, non_ptr(element_ty0));
11061107

11071108
// The element type of the third argument must be a signed integer type of any width:
1109+
// TODO: also support unsigned integers.
11081110
match *element_ty2.kind() {
11091111
ty::Int(_) => (),
11101112
_ => {
11111113
require!(
11121114
false,
1113-
InvalidMonomorphization::ThirdArgElementType {
1114-
span,
1115-
name,
1116-
expected_element: element_ty2,
1117-
third_arg: arg_tys[2]
1118-
}
1115+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
11191116
);
11201117
}
11211118
}

Diff for: compiler/rustc_codegen_llvm/src/intrinsic.rs

+12-44
Original file line numberDiff line numberDiff line change
@@ -1184,18 +1184,6 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
11841184
}};
11851185
}
11861186

1187-
/// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
1188-
macro_rules! require_int_ty {
1189-
($ty: expr, $diag: expr) => {
1190-
match $ty {
1191-
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1192-
_ => {
1193-
return_error!($diag);
1194-
}
1195-
}
1196-
};
1197-
}
1198-
11991187
/// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
12001188
macro_rules! require_int_or_uint_ty {
12011189
($ty: expr, $diag: expr) => {
@@ -1476,9 +1464,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14761464
m_len == v_len,
14771465
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
14781466
);
1479-
let in_elem_bitwidth = require_int_ty!(
1467+
let in_elem_bitwidth = require_int_or_uint_ty!(
14801468
m_elem_ty.kind(),
1481-
InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }
1469+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: m_elem_ty }
14821470
);
14831471
let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len);
14841472
return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate()));
@@ -1499,7 +1487,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14991487
// Integer vector <i{in_bitwidth} x in_len>:
15001488
let in_elem_bitwidth = require_int_or_uint_ty!(
15011489
in_elem.kind(),
1502-
InvalidMonomorphization::VectorArgument { span, name, in_ty, in_elem }
1490+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: in_elem }
15031491
);
15041492

15051493
let i1xn = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, in_len);
@@ -1723,14 +1711,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
17231711
}
17241712
);
17251713

1726-
let mask_elem_bitwidth = require_int_ty!(
1714+
let mask_elem_bitwidth = require_int_or_uint_ty!(
17271715
element_ty2.kind(),
1728-
InvalidMonomorphization::ThirdArgElementType {
1729-
span,
1730-
name,
1731-
expected_element: element_ty2,
1732-
third_arg: arg_tys[2]
1733-
}
1716+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
17341717
);
17351718

17361719
// Alignment of T, must be a constant integer value:
@@ -1825,14 +1808,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18251808
}
18261809
);
18271810

1828-
let m_elem_bitwidth = require_int_ty!(
1811+
let m_elem_bitwidth = require_int_or_uint_ty!(
18291812
mask_elem.kind(),
1830-
InvalidMonomorphization::ThirdArgElementType {
1831-
span,
1832-
name,
1833-
expected_element: values_elem,
1834-
third_arg: mask_ty,
1835-
}
1813+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: mask_elem }
18361814
);
18371815

18381816
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
@@ -1915,14 +1893,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19151893
}
19161894
);
19171895

1918-
let m_elem_bitwidth = require_int_ty!(
1896+
let m_elem_bitwidth = require_int_or_uint_ty!(
19191897
mask_elem.kind(),
1920-
InvalidMonomorphization::ThirdArgElementType {
1921-
span,
1922-
name,
1923-
expected_element: values_elem,
1924-
third_arg: mask_ty,
1925-
}
1898+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: mask_elem }
19261899
);
19271900

19281901
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
@@ -2010,15 +1983,10 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
20101983
}
20111984
);
20121985

2013-
// The element type of the third argument must be a signed integer type of any width:
2014-
let mask_elem_bitwidth = require_int_ty!(
1986+
// The element type of the third argument must be an integer type of any width:
1987+
let mask_elem_bitwidth = require_int_or_uint_ty!(
20151988
element_ty2.kind(),
2016-
InvalidMonomorphization::ThirdArgElementType {
2017-
span,
2018-
name,
2019-
expected_element: element_ty2,
2020-
third_arg: arg_tys[2]
2021-
}
1989+
InvalidMonomorphization::MaskWrongElementType { span, name, ty: element_ty2 }
20221990
);
20231991

20241992
// Alignment of T, must be a constant integer value:

Diff for: compiler/rustc_codegen_ssa/messages.ftl

+1-6
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,7 @@ codegen_ssa_invalid_monomorphization_inserted_type = invalid monomorphization of
133133
134134
codegen_ssa_invalid_monomorphization_invalid_bitmask = invalid monomorphization of `{$name}` intrinsic: invalid bitmask `{$mask_ty}`, expected `u{$expected_int_bits}` or `[u8; {$expected_bytes}]`
135135
136-
codegen_ssa_invalid_monomorphization_mask_type = invalid monomorphization of `{$name}` intrinsic: found mask element type is `{$ty}`, expected a signed integer type
137-
.note = the mask may be widened, which only has the correct behavior for signed integers
136+
codegen_ssa_invalid_monomorphization_mask_wrong_element_type = invalid monomorphization of `{$name}` intrinsic: expected mask element type to be an integer, found `{$ty}`
138137
139138
codegen_ssa_invalid_monomorphization_mismatched_lengths = invalid monomorphization of `{$name}` intrinsic: mismatched lengths: mask length `{$m_len}` != other vector length `{$v_len}`
140139
@@ -166,8 +165,6 @@ codegen_ssa_invalid_monomorphization_simd_shuffle = invalid monomorphization of
166165
167166
codegen_ssa_invalid_monomorphization_simd_third = invalid monomorphization of `{$name}` intrinsic: expected SIMD third type, found non-SIMD `{$ty}`
168167
169-
codegen_ssa_invalid_monomorphization_third_arg_element_type = invalid monomorphization of `{$name}` intrinsic: expected element type `{$expected_element}` of third argument `{$third_arg}` to be a signed integer type
170-
171168
codegen_ssa_invalid_monomorphization_third_argument_length = invalid monomorphization of `{$name}` intrinsic: expected third argument with length {$in_len} (same as input type `{$in_ty}`), found `{$arg_ty}` with length {$out_len}
172169
173170
codegen_ssa_invalid_monomorphization_unrecognized_intrinsic = invalid monomorphization of `{$name}` intrinsic: unrecognized intrinsic `{$name}`
@@ -180,8 +177,6 @@ codegen_ssa_invalid_monomorphization_unsupported_symbol = invalid monomorphizati
180177
181178
codegen_ssa_invalid_monomorphization_unsupported_symbol_of_size = invalid monomorphization of `{$name}` intrinsic: unsupported {$symbol} from `{$in_ty}` with element `{$in_elem}` of size `{$size}` to `{$ret_ty}`
182179
183-
codegen_ssa_invalid_monomorphization_vector_argument = invalid monomorphization of `{$name}` intrinsic: vector argument `{$in_ty}`'s element type `{$in_elem}`, expected integer element type
184-
185180
codegen_ssa_invalid_no_sanitize = invalid argument for `no_sanitize`
186181
.note = expected one of: `address`, `cfi`, `hwaddress`, `kcfi`, `memory`, `memtag`, `shadow-call-stack`, or `thread`
187182

Diff for: compiler/rustc_codegen_ssa/src/errors.rs

+2-21
Original file line numberDiff line numberDiff line change
@@ -1059,24 +1059,14 @@ pub enum InvalidMonomorphization<'tcx> {
10591059
v_len: u64,
10601060
},
10611061

1062-
#[diag(codegen_ssa_invalid_monomorphization_mask_type, code = E0511)]
1063-
#[note]
1064-
MaskType {
1062+
#[diag(codegen_ssa_invalid_monomorphization_mask_wrong_element_type, code = E0511)]
1063+
MaskWrongElementType {
10651064
#[primary_span]
10661065
span: Span,
10671066
name: Symbol,
10681067
ty: Ty<'tcx>,
10691068
},
10701069

1071-
#[diag(codegen_ssa_invalid_monomorphization_vector_argument, code = E0511)]
1072-
VectorArgument {
1073-
#[primary_span]
1074-
span: Span,
1075-
name: Symbol,
1076-
in_ty: Ty<'tcx>,
1077-
in_elem: Ty<'tcx>,
1078-
},
1079-
10801070
#[diag(codegen_ssa_invalid_monomorphization_cannot_return, code = E0511)]
10811071
CannotReturn {
10821072
#[primary_span]
@@ -1099,15 +1089,6 @@ pub enum InvalidMonomorphization<'tcx> {
10991089
mutability: ExpectedPointerMutability,
11001090
},
11011091

1102-
#[diag(codegen_ssa_invalid_monomorphization_third_arg_element_type, code = E0511)]
1103-
ThirdArgElementType {
1104-
#[primary_span]
1105-
span: Span,
1106-
name: Symbol,
1107-
expected_element: Ty<'tcx>,
1108-
third_arg: Ty<'tcx>,
1109-
},
1110-
11111092
#[diag(codegen_ssa_invalid_monomorphization_unsupported_symbol_of_size, code = E0511)]
11121093
UnsupportedSymbolOfSize {
11131094
#[primary_span]

Diff for: library/core/src/intrinsics/simd.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ pub unsafe fn simd_shuffle<T, U, V>(x: T, y: T, idx: U) -> V;
271271
///
272272
/// `U` must be a vector of pointers to the element type of `T`, with the same length as `T`.
273273
///
274-
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
274+
/// `V` must be a vector of integers with the same length as `T` (but any element size).
275275
///
276276
/// For each pointer in `ptr`, if the corresponding value in `mask` is `!0`, read the pointer.
277277
/// Otherwise if the corresponding value in `mask` is `0`, return the corresponding value from
@@ -292,7 +292,7 @@ pub unsafe fn simd_gather<T, U, V>(val: T, ptr: U, mask: V) -> T;
292292
///
293293
/// `U` must be a vector of pointers to the element type of `T`, with the same length as `T`.
294294
///
295-
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
295+
/// `V` must be a vector of integers with the same length as `T` (but any element size).
296296
///
297297
/// For each pointer in `ptr`, if the corresponding value in `mask` is `!0`, write the
298298
/// corresponding value in `val` to the pointer.
@@ -316,7 +316,7 @@ pub unsafe fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);
316316
///
317317
/// `U` must be a pointer to the element type of `T`
318318
///
319-
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
319+
/// `V` must be a vector of integers with the same length as `T` (but any element size).
320320
///
321321
/// For each element, if the corresponding value in `mask` is `!0`, read the corresponding
322322
/// pointer offset from `ptr`.
@@ -339,7 +339,7 @@ pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
339339
///
340340
/// `U` must be a pointer to the element type of `T`
341341
///
342-
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
342+
/// `V` must be a vector of integers with the same length as `T` (but any element size).
343343
///
344344
/// For each element, if the corresponding value in `mask` is `!0`, write the corresponding
345345
/// value in `val` to the pointer offset from `ptr`.
@@ -523,7 +523,7 @@ pub unsafe fn simd_bitmask<T, U>(x: T) -> U;
523523
///
524524
/// `T` must be a vector.
525525
///
526-
/// `M` must be a signed integer vector with the same length as `T` (but any element size).
526+
/// `M` must be an integer vector with the same length as `T` (but any element size).
527527
///
528528
/// For each element, if the corresponding value in `mask` is `!0`, select the element from
529529
/// `if_true`. If the corresponding value in `mask` is `0`, select the element from

Diff for: src/tools/miri/src/helpers.rs

+5
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,11 @@ pub(crate) fn bool_to_simd_element(b: bool, size: Size) -> Scalar {
12951295
}
12961296

12971297
pub(crate) fn simd_element_to_bool(elem: ImmTy<'_>) -> InterpResult<'_, bool> {
1298+
assert!(
1299+
matches!(elem.layout.ty.kind(), ty::Int(_) | ty::Uint(_)),
1300+
"SIMD mask element type must be an integer, but this is `{}`",
1301+
elem.layout.ty
1302+
);
12981303
let val = elem.to_scalar().to_int(elem.layout.size)?;
12991304
interp_ok(match val {
13001305
0 => false,

Diff for: tests/codegen/simd-intrinsic/simd-intrinsic-generic-gather.rs

+13
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ pub unsafe fn gather_f32x2(
2929
simd_gather(values, pointers, mask)
3030
}
3131

32+
// CHECK-LABEL: @gather_f32x2_unsigned
33+
#[no_mangle]
34+
pub unsafe fn gather_f32x2_unsigned(
35+
pointers: Vec2<*const f32>,
36+
mask: Vec2<u32>,
37+
values: Vec2<f32>,
38+
) -> Vec2<f32> {
39+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
40+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
41+
// CHECK: call <2 x float> @llvm.masked.gather.v2f32.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]], <2 x float> {{.*}})
42+
simd_gather(values, pointers, mask)
43+
}
44+
3245
// CHECK-LABEL: @gather_pf32x2
3346
#[no_mangle]
3447
pub unsafe fn gather_pf32x2(

Diff for: tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-load.rs

+13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ pub unsafe fn load_f32x2(mask: Vec2<i32>, pointer: *const f32, values: Vec2<f32>
2323
simd_masked_load(mask, pointer, values)
2424
}
2525

26+
// CHECK-LABEL: @load_f32x2_unsigned
27+
#[no_mangle]
28+
pub unsafe fn load_f32x2_unsigned(
29+
mask: Vec2<u32>,
30+
pointer: *const f32,
31+
values: Vec2<f32>,
32+
) -> Vec2<f32> {
33+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
34+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
35+
// CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> [[B]], <2 x float> {{.*}})
36+
simd_masked_load(mask, pointer, values)
37+
}
38+
2639
// CHECK-LABEL: @load_pf32x4
2740
#[no_mangle]
2841
pub unsafe fn load_pf32x4(

Diff for: tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-store.rs

+9
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ pub unsafe fn store_f32x2(mask: Vec2<i32>, pointer: *mut f32, values: Vec2<f32>)
2323
simd_masked_store(mask, pointer, values)
2424
}
2525

26+
// CHECK-LABEL: @store_f32x2_unsigned
27+
#[no_mangle]
28+
pub unsafe fn store_f32x2_unsigned(mask: Vec2<u32>, pointer: *mut f32, values: Vec2<f32>) {
29+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
30+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
31+
// CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> [[B]])
32+
simd_masked_store(mask, pointer, values)
33+
}
34+
2635
// CHECK-LABEL: @store_pf32x4
2736
#[no_mangle]
2837
pub unsafe fn store_pf32x4(mask: Vec4<i32>, pointer: *mut *const f32, values: Vec4<*const f32>) {

Diff for: tests/codegen/simd-intrinsic/simd-intrinsic-generic-scatter.rs

+9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ pub unsafe fn scatter_f32x2(pointers: Vec2<*mut f32>, mask: Vec2<i32>, values: V
2525
simd_scatter(values, pointers, mask)
2626
}
2727

28+
// CHECK-LABEL: @scatter_f32x2_unsigned
29+
#[no_mangle]
30+
pub unsafe fn scatter_f32x2_unsigned(pointers: Vec2<*mut f32>, mask: Vec2<u32>, values: Vec2<f32>) {
31+
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, {{<i32 31, i32 31>|splat \(i32 31\)}}
32+
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
33+
// CHECK: call void @llvm.masked.scatter.v2f32.v2p0(<2 x float> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> [[B]]
34+
simd_scatter(values, pointers, mask)
35+
}
36+
2837
// CHECK-LABEL: @scatter_pf32x2
2938
#[no_mangle]
3039
pub unsafe fn scatter_pf32x2(

Diff for: tests/codegen/simd-intrinsic/simd-intrinsic-generic-select.rs

+13
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ pub struct b8x4(pub [i8; 4]);
2222
#[derive(Copy, Clone, PartialEq, Debug)]
2323
pub struct i32x4([i32; 4]);
2424

25+
#[repr(simd)]
26+
#[derive(Copy, Clone, PartialEq, Debug)]
27+
pub struct u32x4([u32; 4]);
28+
2529
// CHECK-LABEL: @select_m8
2630
#[no_mangle]
2731
pub unsafe fn select_m8(m: b8x4, a: f32x4, b: f32x4) -> f32x4 {
@@ -40,6 +44,15 @@ pub unsafe fn select_m32(m: i32x4, a: f32x4, b: f32x4) -> f32x4 {
4044
simd_select(m, a, b)
4145
}
4246

47+
// CHECK-LABEL: @select_m32_unsigned
48+
#[no_mangle]
49+
pub unsafe fn select_m32_unsigned(m: u32x4, a: f32x4, b: f32x4) -> f32x4 {
50+
// CHECK: [[A:%[0-9]+]] = lshr <4 x i32> %{{.*}}, {{<i32 31, i32 31, i32 31, i32 31>|splat \(i32 31\)}}
51+
// CHECK: [[B:%[0-9]+]] = trunc <4 x i32> [[A]] to <4 x i1>
52+
// CHECK: select <4 x i1> [[B]]
53+
simd_select(m, a, b)
54+
}
55+
4356
// CHECK-LABEL: @select_bitmask
4457
#[no_mangle]
4558
pub unsafe fn select_bitmask(m: i8, a: f32x8, b: f32x8) -> f32x8 {

0 commit comments

Comments
 (0)