Skip to content

Commit b7d1f4c

Browse files
authored
[naga spv-out] Ensure loops generated by SPIRV backend are bounded (#7080)
If it is undefined behaviour for loops to be infinite, then, when encountering an infinite loop, downstream compilers are able to make certain optimizations that may be unsafe. For example, omitting bounds checks. To prevent this, we must ensure that any loops emitted by our backends are provably bounded. We already do this for both the MSL and HLSL backends. This patch makes us do so for SPIRV as well. The construct used is the same as for HLSL and MSL backends: use a vec2<u32> to emulate a 64-bit counter, which is incremented every iteration and breaks after 2^64 iterations. While the implementation is fairly verbose for the SPIRV backend, the logic is simple enough. The one point of note is that SPIRV requires `OpVariable` instructions with a `Function` storage class to be located at the start of the first block of the function. We therefore remember the IDs generated for each loop counter variable in a function whilst generating the function body's code. The instructions to declare these variables are then emitted in `Function::to_words()` prior to emitting the function's body. As this may negatively impact shader performance, this workaround can be disabled using the same mechanism as for other backends: eg calling Device::create_shader_module_trusted() and setting the ShaderRuntimeChecks::force_loop_bounding flag to false.
1 parent e0f0185 commit b7d1f4c

20 files changed

+2989
-2269
lines changed

CHANGELOG.md

+11
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,17 @@ pub enum PollError {
117117
By @cwfitzgerald in [#6942](https://github.com/gfx-rs/wgpu/pull/6942).
118118
By @cwfitzgerald in [#7030](https://github.com/gfx-rs/wgpu/pull/7030).
119119
120+
#### Naga
121+
122+
##### Ensure loops generated by SPIR-V and HLSL Naga backends are bounded
123+
124+
Make sure that all loops in shaders generated by these naga backends are bounded
125+
to avoid undefined behaviour due to infinite loops. Note that this may have a
126+
performance cost. As with the existing implementation for the MSL backend this
127+
can be disabled by using `Device::create_shader_module_trusted()`.
128+
129+
By @jamienicol in [#6929](https://github.com/gfx-rs/wgpu/pull/6929) and [#7080](https://github.com/gfx-rs/wgpu/pull/7080).
130+
120131
### New Features
121132
122133
#### General

naga/src/back/spv/block.rs

+153
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,155 @@ impl Writer {
261261
}
262262

263263
impl BlockContext<'_> {
264+
/// Generates code to ensure that a loop is bounded. Should be called immediately
265+
/// after adding the OpLoopMerge instruction to `block`. This function will
266+
/// [`consume()`](crate::back::spv::Function::consume) `block` and append its
267+
/// instructions to a new [`Block`], which will be returned to the caller for it to
268+
/// consumed prior to writing the loop body.
269+
///
270+
/// Additionally this function will populate [`force_loop_bounding_vars`](crate::back::spv::Function::force_loop_bounding_vars),
271+
/// ensuring that [`Function::to_words()`](crate::back::spv::Function::to_words) will
272+
/// declare the required variables.
273+
///
274+
/// See [`crate::back::msl::Writer::gen_force_bounded_loop_statements`] for details
275+
/// of why this is required.
276+
fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
277+
let uint_type_id = self.writer.get_uint_type_id();
278+
let uint2_type_id = self.writer.get_uint2_type_id();
279+
let uint2_ptr_type_id = self
280+
.writer
281+
.get_uint2_pointer_type_id(spirv::StorageClass::Function);
282+
let bool_type_id = self.writer.get_bool_type_id();
283+
let bool2_type_id = self.writer.get_bool2_type_id();
284+
let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
285+
let zero_uint2_const_id = self.writer.get_constant_composite(
286+
LookupType::Local(LocalType::Numeric(NumericType::Vector {
287+
size: crate::VectorSize::Bi,
288+
scalar: crate::Scalar::U32,
289+
})),
290+
&[zero_uint_const_id, zero_uint_const_id],
291+
);
292+
let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
293+
let max_uint_const_id = self
294+
.writer
295+
.get_constant_scalar(crate::Literal::U32(u32::MAX));
296+
let max_uint2_const_id = self.writer.get_constant_composite(
297+
LookupType::Local(LocalType::Numeric(NumericType::Vector {
298+
size: crate::VectorSize::Bi,
299+
scalar: crate::Scalar::U32,
300+
})),
301+
&[max_uint_const_id, max_uint_const_id],
302+
);
303+
304+
let loop_counter_var_id = self.gen_id();
305+
if self.writer.flags.contains(WriterFlags::DEBUG) {
306+
self.writer
307+
.debugs
308+
.push(Instruction::name(loop_counter_var_id, "loop_bound"));
309+
}
310+
let var = super::LocalVariable {
311+
id: loop_counter_var_id,
312+
instruction: Instruction::variable(
313+
uint2_ptr_type_id,
314+
loop_counter_var_id,
315+
spirv::StorageClass::Function,
316+
Some(zero_uint2_const_id),
317+
),
318+
};
319+
self.function.force_loop_bounding_vars.push(var);
320+
321+
let break_if_block = self.gen_id();
322+
323+
self.function
324+
.consume(block, Instruction::branch(break_if_block));
325+
block = Block::new(break_if_block);
326+
327+
// Load the current loop counter value from its variable. We use a vec2<u32> to
328+
// simulate a 64-bit counter.
329+
let load_id = self.gen_id();
330+
block.body.push(Instruction::load(
331+
uint2_type_id,
332+
load_id,
333+
loop_counter_var_id,
334+
None,
335+
));
336+
337+
// If both the high and low u32s have reached u32::MAX then break. ie
338+
// if (all(eq(loop_counter, vec2(u32::MAX)))) { break; }
339+
let eq_id = self.gen_id();
340+
block.body.push(Instruction::binary(
341+
spirv::Op::IEqual,
342+
bool2_type_id,
343+
eq_id,
344+
max_uint2_const_id,
345+
load_id,
346+
));
347+
let all_eq_id = self.gen_id();
348+
block.body.push(Instruction::relational(
349+
spirv::Op::All,
350+
bool_type_id,
351+
all_eq_id,
352+
eq_id,
353+
));
354+
355+
let inc_counter_block_id = self.gen_id();
356+
block.body.push(Instruction::selection_merge(
357+
inc_counter_block_id,
358+
spirv::SelectionControl::empty(),
359+
));
360+
self.function.consume(
361+
block,
362+
Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
363+
);
364+
block = Block::new(inc_counter_block_id);
365+
366+
// To simulate a 64-bit counter we always increment the low u32, and increment
367+
// the high u32 when the low u32 overflows. ie
368+
// counter += vec2(select(0u, 1u, counter.y == u32::MAX), 1u);
369+
let low_id = self.gen_id();
370+
block.body.push(Instruction::composite_extract(
371+
uint_type_id,
372+
low_id,
373+
load_id,
374+
&[1],
375+
));
376+
let low_overflow_id = self.gen_id();
377+
block.body.push(Instruction::binary(
378+
spirv::Op::IEqual,
379+
bool_type_id,
380+
low_overflow_id,
381+
low_id,
382+
max_uint_const_id,
383+
));
384+
let carry_bit_id = self.gen_id();
385+
block.body.push(Instruction::select(
386+
uint_type_id,
387+
carry_bit_id,
388+
low_overflow_id,
389+
one_uint_const_id,
390+
zero_uint_const_id,
391+
));
392+
let increment_id = self.gen_id();
393+
block.body.push(Instruction::composite_construct(
394+
uint2_type_id,
395+
increment_id,
396+
&[carry_bit_id, one_uint_const_id],
397+
));
398+
let result_id = self.gen_id();
399+
block.body.push(Instruction::binary(
400+
spirv::Op::IAdd,
401+
uint2_type_id,
402+
result_id,
403+
load_id,
404+
increment_id,
405+
));
406+
block
407+
.body
408+
.push(Instruction::store(loop_counter_var_id, result_id, None));
409+
410+
block
411+
}
412+
264413
/// Cache an expression for a value.
265414
pub(super) fn cache_expression_value(
266415
&mut self,
@@ -2558,6 +2707,10 @@ impl BlockContext<'_> {
25582707
continuing_id,
25592708
spirv::SelectionControl::NONE,
25602709
));
2710+
2711+
if self.force_loop_bounding {
2712+
block = self.write_force_bounded_loop_instructions(block, merge_id);
2713+
}
25612714
self.function.consume(block, Instruction::branch(body_id));
25622715

25632716
// We can ignore the `BlockExitDisposition` returned here because,

naga/src/back/spv/mod.rs

+10
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ struct Function {
144144
signature: Option<Instruction>,
145145
parameters: Vec<FunctionArgument>,
146146
variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>,
147+
/// List of local variables used as a counters to ensure that all loops are bounded.
148+
force_loop_bounding_vars: Vec<LocalVariable>,
147149

148150
/// A map taking an expression that yields a composite value (array, matrix)
149151
/// to the temporary variables we have spilled it to, if any. Spilling
@@ -726,6 +728,8 @@ struct BlockContext<'w> {
726728

727729
/// Tracks the constness of `Expression`s residing in `self.ir_function.expressions`
728730
expression_constness: ExpressionConstnessTracker,
731+
732+
force_loop_bounding: bool,
729733
}
730734

731735
impl BlockContext<'_> {
@@ -779,6 +783,7 @@ pub struct Writer {
779783
flags: WriterFlags,
780784
bounds_check_policies: BoundsCheckPolicies,
781785
zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,
786+
force_loop_bounding: bool,
782787
void_type: Word,
783788
//TODO: convert most of these into vectors, addressable by handle indices
784789
lookup_type: crate::FastHashMap<LookupType, Word>,
@@ -882,6 +887,10 @@ pub struct Options<'a> {
882887
/// Dictates the way workgroup variables should be zero initialized
883888
pub zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,
884889

890+
/// If set, loops will have code injected into them, forcing the compiler
891+
/// to think the number of iterations is bounded.
892+
pub force_loop_bounding: bool,
893+
885894
pub debug_info: Option<DebugInfo<'a>>,
886895
}
887896

@@ -900,6 +909,7 @@ impl Default for Options<'_> {
900909
capabilities: None,
901910
bounds_check_policies: BoundsCheckPolicies::default(),
902911
zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill,
912+
force_loop_bounding: true,
903913
debug_info: None,
904914
}
905915
}

naga/src/back/spv/writer.rs

+33
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ impl Function {
3232
for local_var in self.variables.values() {
3333
local_var.instruction.to_words(sink);
3434
}
35+
for local_var in self.force_loop_bounding_vars.iter() {
36+
local_var.instruction.to_words(sink);
37+
}
3538
for internal_var in self.spilled_composites.values() {
3639
internal_var.instruction.to_words(sink);
3740
}
@@ -71,6 +74,7 @@ impl Writer {
7174
flags: options.flags,
7275
bounds_check_policies: options.bounds_check_policies,
7376
zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
77+
force_loop_bounding: options.force_loop_bounding,
7478
void_type,
7579
lookup_type: crate::FastHashMap::default(),
7680
lookup_function: crate::FastHashMap::default(),
@@ -111,6 +115,7 @@ impl Writer {
111115
flags: self.flags,
112116
bounds_check_policies: self.bounds_check_policies,
113117
zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
118+
force_loop_bounding: self.force_loop_bounding,
114119
capabilities_available: take(&mut self.capabilities_available),
115120
binding_map: take(&mut self.binding_map),
116121

@@ -267,6 +272,14 @@ impl Writer {
267272
self.get_type_id(local_type.into())
268273
}
269274

275+
pub(super) fn get_uint2_type_id(&mut self) -> Word {
276+
let local_type = LocalType::Numeric(NumericType::Vector {
277+
size: crate::VectorSize::Bi,
278+
scalar: crate::Scalar::U32,
279+
});
280+
self.get_type_id(local_type.into())
281+
}
282+
270283
pub(super) fn get_uint3_type_id(&mut self) -> Word {
271284
let local_type = LocalType::Numeric(NumericType::Vector {
272285
size: crate::VectorSize::Tri,
@@ -283,6 +296,17 @@ impl Writer {
283296
self.get_type_id(local_type.into())
284297
}
285298

299+
pub(super) fn get_uint2_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
300+
let local_type = LocalType::LocalPointer {
301+
base: NumericType::Vector {
302+
size: crate::VectorSize::Bi,
303+
scalar: crate::Scalar::U32,
304+
},
305+
class,
306+
};
307+
self.get_type_id(local_type.into())
308+
}
309+
286310
pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
287311
let local_type = LocalType::LocalPointer {
288312
base: NumericType::Vector {
@@ -299,6 +323,14 @@ impl Writer {
299323
self.get_type_id(local_type.into())
300324
}
301325

326+
pub(super) fn get_bool2_type_id(&mut self) -> Word {
327+
let local_type = LocalType::Numeric(NumericType::Vector {
328+
size: crate::VectorSize::Bi,
329+
scalar: crate::Scalar::BOOL,
330+
});
331+
self.get_type_id(local_type.into())
332+
}
333+
302334
pub(super) fn get_bool3_type_id(&mut self) -> Word {
303335
let local_type = LocalType::Numeric(NumericType::Vector {
304336
size: crate::VectorSize::Tri,
@@ -839,6 +871,7 @@ impl Writer {
839871

840872
// Steal the Writer's temp list for a bit.
841873
temp_list: std::mem::take(&mut self.temp_list),
874+
force_loop_bounding: self.force_loop_bounding,
842875
writer: self,
843876
expression_constness: super::ExpressionConstnessTracker::from_arena(
844877
&ir_function.expressions,

naga/tests/out/spv/6220-break-from-loop.spvasm

+37-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; SPIR-V
22
; Version: 1.1
33
; Generator: rspirv
4-
; Bound: 26
4+
; Bound: 46
55
OpCapability Shader
66
OpCapability Linkage
77
%1 = OpExtInstImport "GLSL.std.450"
@@ -13,31 +13,55 @@ OpMemoryModel Logical GLSL450
1313
%8 = OpConstant %3 4
1414
%9 = OpConstant %3 1
1515
%11 = OpTypePointer Function %3
16-
%18 = OpTypeBool
16+
%17 = OpTypeInt 32 0
17+
%18 = OpTypeVector %17 2
18+
%19 = OpTypePointer Function %18
19+
%20 = OpTypeBool
20+
%21 = OpTypeVector %20 2
21+
%22 = OpConstant %17 0
22+
%23 = OpConstantComposite %18 %22 %22
23+
%24 = OpConstant %17 1
24+
%25 = OpConstant %17 4294967295
25+
%26 = OpConstantComposite %18 %25 %25
1726
%5 = OpFunction %2 None %6
1827
%4 = OpLabel
1928
%10 = OpVariable %11 Function %7
29+
%27 = OpVariable %19 Function %23
2030
OpBranch %12
2131
%12 = OpLabel
2232
OpBranch %13
2333
%13 = OpLabel
2434
OpLoopMerge %14 %16 None
35+
OpBranch %28
36+
%28 = OpLabel
37+
%29 = OpLoad %18 %27
38+
%30 = OpIEqual %21 %26 %29
39+
%31 = OpAll %20 %30
40+
OpSelectionMerge %32 None
41+
OpBranchConditional %31 %14 %32
42+
%32 = OpLabel
43+
%33 = OpCompositeExtract %17 %29 1
44+
%34 = OpIEqual %20 %33 %25
45+
%35 = OpSelect %17 %34 %24 %22
46+
%36 = OpCompositeConstruct %18 %35 %24
47+
%37 = OpIAdd %18 %29 %36
48+
OpStore %27 %37
2549
OpBranch %15
2650
%15 = OpLabel
27-
%17 = OpLoad %3 %10
28-
%19 = OpSLessThan %18 %17 %8
29-
OpSelectionMerge %20 None
30-
OpBranchConditional %19 %20 %21
31-
%21 = OpLabel
51+
%38 = OpLoad %3 %10
52+
%39 = OpSLessThan %20 %38 %8
53+
OpSelectionMerge %40 None
54+
OpBranchConditional %39 %40 %41
55+
%41 = OpLabel
3256
OpBranch %14
33-
%20 = OpLabel
34-
OpBranch %22
35-
%22 = OpLabel
57+
%40 = OpLabel
58+
OpBranch %42
59+
%42 = OpLabel
3660
OpBranch %14
3761
%16 = OpLabel
38-
%24 = OpLoad %3 %10
39-
%25 = OpIAdd %3 %24 %9
40-
OpStore %10 %25
62+
%44 = OpLoad %3 %10
63+
%45 = OpIAdd %3 %44 %9
64+
OpStore %10 %45
4165
OpBranch %13
4266
%14 = OpLabel
4367
OpReturn

0 commit comments

Comments
 (0)