Skip to content

Commit 702cb97

Browse files
committed
add support for leading_zeros and trailing_zeros, panics during linking
1 parent 6e2c84d commit 702cb97

File tree

4 files changed

+88
-60
lines changed

4 files changed

+88
-60
lines changed

crates/rustc_codegen_spirv/src/builder/ext_inst.rs

+3-30
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use super::Builder;
22
use crate::builder_spirv::{SpirvValue, SpirvValueExt};
33
use crate::custom_insts;
4+
use rspirv::dr::Operand;
45
use rspirv::spirv::{GLOp, Word};
5-
use rspirv::{dr::Operand, spirv::Capability};
66

77
const GLSL_STD_450: &str = "GLSL.std.450";
88

@@ -13,7 +13,6 @@ pub struct ExtInst {
1313
custom: Option<Word>,
1414

1515
glsl: Option<Word>,
16-
integer_functions_2_intel: bool,
1716
}
1817

1918
impl ExtInst {
@@ -38,32 +37,11 @@ impl ExtInst {
3837
id
3938
}
4039
}
41-
42-
pub fn require_integer_functions_2_intel(&mut self, bx: &Builder<'_, '_>, to_zombie: Word) {
43-
if !self.integer_functions_2_intel {
44-
self.integer_functions_2_intel = true;
45-
if !bx
46-
.builder
47-
.has_capability(Capability::IntegerFunctions2INTEL)
48-
{
49-
bx.zombie(to_zombie, "capability IntegerFunctions2INTEL is required");
50-
}
51-
if !bx
52-
.builder
53-
.has_extension(bx.sym.spv_intel_shader_integer_functions2)
54-
{
55-
bx.zombie(
56-
to_zombie,
57-
"extension SPV_INTEL_shader_integer_functions2 is required",
58-
);
59-
}
60-
}
61-
}
6240
}
6341

6442
impl<'a, 'tcx> Builder<'a, 'tcx> {
6543
pub fn custom_inst(
66-
&mut self,
44+
&self,
6745
result_type: Word,
6846
inst: custom_insts::CustomInst<Operand>,
6947
) -> SpirvValue {
@@ -80,12 +58,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
8058
.with_type(result_type)
8159
}
8260

83-
pub fn gl_op(
84-
&mut self,
85-
op: GLOp,
86-
result_type: Word,
87-
args: impl AsRef<[SpirvValue]>,
88-
) -> SpirvValue {
61+
pub fn gl_op(&self, op: GLOp, result_type: Word, args: impl AsRef<[SpirvValue]>) -> SpirvValue {
8962
let args = args.as_ref();
9063
let glsl = self.ext_inst.borrow_mut().import_glsl(self);
9164
self.emit()

crates/rustc_codegen_spirv/src/builder/intrinsics.rs

+48-26
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::codegen_cx::CodegenCx;
88
use crate::custom_insts::CustomInst;
99
use crate::spirv_type::SpirvType;
1010
use rspirv::dr::Operand;
11-
use rspirv::spirv::GLOp;
11+
use rspirv::spirv::{GLOp, Word};
1212
use rustc_codegen_ssa::mir::operand::OperandRef;
1313
use rustc_codegen_ssa::mir::place::PlaceRef;
1414
use rustc_codegen_ssa::traits::{BuilderMethods, IntrinsicCallBuilderMethods};
@@ -211,34 +211,11 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
211211
self.rotate(val, shift, is_left)
212212
}
213213

214-
// TODO: Do we want to manually implement these instead of using intel instructions?
215214
sym::ctlz | sym::ctlz_nonzero => {
216-
let result = self
217-
.emit()
218-
.u_count_leading_zeros_intel(
219-
args[0].immediate().ty,
220-
None,
221-
args[0].immediate().def(self),
222-
)
223-
.unwrap();
224-
self.ext_inst
225-
.borrow_mut()
226-
.require_integer_functions_2_intel(self, result);
227-
result.with_type(args[0].immediate().ty)
215+
self.count_leading_trailing_zeros(ret_ty, args[0].immediate(), false)
228216
}
229217
sym::cttz | sym::cttz_nonzero => {
230-
let result = self
231-
.emit()
232-
.u_count_trailing_zeros_intel(
233-
args[0].immediate().ty,
234-
None,
235-
args[0].immediate().def(self),
236-
)
237-
.unwrap();
238-
self.ext_inst
239-
.borrow_mut()
240-
.require_integer_functions_2_intel(self, result);
241-
result.with_type(args[0].immediate().ty)
218+
self.count_leading_trailing_zeros(ret_ty, args[0].immediate(), true)
242219
}
243220

244221
sym::ctpop => self
@@ -398,6 +375,51 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
398375
}
399376

400377
impl Builder<'_, '_> {
378+
pub fn count_leading_trailing_zeros(
379+
&self,
380+
ret_ty: Word,
381+
arg: SpirvValue,
382+
trailing: bool,
383+
) -> SpirvValue {
384+
let ty = arg.ty;
385+
match self.cx.lookup_type(ty) {
386+
SpirvType::Integer(bits, _) => {
387+
let int_0 = self.constant_int(ty, 0);
388+
let int_bits = self.constant_int(ret_ty, bits as u128).def(self);
389+
let glsl = self.ext_inst.borrow_mut().import_glsl(self);
390+
391+
let mut emit = self.emit();
392+
let is_0 = emit
393+
.i_equal(ty, None, arg.def(self), int_0.def(self))
394+
.unwrap();
395+
let end_label = emit.id();
396+
let xsb_label = emit.id();
397+
emit.branch_conditional(is_0, end_label, xsb_label, [])
398+
.unwrap();
399+
400+
emit.begin_block(Some(xsb_label)).unwrap();
401+
// rust is always unsigned
402+
let gl_op = if trailing {
403+
GLOp::FindILsb
404+
} else {
405+
GLOp::FindUMsb
406+
};
407+
let find_xsb = emit
408+
.ext_inst(ret_ty, None, glsl, gl_op as u32, [Operand::IdRef(
409+
arg.def(self),
410+
)])
411+
.unwrap();
412+
emit.branch(end_label).unwrap();
413+
414+
emit.begin_block(Some(end_label)).unwrap();
415+
emit.phi(ret_ty, None, [(end_label, int_bits), (xsb_label, find_xsb)])
416+
.unwrap()
417+
.with_type(ret_ty)
418+
}
419+
_ => self.fatal("counting leading / trailing zeros on a non-integer type"),
420+
}
421+
}
422+
401423
pub fn abort_with_kind_and_message_debug_printf(
402424
&mut self,
403425
kind: &str,

crates/rustc_codegen_spirv/src/symbols.rs

-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ pub struct Symbols {
2121
pub spirv: Symbol,
2222
pub libm: Symbol,
2323
pub entry_point_name: Symbol,
24-
pub spv_intel_shader_integer_functions2: Symbol,
2524
pub spv_khr_vulkan_memory_model: Symbol,
2625

2726
descriptor_set: Symbol,
@@ -411,9 +410,6 @@ impl Symbols {
411410
spirv: Symbol::intern("spirv"),
412411
libm: Symbol::intern("libm"),
413412
entry_point_name: Symbol::intern("entry_point_name"),
414-
spv_intel_shader_integer_functions2: Symbol::intern(
415-
"SPV_INTEL_shader_integer_functions2",
416-
),
417413
spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"),
418414

419415
descriptor_set: Symbol::intern("descriptor_set"),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Test all trailing and leading zeros. No need to test ones, they just call the zero variant with !value
2+
3+
// build-pass
4+
5+
use spirv_std::spirv;
6+
7+
#[spirv(fragment)]
8+
pub fn leading_zeros_u32(
9+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &u32,
10+
out: &mut u32,
11+
) {
12+
*out = u32::leading_zeros(*buffer);
13+
}
14+
15+
#[spirv(fragment)]
16+
pub fn trailing_zeros_u32(
17+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &u32,
18+
out: &mut u32,
19+
) {
20+
*out = u32::trailing_zeros(*buffer);
21+
}
22+
23+
#[spirv(fragment)]
24+
pub fn leading_zeros_i32(
25+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &i32,
26+
out: &mut u32,
27+
) {
28+
*out = i32::leading_zeros(*buffer);
29+
}
30+
31+
#[spirv(fragment)]
32+
pub fn trailing_zeros_i32(
33+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &i32,
34+
out: &mut u32,
35+
) {
36+
*out = i32::trailing_zeros(*buffer);
37+
}

0 commit comments

Comments
 (0)