Skip to content

Commit a2efe8b

Browse files
committed
[naga spv-out msl-out hlsl-out] Make infinite loop workaround count down instead of up
To avoid generating code containing infinite loops, and therefore incurring the wrath of undefined behaviour, we insert a counter into each loop that will break after 2^64 iterations. This was previously implemented as two u32 variables counting up from zero. We have been informed that this construct can cause certain Intel drivers to hang. Instead, we must count down from u32::MAX. Counting down is more fun, anyway.
1 parent 8474132 commit a2efe8b

32 files changed

+247
-242
lines changed

naga/src/back/hlsl/writer.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
171171
}
172172

173173
let loop_bound_name = self.namer.call("loop_bound");
174-
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u, 0u);");
175-
let level = level.next();
176174
let max = u32::MAX;
175+
// Count down from u32::MAX rather than up from 0 to avoid hang on
176+
// certain intel drivers. See https://github.com/gfx-rs/wgpu/issues/7319.
177+
let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
178+
let level = level.next();
177179
let break_and_inc = format!(
178-
"{level}if (all({loop_bound_name} == uint2({max}u, {max}u))) {{ break; }}
179-
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
180+
"{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
181+
{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
180182
);
181183

182184
Some((decl, break_and_inc))

naga/src/back/msl/writer.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -841,12 +841,13 @@ impl<W: Write> Writer<W> {
841841
}
842842

843843
let loop_bound_name = self.namer.call("loop_bound");
844-
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u);");
844+
// Count down from u32::MAX rather than up from 0 to avoid hang on
845+
// certain intel drivers. See https://github.com/gfx-rs/wgpu/issues/7319.
846+
let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
845847
let level = level.next();
846-
let max = u32::MAX;
847848
let break_and_inc = format!(
848-
"{level}if ({NAMESPACE}::all({loop_bound_name} == uint2({max}u))) {{ break; }}
849-
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
849+
"{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
850+
{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
850851
);
851852

852853
Some((decl, break_and_inc))

naga/src/back/spv/block.rs

+13-11
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ impl BlockContext<'_> {
309309
uint2_ptr_type_id,
310310
loop_counter_var_id,
311311
spirv::StorageClass::Function,
312-
Some(zero_uint2_const_id),
312+
Some(max_uint2_const_id),
313313
),
314314
};
315315
self.function.force_loop_bounding_vars.push(var);
@@ -330,14 +330,14 @@ impl BlockContext<'_> {
330330
None,
331331
));
332332

333-
// If both the high and low u32s have reached u32::MAX then break. ie
334-
// if (all(eq(loop_counter, vec2(u32::MAX)))) { break; }
333+
// If both the high and low u32s have reached 0 then break. ie
334+
// if (all(eq(loop_counter, vec2(0)))) { break; }
335335
let eq_id = self.gen_id();
336336
block.body.push(Instruction::binary(
337337
spirv::Op::IEqual,
338338
bool2_type_id,
339339
eq_id,
340-
max_uint2_const_id,
340+
zero_uint2_const_id,
341341
load_id,
342342
));
343343
let all_eq_id = self.gen_id();
@@ -359,9 +359,11 @@ impl BlockContext<'_> {
359359
);
360360
block = Block::new(inc_counter_block_id);
361361

362-
// To simulate a 64-bit counter we always increment the low u32, and increment
362+
// To simulate a 64-bit counter we always decrement the low u32, and decrement
363363
// the high u32 when the low u32 overflows. ie
364-
// counter += vec2(select(0u, 1u, counter.y == u32::MAX), 1u);
364+
// counter -= vec2(select(0u, 1u, counter.y == 0), 1u);
365+
// Count down from u32::MAX rather than up from 0 to avoid hang on
366+
// certain intel drivers. See https://github.com/gfx-rs/wgpu/issues/7319.
365367
let low_id = self.gen_id();
366368
block.body.push(Instruction::composite_extract(
367369
uint_type_id,
@@ -375,7 +377,7 @@ impl BlockContext<'_> {
375377
bool_type_id,
376378
low_overflow_id,
377379
low_id,
378-
max_uint_const_id,
380+
zero_uint_const_id,
379381
));
380382
let carry_bit_id = self.gen_id();
381383
block.body.push(Instruction::select(
@@ -385,19 +387,19 @@ impl BlockContext<'_> {
385387
one_uint_const_id,
386388
zero_uint_const_id,
387389
));
388-
let increment_id = self.gen_id();
390+
let decrement_id = self.gen_id();
389391
block.body.push(Instruction::composite_construct(
390392
uint2_type_id,
391-
increment_id,
393+
decrement_id,
392394
&[carry_bit_id, one_uint_const_id],
393395
));
394396
let result_id = self.gen_id();
395397
block.body.push(Instruction::binary(
396-
spirv::Op::IAdd,
398+
spirv::Op::ISub,
397399
uint2_type_id,
398400
result_id,
399401
load_id,
400-
increment_id,
402+
decrement_id,
401403
));
402404
block
403405
.body

naga/tests/out/hlsl/boids.hlsl

+3-3
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ void main(uint3 global_invocation_id : SV_DispatchThreadID)
4141
vPos = _e8;
4242
float2 _e14 = asfloat(particlesSrc.Load2(8+index*16+0));
4343
vVel = _e14;
44-
uint2 loop_bound = uint2(0u, 0u);
44+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
4545
bool loop_init = true;
4646
while(true) {
47-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
48-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
47+
if (all(loop_bound == uint2(0u, 0u))) { break; }
48+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
4949
if (!loop_init) {
5050
uint _e91 = i;
5151
i = (_e91 + 1u);

naga/tests/out/hlsl/break-if.hlsl

+12-12
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
void breakIfEmpty()
22
{
3-
uint2 loop_bound = uint2(0u, 0u);
3+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
44
bool loop_init = true;
55
while(true) {
6-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
7-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
6+
if (all(loop_bound == uint2(0u, 0u))) { break; }
7+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
88
if (!loop_init) {
99
if (true) {
1010
break;
@@ -20,11 +20,11 @@ void breakIfEmptyBody(bool a)
2020
bool b = (bool)0;
2121
bool c = (bool)0;
2222

23-
uint2 loop_bound_1 = uint2(0u, 0u);
23+
uint2 loop_bound_1 = uint2(4294967295u, 4294967295u);
2424
bool loop_init_1 = true;
2525
while(true) {
26-
if (all(loop_bound_1 == uint2(4294967295u, 4294967295u))) { break; }
27-
loop_bound_1 += uint2(loop_bound_1.y == 4294967295u, 1u);
26+
if (all(loop_bound_1 == uint2(0u, 0u))) { break; }
27+
loop_bound_1 -= uint2(loop_bound_1.y == 0u, 1u);
2828
if (!loop_init_1) {
2929
b = a;
3030
bool _e2 = b;
@@ -44,11 +44,11 @@ void breakIf(bool a_1)
4444
bool d = (bool)0;
4545
bool e = (bool)0;
4646

47-
uint2 loop_bound_2 = uint2(0u, 0u);
47+
uint2 loop_bound_2 = uint2(4294967295u, 4294967295u);
4848
bool loop_init_2 = true;
4949
while(true) {
50-
if (all(loop_bound_2 == uint2(4294967295u, 4294967295u))) { break; }
51-
loop_bound_2 += uint2(loop_bound_2.y == 4294967295u, 1u);
50+
if (all(loop_bound_2 == uint2(0u, 0u))) { break; }
51+
loop_bound_2 -= uint2(loop_bound_2.y == 0u, 1u);
5252
if (!loop_init_2) {
5353
bool _e5 = e;
5454
if ((a_1 == _e5)) {
@@ -67,11 +67,11 @@ void breakIfSeparateVariable()
6767
{
6868
uint counter = 0u;
6969

70-
uint2 loop_bound_3 = uint2(0u, 0u);
70+
uint2 loop_bound_3 = uint2(4294967295u, 4294967295u);
7171
bool loop_init_3 = true;
7272
while(true) {
73-
if (all(loop_bound_3 == uint2(4294967295u, 4294967295u))) { break; }
74-
loop_bound_3 += uint2(loop_bound_3.y == 4294967295u, 1u);
73+
if (all(loop_bound_3 == uint2(0u, 0u))) { break; }
74+
loop_bound_3 -= uint2(loop_bound_3.y == 0u, 1u);
7575
if (!loop_init_3) {
7676
uint _e5 = counter;
7777
if ((_e5 == 5u)) {

naga/tests/out/hlsl/collatz.hlsl

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ uint collatz_iterations(uint n_base)
1414
uint i = 0u;
1515

1616
n = n_base;
17-
uint2 loop_bound = uint2(0u, 0u);
17+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
1818
while(true) {
19-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
20-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
19+
if (all(loop_bound == uint2(0u, 0u))) { break; }
20+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
2121
uint _e4 = n;
2222
if ((_e4 > 1u)) {
2323
} else {

naga/tests/out/hlsl/control-flow.hlsl

+18-18
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ void switch_const_expr_case_selectors()
6464

6565
void loop_switch_continue(int x)
6666
{
67-
uint2 loop_bound = uint2(0u, 0u);
67+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
6868
while(true) {
69-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
70-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
69+
if (all(loop_bound == uint2(0u, 0u))) { break; }
70+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
7171
bool should_continue = false;
7272
switch(x) {
7373
case 1: {
@@ -87,10 +87,10 @@ void loop_switch_continue(int x)
8787

8888
void loop_switch_continue_nesting(int x_1, int y, int z)
8989
{
90-
uint2 loop_bound_1 = uint2(0u, 0u);
90+
uint2 loop_bound_1 = uint2(4294967295u, 4294967295u);
9191
while(true) {
92-
if (all(loop_bound_1 == uint2(4294967295u, 4294967295u))) { break; }
93-
loop_bound_1 += uint2(loop_bound_1.y == 4294967295u, 1u);
92+
if (all(loop_bound_1 == uint2(0u, 0u))) { break; }
93+
loop_bound_1 -= uint2(loop_bound_1.y == 0u, 1u);
9494
bool should_continue_1 = false;
9595
switch(x_1) {
9696
case 1: {
@@ -104,10 +104,10 @@ void loop_switch_continue_nesting(int x_1, int y, int z)
104104
break;
105105
}
106106
default: {
107-
uint2 loop_bound_2 = uint2(0u, 0u);
107+
uint2 loop_bound_2 = uint2(4294967295u, 4294967295u);
108108
while(true) {
109-
if (all(loop_bound_2 == uint2(4294967295u, 4294967295u))) { break; }
110-
loop_bound_2 += uint2(loop_bound_2.y == 4294967295u, 1u);
109+
if (all(loop_bound_2 == uint2(0u, 0u))) { break; }
110+
loop_bound_2 -= uint2(loop_bound_2.y == 0u, 1u);
111111
bool should_continue_2 = false;
112112
switch(z) {
113113
case 1: {
@@ -146,10 +146,10 @@ void loop_switch_continue_nesting(int x_1, int y, int z)
146146
continue;
147147
}
148148
}
149-
uint2 loop_bound_3 = uint2(0u, 0u);
149+
uint2 loop_bound_3 = uint2(4294967295u, 4294967295u);
150150
while(true) {
151-
if (all(loop_bound_3 == uint2(4294967295u, 4294967295u))) { break; }
152-
loop_bound_3 += uint2(loop_bound_3.y == 4294967295u, 1u);
151+
if (all(loop_bound_3 == uint2(0u, 0u))) { break; }
152+
loop_bound_3 -= uint2(loop_bound_3.y == 0u, 1u);
153153
bool should_continue_4 = false;
154154
do {
155155
do {
@@ -171,10 +171,10 @@ void loop_switch_omit_continue_variable_checks(int x_2, int y_1, int z_1, int w)
171171
{
172172
int pos_1 = int(0);
173173

174-
uint2 loop_bound_4 = uint2(0u, 0u);
174+
uint2 loop_bound_4 = uint2(4294967295u, 4294967295u);
175175
while(true) {
176-
if (all(loop_bound_4 == uint2(4294967295u, 4294967295u))) { break; }
177-
loop_bound_4 += uint2(loop_bound_4.y == 4294967295u, 1u);
176+
if (all(loop_bound_4 == uint2(0u, 0u))) { break; }
177+
loop_bound_4 -= uint2(loop_bound_4.y == 0u, 1u);
178178
bool should_continue_5 = false;
179179
switch(x_2) {
180180
case 1: {
@@ -186,10 +186,10 @@ void loop_switch_omit_continue_variable_checks(int x_2, int y_1, int z_1, int w)
186186
}
187187
}
188188
}
189-
uint2 loop_bound_5 = uint2(0u, 0u);
189+
uint2 loop_bound_5 = uint2(4294967295u, 4294967295u);
190190
while(true) {
191-
if (all(loop_bound_5 == uint2(4294967295u, 4294967295u))) { break; }
192-
loop_bound_5 += uint2(loop_bound_5.y == 4294967295u, 1u);
191+
if (all(loop_bound_5 == uint2(0u, 0u))) { break; }
192+
loop_bound_5 -= uint2(loop_bound_5.y == 0u, 1u);
193193
bool should_continue_6 = false;
194194
switch(x_2) {
195195
case 1: {

naga/tests/out/hlsl/do-while.hlsl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
void fb1_(inout bool cond)
22
{
3-
uint2 loop_bound = uint2(0u, 0u);
3+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
44
bool loop_init = true;
55
while(true) {
6-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
7-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
6+
if (all(loop_bound == uint2(0u, 0u))) { break; }
7+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
88
if (!loop_init) {
99
bool _e1 = cond;
1010
if (!(_e1)) {

naga/tests/out/hlsl/ray-query.hlsl

+3-3
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ RayIntersection query_loop(float3 pos, float3 dir, RaytracingAccelerationStructu
8484
RayQuery<RAY_FLAG_NONE> rq_1;
8585

8686
rq_1.TraceRayInline(acs, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir)));
87-
uint2 loop_bound = uint2(0u, 0u);
87+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
8888
while(true) {
89-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
90-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
89+
if (all(loop_bound == uint2(0u, 0u))) { break; }
90+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
9191
const bool _e9 = rq_1.Proceed();
9292
if (_e9) {
9393
} else {

naga/tests/out/hlsl/shadow.hlsl

+6-6
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ float4 fs_main(FragmentInput_fs_main fragmentinput_fs_main) : SV_Target0
9595
uint i = 0u;
9696

9797
float3 normal_1 = normalize(in_.world_normal);
98-
uint2 loop_bound = uint2(0u, 0u);
98+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
9999
bool loop_init = true;
100100
while(true) {
101-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
102-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
101+
if (all(loop_bound == uint2(0u, 0u))) { break; }
102+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
103103
if (!loop_init) {
104104
uint _e40 = i;
105105
i = (_e40 + 1u);
@@ -134,11 +134,11 @@ float4 fs_main_without_storage(FragmentInput_fs_main_without_storage fragmentinp
134134
uint i_1 = 0u;
135135

136136
float3 normal_2 = normalize(in_1.world_normal);
137-
uint2 loop_bound_1 = uint2(0u, 0u);
137+
uint2 loop_bound_1 = uint2(4294967295u, 4294967295u);
138138
bool loop_init_1 = true;
139139
while(true) {
140-
if (all(loop_bound_1 == uint2(4294967295u, 4294967295u))) { break; }
141-
loop_bound_1 += uint2(loop_bound_1.y == 4294967295u, 1u);
140+
if (all(loop_bound_1 == uint2(0u, 0u))) { break; }
141+
loop_bound_1 -= uint2(loop_bound_1.y == 0u, 1u);
142142
if (!loop_init_1) {
143143
uint _e40 = i_1;
144144
i_1 = (_e40 + 1u);

naga/tests/out/msl/atomicCompareExchange.msl

+12-12
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ kernel void test_atomic_compare_exchange_i32_(
7676
uint i = 0u;
7777
int old = {};
7878
bool exchanged = {};
79-
uint2 loop_bound = uint2(0u);
79+
uint2 loop_bound = uint2(4294967295u);
8080
bool loop_init = true;
8181
while(true) {
82-
if (metal::all(loop_bound == uint2(4294967295u))) { break; }
83-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
82+
if (metal::all(loop_bound == uint2(0u))) { break; }
83+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
8484
if (!loop_init) {
8585
uint _e27 = i;
8686
i = _e27 + 1u;
@@ -96,10 +96,10 @@ kernel void test_atomic_compare_exchange_i32_(
9696
int _e8 = metal::atomic_load_explicit(&arr_i32_.inner[_e6], metal::memory_order_relaxed);
9797
old = _e8;
9898
exchanged = false;
99-
uint2 loop_bound_1 = uint2(0u);
99+
uint2 loop_bound_1 = uint2(4294967295u);
100100
while(true) {
101-
if (metal::all(loop_bound_1 == uint2(4294967295u))) { break; }
102-
loop_bound_1 += uint2(loop_bound_1.y == 4294967295u, 1u);
101+
if (metal::all(loop_bound_1 == uint2(0u))) { break; }
102+
loop_bound_1 -= uint2(loop_bound_1.y == 0u, 1u);
103103
bool _e12 = exchanged;
104104
if (!(_e12)) {
105105
} else {
@@ -127,11 +127,11 @@ kernel void test_atomic_compare_exchange_u32_(
127127
uint i_1 = 0u;
128128
uint old_1 = {};
129129
bool exchanged_1 = {};
130-
uint2 loop_bound_2 = uint2(0u);
130+
uint2 loop_bound_2 = uint2(4294967295u);
131131
bool loop_init_1 = true;
132132
while(true) {
133-
if (metal::all(loop_bound_2 == uint2(4294967295u))) { break; }
134-
loop_bound_2 += uint2(loop_bound_2.y == 4294967295u, 1u);
133+
if (metal::all(loop_bound_2 == uint2(0u))) { break; }
134+
loop_bound_2 -= uint2(loop_bound_2.y == 0u, 1u);
135135
if (!loop_init_1) {
136136
uint _e27 = i_1;
137137
i_1 = _e27 + 1u;
@@ -147,10 +147,10 @@ kernel void test_atomic_compare_exchange_u32_(
147147
uint _e8 = metal::atomic_load_explicit(&arr_u32_.inner[_e6], metal::memory_order_relaxed);
148148
old_1 = _e8;
149149
exchanged_1 = false;
150-
uint2 loop_bound_3 = uint2(0u);
150+
uint2 loop_bound_3 = uint2(4294967295u);
151151
while(true) {
152-
if (metal::all(loop_bound_3 == uint2(4294967295u))) { break; }
153-
loop_bound_3 += uint2(loop_bound_3.y == 4294967295u, 1u);
152+
if (metal::all(loop_bound_3 == uint2(0u))) { break; }
153+
loop_bound_3 -= uint2(loop_bound_3.y == 0u, 1u);
154154
bool _e12 = exchanged_1;
155155
if (!(_e12)) {
156156
} else {

0 commit comments

Comments
 (0)