Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga wgsl-in] Allow global const declarations to have abstract types #7055

Merged
merged 5 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ By @brodycj in [#6924](https://github.com/gfx-rs/wgpu/pull/6924).
- Error if structs have two fields with the same name. By @SparkyPotato in [#7088](https://github.com/gfx-rs/wgpu/pull/7088).
- Forward '--keep-coordinate-space' flag to GLSL backend in naga-cli. By @cloone8 in [#7206](https://github.com/gfx-rs/wgpu/pull/7206).
- Allow template lists to have a trailing comma. By @KentSlaney in [#7142](https://github.com/gfx-rs/wgpu/pull/7142).
- Allow WGSL const declarations to have abstract types. By @jamienicol in [#7055](https://github.com/gfx-rs/wgpu/pull/7055).

#### General

Expand Down
15 changes: 9 additions & 6 deletions naga/src/compact/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,18 @@ pub fn compact(module: &mut crate::Module) {
log::trace!("tracing special types");
module_tracer.trace_special_types(&module.special_types);

// We treat all named constants as used by definition.
// We treat all named constants as used by definition, unless they have an
// abstract type as we do not want those reaching the validator.
log::trace!("tracing named constants");
for (handle, constant) in module.constants.iter() {
if constant.name.is_some() {
log::trace!("tracing constant {:?}", constant.name.as_ref().unwrap());
module_tracer.constants_used.insert(handle);
module_tracer.types_used.insert(constant.ty);
module_tracer.global_expressions_used.insert(constant.init);
if constant.name.is_none() || module.types[constant.ty].inner.is_abstract(&module.types) {
continue;
}

log::trace!("tracing constant {:?}", constant.name.as_ref().unwrap());
module_tracer.constants_used.insert(handle);
module_tracer.types_used.insert(constant.ty);
module_tracer.global_expressions_used.insert(constant.init);
}

// We treat all named overrides as used by definition.
Expand Down
9 changes: 1 addition & 8 deletions naga/src/front/wgsl/lower/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,7 @@ impl<'source> super::ExpressionContext<'source, '_, '_> {
// rather than them being misreported as type conversion errors.
// If the type is an array (of an array, etc) then we must check whether the
// type of the innermost array's base type is abstract.
let mut base_inner = expr_inner;
while let crate::TypeInner::Array { base, .. } = *base_inner {
base_inner = &types[base].inner;
}
if !base_inner
.scalar()
.is_some_and(|scalar| scalar.is_abstract())
{
if !expr_inner.is_abstract(types) {
return Ok(expr);
}

Expand Down
55 changes: 45 additions & 10 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,15 @@ impl SubgroupGather {
}
}

/// Whether a declaration accepts abstract types, or concretizes.
enum AbstractRule {
/// This declaration concretizes its initialization expression.
Concretize,

/// This declaration can accept initializers with abstract types.
Allow,
}

pub struct Lowerer<'source, 'temp> {
index: &'temp Index<'source>,
}
Expand Down Expand Up @@ -1072,8 +1081,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_const()))
.transpose()?;

let (ty, initializer) =
self.type_and_init(v.name, v.init, explicit_ty, &mut ctx.as_override())?;
let (ty, initializer) = self.type_and_init(
v.name,
v.init,
explicit_ty,
AbstractRule::Concretize,
&mut ctx.as_override(),
)?;

let binding = if let Some(ref binding) = v.binding {
Some(crate::ResourceBinding {
Expand Down Expand Up @@ -1105,8 +1119,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
c.ty.map(|ast| self.resolve_ast_type(ast, &mut ectx))
.transpose()?;

let (ty, init) =
self.type_and_init(c.name, Some(c.init), explicit_ty, &mut ectx)?;
let (ty, init) = self.type_and_init(
c.name,
Some(c.init),
explicit_ty,
AbstractRule::Allow,
&mut ectx,
)?;
let init = init.expect("Global const must have init");

let handle = ctx.module.constants.append(
Expand All @@ -1128,7 +1147,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

let mut ectx = ctx.as_override();

let (ty, init) = self.type_and_init(o.name, o.init, explicit_ty, &mut ectx)?;
let (ty, init) = self.type_and_init(
o.name,
o.init,
explicit_ty,
AbstractRule::Concretize,
&mut ectx,
)?;

let id =
o.id.map(|id| self.const_u32(id, &mut ctx.as_const()))
Expand Down Expand Up @@ -1201,6 +1226,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
name: ast::Ident<'source>,
init: Option<Handle<ast::Expression<'source>>>,
explicit_ty: Option<Handle<crate::Type>>,
abstract_rule: AbstractRule,
ectx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<(Handle<crate::Type>, Option<Handle<crate::Expression>>), Error<'source>> {
let ty;
Expand Down Expand Up @@ -1234,9 +1260,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
initializer = Some(init);
}
(Some(init), None) => {
let concretized = self.expression(init, ectx)?;
ty = ectx.register_type(concretized)?;
initializer = Some(concretized);
let mut init = self.expression_for_abstract(init, ectx)?;
if let AbstractRule::Concretize = abstract_rule {
init = ectx.concretize(init)?;
}
ty = ectx.register_type(init)?;
initializer = Some(init);
}
(None, Some(explicit_ty)) => {
ty = explicit_ty;
Expand Down Expand Up @@ -1480,8 +1509,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.transpose()?;

let mut ectx = ctx.as_expression(block, &mut emitter);
let (ty, initializer) =
self.type_and_init(v.name, v.init, explicit_ty, &mut ectx)?;
let (ty, initializer) = self.type_and_init(
v.name,
v.init,
explicit_ty,
AbstractRule::Concretize,
&mut ectx,
)?;

let (const_initializer, initializer) = {
match initializer {
Expand Down Expand Up @@ -1544,6 +1578,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
c.name,
Some(c.init),
explicit_ty,
AbstractRule::Concretize,
&mut ectx.as_const(),
)?;
let init = init.expect("Local const must have init");
Expand Down
2 changes: 2 additions & 0 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,8 @@ impl<'a> ConstantEvaluator<'a> {
target: crate::Scalar,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let expr = self.check_and_get(expr)?;

let Expression::Compose { ty, ref components } = self.expressions[expr] else {
return self.cast(expr, target, span);
};
Expand Down
22 changes: 22 additions & 0 deletions naga/src/proc/type_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,26 @@ impl crate::TypeInner {
| crate::TypeInner::BindingArray { .. } => None,
}
}

/// Return true if `self` is an abstract type.
///
/// Use `types` to look up type handles. This is necessary to
/// recognize abstract arrays.
pub fn is_abstract(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
match *self {
crate::TypeInner::Scalar(scalar)
| crate::TypeInner::Vector { scalar, .. }
| crate::TypeInner::Matrix { scalar, .. }
| crate::TypeInner::Atomic(scalar) => scalar.is_abstract(),
crate::TypeInner::Array { base, .. } => types[base].inner.is_abstract(types),
crate::TypeInner::ValuePointer { .. }
| crate::TypeInner::Pointer { .. }
| crate::TypeInner::Struct { .. }
| crate::TypeInner::Image { .. }
| crate::TypeInner::Sampler { .. }
| crate::TypeInner::AccelerationStructure
| crate::TypeInner::RayQuery
| crate::TypeInner::BindingArray { .. } => false,
}
}
}
28 changes: 28 additions & 0 deletions naga/tests/in/wgsl/abstract-types-function-calls.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,53 @@ fn func_au(a: array<u32, 2>) {}

fn func_f_i(a: f32, b: i32) {}

const const_af = 0;
const const_ai = 0;
const const_vec_af = vec2(0.0);
const const_vec_ai = vec2(0);
const const_mat_af = mat2x2(vec2(0.0), vec2(0.0));
const const_arr_af = array(0.0, 0.0);
const const_arr_ai = array(0, 0);

fn main() {
func_f(0.0);
func_f(0);
func_i(0);
func_u(0);

func_f(const_af);
func_f(const_ai);
func_i(const_ai);
func_u(const_ai);

func_vf(vec2(0.0));
func_vf(vec2(0));
func_vi(vec2(0));
func_vu(vec2(0));

func_vf(const_vec_af);
func_vf(const_vec_ai);
func_vi(const_vec_ai);
func_vu(const_vec_ai);

func_mf(mat2x2(vec2(0.0), vec2(0.0)));
func_mf(mat2x2(vec2(0), vec2(0)));

func_mf(const_mat_af);

func_af(array(0.0, 0.0));
func_af(array(0, 0));
func_ai(array(0, 0));
func_au(array(0, 0));

func_af(const_arr_af);
func_af(const_arr_ai);
func_ai(const_arr_ai);
func_au(const_arr_ai);

func_f_i(0.0, 0);
func_f_i(0, 0);

func_f_i(const_af, const_ai);
func_f_i(const_ai, const_ai);
}
5 changes: 5 additions & 0 deletions naga/tests/in/wgsl/abstract-types-return.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,8 @@ fn return_vec2f32_ai() -> vec2<f32> {
fn return_arrf32_ai() -> array<f32, 4> {
return array(1, 1, 1, 1);
}

const one = 1;
fn return_const_f32_const_ai() -> f32 {
return one;
}
7 changes: 6 additions & 1 deletion naga/tests/in/wgsl/const_assert.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@ const y = 2;
const_assert x < y; // valid at module-scope.
const_assert(y != 0); // parentheses are optional.

// Ensure abstract-typed consts can be compared to different concrete types
const_assert x == 1i;
const_assert x > 0u;
const_assert x < 2.0f;

fn foo() {
const z = x + y - 2;
const_assert z > 0; // valid in functions.
const_assert(z > 0);
}
}
4 changes: 4 additions & 0 deletions naga/tests/out/glsl/abstract-types-return.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ float[4] return_arrf32_ai() {
return float[4](1.0, 1.0, 1.0, 1.0);
}

float return_const_f32_const_ai() {
return 1.0;
}

void main() {
return;
}
Expand Down
2 changes: 0 additions & 2 deletions naga/tests/out/glsl/constructors.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ struct Foo {
vec4 a;
int b;
};
const vec3 const2_ = vec3(0.0, 1.0, 2.0);
const mat2x2 const3_ = mat2x2(vec2(0.0, 1.0), vec2(2.0, 3.0));
const mat2x2 const4_[1] = mat2x2[1](mat2x2(vec2(0.0, 1.0), vec2(2.0, 3.0)));
const bool cz0_ = false;
Expand All @@ -20,7 +19,6 @@ const uvec2 cz4_ = uvec2(0u);
const mat2x2 cz5_ = mat2x2(0.0);
const Foo cz6_[3] = Foo[3](Foo(vec4(0.0), 0), Foo(vec4(0.0), 0), Foo(vec4(0.0), 0));
const Foo cz7_ = Foo(vec4(0.0), 0);
const int cp3_[4] = int[4](0, 1, 2, 3);


void main() {
Expand Down
5 changes: 5 additions & 0 deletions naga/tests/out/hlsl/abstract-types-return.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ ret_return_arrf32_ai return_arrf32_ai()
return Constructarray4_float_(1.0, 1.0, 1.0, 1.0);
}

float return_const_f32_const_ai()
{
return 1.0;
}

[numthreads(1, 1, 1)]
void main()
{
Expand Down
14 changes: 6 additions & 8 deletions naga/tests/out/hlsl/constructors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ ret_Constructarray1_float2x2_ Constructarray1_float2x2_(float2x2 arg0) {
return ret;
}

typedef int ret_Constructarray4_int_[4];
ret_Constructarray4_int_ Constructarray4_int_(int arg0, int arg1, int arg2, int arg3) {
int ret[4] = { arg0, arg1, arg2, arg3 };
return ret;
}

bool ZeroValuebool() {
return (bool)0;
}
Expand Down Expand Up @@ -51,7 +45,6 @@ Foo ZeroValueFoo() {
return (Foo)0;
}

static const float3 const2_ = float3(0.0, 1.0, 2.0);
static const float2x2 const3_ = float2x2(float2(0.0, 1.0), float2(2.0, 3.0));
static const float2x2 const4_[1] = Constructarray1_float2x2_(float2x2(float2(0.0, 1.0), float2(2.0, 3.0)));
static const bool cz0_ = ZeroValuebool();
Expand All @@ -62,7 +55,6 @@ static const uint2 cz4_ = ZeroValueuint2();
static const float2x2 cz5_ = ZeroValuefloat2x2();
static const Foo cz6_[3] = ZeroValuearray3_Foo_();
static const Foo cz7_ = ZeroValueFoo();
static const int cp3_[4] = Constructarray4_int_(int(0), int(1), int(2), int(3));

Foo ConstructFoo(float4 arg0, int arg1) {
Foo ret = (Foo)0;
Expand All @@ -71,6 +63,12 @@ Foo ConstructFoo(float4 arg0, int arg1) {
return ret;
}

typedef int ret_Constructarray4_int_[4];
ret_Constructarray4_int_ Constructarray4_int_(int arg0, int arg1, int arg2, int arg3) {
int ret[4] = { arg0, arg1, arg2, arg3 };
return ret;
}

float2x3 ZeroValuefloat2x3() {
return (float2x3)0;
}
Expand Down
28 changes: 3 additions & 25 deletions naga/tests/out/ir/const_assert.compact.ron
Original file line number Diff line number Diff line change
@@ -1,36 +1,14 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Sint,
width: 4,
)),
),
],
types: [],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {},
),
constants: [
(
name: Some("x"),
ty: 0,
init: 0,
),
(
name: Some("y"),
ty: 0,
init: 1,
),
],
constants: [],
overrides: [],
global_variables: [],
global_expressions: [
Literal(I32(1)),
Literal(I32(2)),
],
global_expressions: [],
functions: [
(
name: Some("foo"),
Expand Down
Loading