Skip to content

Commit

Permalink
[naga wgsl-in] Ensure constant evaluation correctly handles Composes …
Browse files Browse the repository at this point in the history
…of vector ZeroValues (#7138)
  • Loading branch information
jamienicol authored Feb 14, 2025
1 parent eea3dde commit 4bb09e1
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 110 deletions.
11 changes: 11 additions & 0 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,17 @@ impl<'a> ConstantEvaluator<'a> {
mut expr: Handle<Expression>,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
// If expr is a Compose expression, eliminate ZeroValue and Splat expressions for
// each of its components.
if let Expression::Compose { ty, ref components } = self.expressions[expr] {
let components = components
.clone()
.iter()
.map(|component| self.eval_zero_value_and_splat(*component, span))
.collect::<Result<_, _>>()?;
expr = self.register_evaluated_expr(Expression::Compose { ty, components }, span)?;
}

// The result of the splat() for a Splat of a scalar ZeroValue is a
// vector ZeroValue, so we must call eval_zero_value_impl() after
// splat() in order to ensure we have no ZeroValues remaining.
Expand Down
7 changes: 7 additions & 0 deletions naga/tests/in/const-exprs.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,10 @@ fn compose_of_splat() {

const add_vec = vec2(1.0f) + vec2(3.0f, 4.0f);
const compare_vec = vec2(3.0f) == vec2(3.0f, 4.0f);

// Ensure binary ops correctly flatten compositions of vector zero values
fn compose_vector_zero_val_binop() {
var a = vec3(vec2i(), 0) + vec3(1);
var b = vec3(vec2i(), 0) + vec3(0, 1, 2);
var c = vec3(vec2i(), 2) + vec3(1, vec2i());
}
7 changes: 7 additions & 0 deletions naga/tests/out/glsl/const-exprs.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ uint map_texture_kind(int texture_kind) {
}
}

void compose_vector_zero_val_binop() {
ivec3 a = ivec3(1, 1, 1);
ivec3 b = ivec3(0, 1, 2);
ivec3 c = ivec3(1, 0, 2);
return;
}

void main() {
swizzle_of_compose();
index_of_compose();
Expand Down
9 changes: 9 additions & 0 deletions naga/tests/out/hlsl/const-exprs.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ uint map_texture_kind(int texture_kind)
}
}

void compose_vector_zero_val_binop()
{
int3 a = int3(1, 1, 1);
int3 b = int3(0, 1, 2);
int3 c = int3(1, 0, 2);

return;
}

[numthreads(2, 3, 1)]
void main()
{
Expand Down
8 changes: 8 additions & 0 deletions naga/tests/out/msl/const-exprs.msl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ uint map_texture_kind(
}
}

void compose_vector_zero_val_binop(
) {
metal::int3 a = metal::int3(1, 1, 1);
metal::int3 b = metal::int3(0, 1, 2);
metal::int3 c = metal::int3(1, 0, 2);
return;
}

kernel void main_(
) {
swizzle_of_compose();
Expand Down
234 changes: 124 additions & 110 deletions naga/tests/out/spv/const-exprs.spvasm
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 109
; Bound: 120
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %100 "main"
OpExecutionMode %100 LocalSize 2 3 1
OpEntryPoint GLCompute %111 "main"
OpExecutionMode %111 LocalSize 2 3 1
%2 = OpTypeVoid
%3 = OpTypeInt 32 0
%4 = OpTypeInt 32 1
Expand All @@ -16,137 +16,151 @@ OpExecutionMode %100 LocalSize 2 3 1
%8 = OpTypeVector %6 2
%10 = OpTypeBool
%9 = OpTypeVector %10 2
%11 = OpConstant %3 2
%12 = OpConstant %4 3
%13 = OpConstant %4 4
%14 = OpConstant %4 8
%15 = OpConstant %6 3.141
%16 = OpConstant %6 6.282
%17 = OpConstant %6 0.44444445
%18 = OpConstant %6 0.0
%19 = OpConstantComposite %7 %17 %18 %18 %18
%20 = OpConstant %4 0
%21 = OpConstant %4 1
%22 = OpConstant %4 2
%23 = OpConstant %6 4.0
%24 = OpConstant %6 5.0
%25 = OpConstantComposite %8 %23 %24
%26 = OpConstantTrue %10
%27 = OpConstantFalse %10
%28 = OpConstantComposite %9 %26 %27
%31 = OpTypeFunction %2
%32 = OpConstantComposite %5 %13 %12 %22 %21
%34 = OpTypePointer Function %5
%39 = OpTypePointer Function %4
%43 = OpConstant %4 6
%48 = OpConstant %4 30
%49 = OpConstant %4 70
%52 = OpConstantNull %4
%54 = OpConstantNull %4
%57 = OpConstantNull %5
%68 = OpConstant %4 -4
%69 = OpConstantComposite %5 %68 %68 %68 %68
%78 = OpConstant %6 1.0
%79 = OpConstant %6 2.0
%80 = OpConstantComposite %7 %79 %78 %78 %78
%82 = OpTypePointer Function %7
%87 = OpTypeFunction %3 %4
%88 = OpConstant %3 10
%89 = OpConstant %3 20
%90 = OpConstant %3 30
%91 = OpConstant %3 0
%98 = OpConstantNull %3
%30 = OpFunction %2 None %31
%29 = OpLabel
%33 = OpVariable %34 Function %32
OpBranch %35
%35 = OpLabel
OpReturn
OpFunctionEnd
%37 = OpFunction %2 None %31
%11 = OpTypeVector %4 3
%12 = OpConstant %3 2
%13 = OpConstant %4 3
%14 = OpConstant %4 4
%15 = OpConstant %4 8
%16 = OpConstant %6 3.141
%17 = OpConstant %6 6.282
%18 = OpConstant %6 0.44444445
%19 = OpConstant %6 0.0
%20 = OpConstantComposite %7 %18 %19 %19 %19
%21 = OpConstant %4 0
%22 = OpConstant %4 1
%23 = OpConstant %4 2
%24 = OpConstant %6 4.0
%25 = OpConstant %6 5.0
%26 = OpConstantComposite %8 %24 %25
%27 = OpConstantTrue %10
%28 = OpConstantFalse %10
%29 = OpConstantComposite %9 %27 %28
%32 = OpTypeFunction %2
%33 = OpConstantComposite %5 %14 %13 %23 %22
%35 = OpTypePointer Function %5
%40 = OpTypePointer Function %4
%44 = OpConstant %4 6
%49 = OpConstant %4 30
%50 = OpConstant %4 70
%53 = OpConstantNull %4
%55 = OpConstantNull %4
%58 = OpConstantNull %5
%69 = OpConstant %4 -4
%70 = OpConstantComposite %5 %69 %69 %69 %69
%79 = OpConstant %6 1.0
%80 = OpConstant %6 2.0
%81 = OpConstantComposite %7 %80 %79 %79 %79
%83 = OpTypePointer Function %7
%88 = OpTypeFunction %3 %4
%89 = OpConstant %3 10
%90 = OpConstant %3 20
%91 = OpConstant %3 30
%92 = OpConstant %3 0
%99 = OpConstantNull %3
%102 = OpConstantComposite %11 %22 %22 %22
%103 = OpConstantComposite %11 %21 %22 %23
%104 = OpConstantComposite %11 %22 %21 %23
%106 = OpTypePointer Function %11
%31 = OpFunction %2 None %32
%30 = OpLabel
%34 = OpVariable %35 Function %33
OpBranch %36
%36 = OpLabel
%38 = OpVariable %39 Function %22
OpBranch %40
%40 = OpLabel
OpReturn
OpFunctionEnd
%42 = OpFunction %2 None %31
%38 = OpFunction %2 None %32
%37 = OpLabel
%39 = OpVariable %40 Function %23
OpBranch %41
%41 = OpLabel
%44 = OpVariable %39 Function %43
OpBranch %45
%45 = OpLabel
OpReturn
OpFunctionEnd
%47 = OpFunction %2 None %31
%43 = OpFunction %2 None %32
%42 = OpLabel
%45 = OpVariable %40 Function %44
OpBranch %46
%46 = OpLabel
%56 = OpVariable %34 Function %57
%51 = OpVariable %39 Function %52
%55 = OpVariable %39 Function %49
%50 = OpVariable %39 Function %48
%53 = OpVariable %39 Function %54
OpBranch %58
%58 = OpLabel
%59 = OpLoad %4 %50
OpStore %51 %59
%60 = OpLoad %4 %51
OpStore %53 %60
%61 = OpLoad %4 %50
%62 = OpLoad %4 %51
%63 = OpLoad %4 %53
%64 = OpLoad %4 %55
%65 = OpCompositeConstruct %5 %61 %62 %63 %64
OpStore %56 %65
OpReturn
OpFunctionEnd
%67 = OpFunction %2 None %31
%66 = OpLabel
%70 = OpVariable %34 Function %69
OpBranch %71
%71 = OpLabel
%48 = OpFunction %2 None %32
%47 = OpLabel
%57 = OpVariable %35 Function %58
%52 = OpVariable %40 Function %53
%56 = OpVariable %40 Function %50
%51 = OpVariable %40 Function %49
%54 = OpVariable %40 Function %55
OpBranch %59
%59 = OpLabel
%60 = OpLoad %4 %51
OpStore %52 %60
%61 = OpLoad %4 %52
OpStore %54 %61
%62 = OpLoad %4 %51
%63 = OpLoad %4 %52
%64 = OpLoad %4 %54
%65 = OpLoad %4 %56
%66 = OpCompositeConstruct %5 %62 %63 %64 %65
OpStore %57 %66
OpReturn
OpFunctionEnd
%73 = OpFunction %2 None %31
%68 = OpFunction %2 None %32
%67 = OpLabel
%71 = OpVariable %35 Function %70
OpBranch %72
%72 = OpLabel
%74 = OpVariable %34 Function %69
OpBranch %75
%75 = OpLabel
OpReturn
OpFunctionEnd
%77 = OpFunction %2 None %31
%74 = OpFunction %2 None %32
%73 = OpLabel
%75 = OpVariable %35 Function %70
OpBranch %76
%76 = OpLabel
%81 = OpVariable %82 Function %80
OpBranch %83
%83 = OpLabel
OpReturn
OpFunctionEnd
%86 = OpFunction %3 None %87
%85 = OpFunctionParameter %4
%78 = OpFunction %2 None %32
%77 = OpLabel
%82 = OpVariable %83 Function %81
OpBranch %84
%84 = OpLabel
OpBranch %92
%92 = OpLabel
OpSelectionMerge %93 None
OpSwitch %85 %97 0 %94 1 %95 2 %96
%94 = OpLabel
OpReturnValue %88
OpReturn
OpFunctionEnd
%87 = OpFunction %3 None %88
%86 = OpFunctionParameter %4
%85 = OpLabel
OpBranch %93
%93 = OpLabel
OpSelectionMerge %94 None
OpSwitch %86 %98 0 %95 1 %96 2 %97
%95 = OpLabel
OpReturnValue %89
%96 = OpLabel
OpReturnValue %90
%97 = OpLabel
OpReturnValue %91
%93 = OpLabel
OpReturnValue %98
%98 = OpLabel
OpReturnValue %92
%94 = OpLabel
OpReturnValue %99
OpFunctionEnd
%101 = OpFunction %2 None %32
%100 = OpLabel
%105 = OpVariable %106 Function %102
%107 = OpVariable %106 Function %103
%108 = OpVariable %106 Function %104
OpBranch %109
%109 = OpLabel
OpReturn
OpFunctionEnd
%100 = OpFunction %2 None %31
%99 = OpLabel
OpBranch %101
%101 = OpLabel
%102 = OpFunctionCall %2 %30
%103 = OpFunctionCall %2 %37
%104 = OpFunctionCall %2 %42
%105 = OpFunctionCall %2 %47
%106 = OpFunctionCall %2 %67
%107 = OpFunctionCall %2 %73
%108 = OpFunctionCall %2 %77
%111 = OpFunction %2 None %32
%110 = OpLabel
OpBranch %112
%112 = OpLabel
%113 = OpFunctionCall %2 %31
%114 = OpFunctionCall %2 %38
%115 = OpFunctionCall %2 %43
%116 = OpFunctionCall %2 %48
%117 = OpFunctionCall %2 %68
%118 = OpFunctionCall %2 %74
%119 = OpFunctionCall %2 %78
OpReturn
OpFunctionEnd
8 changes: 8 additions & 0 deletions naga/tests/out/wgsl/const-exprs.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ fn map_texture_kind(texture_kind: i32) -> u32 {
}
}

fn compose_vector_zero_val_binop() {
var a: vec3<i32> = vec3<i32>(1i, 1i, 1i);
var b: vec3<i32> = vec3<i32>(0i, 1i, 2i);
var c: vec3<i32> = vec3<i32>(1i, 0i, 2i);

return;
}

@compute @workgroup_size(2, 3, 1)
fn main() {
swizzle_of_compose();
Expand Down

0 comments on commit 4bb09e1

Please sign in to comment.