From e1b62e73d8a1fbf6bcff7788be90d53ec67ac50d Mon Sep 17 00:00:00 2001 From: Kent Slaney Date: Fri, 7 Feb 2025 08:29:10 -0800 Subject: [PATCH] [naga] Correct override resolution in array lengths. When the user provides values for a module's overrides, rather than replacing override-sized array types with ordinary array types (which could require adjusting type handles throughout the module), instead edit all overrides to have initializers that are fully-evaluated constant expressions. Then, change all backends to handle override-sized arrays by retrieving their overrides' values. For arrays whose sizes are override expressions, not simple references to a specific override's value, let front ends built array types that refer to anonymous overrides whose initializers are the necessary expression. This means that all arrays whose sizes are override expressions are references to some `Override`. Remove `naga::PendingArraySize`, and let `ArraySize::Pending` hold a `Handle` in all cases. Expand `tests/gpu-tests/shader/array_size_overrides.rs` to include the test case that motivated this approach. --- CHANGELOG.md | 1 + naga/src/arena/mod.rs | 12 +++ naga/src/back/glsl/mod.rs | 25 ++--- naga/src/back/hlsl/conv.rs | 17 ++- naga/src/back/hlsl/mod.rs | 2 + naga/src/back/hlsl/writer.rs | 18 ++-- naga/src/back/msl/mod.rs | 2 + naga/src/back/msl/writer.rs | 25 ++--- naga/src/back/pipeline_constants.rs | 102 +++++------------- naga/src/back/spv/index.rs | 5 +- naga/src/back/spv/mod.rs | 2 + naga/src/back/spv/writer.rs | 40 +++---- naga/src/compact/mod.rs | 33 ++++-- naga/src/compact/types.rs | 26 ++--- naga/src/front/wgsl/lower/mod.rs | 15 ++- naga/src/lib.rs | 11 +- naga/src/proc/index.rs | 99 +++++++++++++++-- naga/src/proc/mod.rs | 42 ++++++++ naga/src/proc/type_methods.rs | 10 +- naga/src/valid/expression.rs | 11 +- naga/src/valid/handles.rs | 36 ++++--- naga/src/valid/mod.rs | 54 +++++----- naga/src/valid/type.rs | 1 + naga/tests/validation.rs | 6 +- .../gpu-tests/shader/array_size_overrides.rs | 18 +++- 25 files changed, 353 insertions(+), 260 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cdd99d3ae..4d6c62893b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -199,6 +199,7 @@ By @Vecvec in [#6905](https://github.com/gfx-rs/wgpu/pull/6905), [#7086](https:/ - Forward '--keep-coordinate-space' flag to GLSL backend in naga-cli. By @cloone8 in [#7206](https://github.com/gfx-rs/wgpu/pull/7206). - Allow template lists to have a trailing comma. By @KentSlaney in [#7142](https://github.com/gfx-rs/wgpu/pull/7142). - Allow WGSL const declarations to have abstract types. By @jamienicol in [#7055](https://github.com/gfx-rs/wgpu/pull/7055) and [#7222](https://github.com/gfx-rs/wgpu/pull/7222). +- Allows override-sized arrays to resolve to the same size without causing the type arena to panic. By @KentSlaney in [#7082](https://github.com/gfx-rs/wgpu/pull/7082). #### General diff --git a/naga/src/arena/mod.rs b/naga/src/arena/mod.rs index fa78332cf7..7a40b09b76 100644 --- a/naga/src/arena/mod.rs +++ b/naga/src/arena/mod.rs @@ -102,6 +102,18 @@ impl Arena { .map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) }) } + /// Returns an iterator over the items stored in this arena, returning both + /// the item's handle and a reference to it. + pub fn iter_mut_span( + &mut self, + ) -> impl DoubleEndedIterator, &mut T, &Span)> + ExactSizeIterator { + self.data + .iter_mut() + .zip(self.span_info.iter()) + .enumerate() + .map(|(i, (v, span))| unsafe { (Handle::from_usize_unchecked(i), v, span) }) + } + /// Drains the arena, returning an iterator over the items stored. pub fn drain(&mut self) -> impl DoubleEndedIterator, T, Span)> { let arena = core::mem::take(self); diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 4c672c11b7..54f2b2e07f 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -539,6 +539,8 @@ pub enum Error { /// [`crate::Sampling::First`] is unsupported. #[error("`{:?}` sampling is unsupported", crate::Sampling::First)] FirstSamplingNotSupported, + #[error(transparent)] + ResolveArraySizeError(#[from] proc::ResolveArraySizeError), } /// Binary operation with a different logic on the GLSL side. @@ -612,10 +614,6 @@ impl<'a, W: Write> Writer<'a, W> { pipeline_options: &'a PipelineOptions, policies: proc::BoundsCheckPolicies, ) -> Result { - if !module.overrides.is_empty() { - return Err(Error::Override); - } - // Check if the requested version is supported if !options.version.is_supported() { log::error!("Version {}", options.version); @@ -1013,13 +1011,12 @@ impl<'a, W: Write> Writer<'a, W> { write!(self.out, "[")?; // Write the array size - // Writes nothing if `ArraySize::Dynamic` - match size { - crate::ArraySize::Constant(size) => { + // Writes nothing if `IndexableLength::Dynamic` + match size.resolve(self.module.to_ctx())? { + proc::IndexableLength::Known(size) => { write!(self.out, "{size}")?; } - crate::ArraySize::Pending(_) => unreachable!(), - crate::ArraySize::Dynamic => (), + proc::IndexableLength::Dynamic => (), } write!(self.out, "]")?; @@ -2759,7 +2756,9 @@ impl<'a, W: Write> Writer<'a, W> { write_expression(self, value)?; write!(self.out, ")")? } - _ => unreachable!(), + _ => { + return Err(Error::Override); + } } Ok(()) @@ -4574,12 +4573,8 @@ impl<'a, W: Write> Writer<'a, W> { write!(self.out, ")")?; } TypeInner::Array { base, size, .. } => { - let count = match size - .to_indexable_length(self.module) - .expect("Bad array size") - { + let count = match size.resolve(self.module.to_ctx())? { proc::IndexableLength::Known(count) => count, - proc::IndexableLength::Pending => unreachable!(), proc::IndexableLength::Dynamic => return Ok(()), }; self.write_type(base)?; diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index 6553745ac2..3b8fb8fb3d 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -54,7 +54,7 @@ impl crate::TypeInner { } } - pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> u32 { + pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> Result { match *self { Self::Matrix { columns, @@ -63,19 +63,18 @@ impl crate::TypeInner { } => { let stride = Alignment::from(rows) * scalar.width as u32; let last_row_size = rows as u32 * scalar.width as u32; - ((columns as u32 - 1) * stride) + last_row_size + Ok(((columns as u32 - 1) * stride) + last_row_size) } Self::Array { base, size, stride } => { - let count = match size { - crate::ArraySize::Constant(size) => size.get(), + let count = match size.resolve(gctx)? { + crate::proc::IndexableLength::Known(size) => size, // A dynamically-sized array has to have at least one element - crate::ArraySize::Pending(_) => unreachable!(), - crate::ArraySize::Dynamic => 1, + crate::proc::IndexableLength::Dynamic => 1, }; - let last_el_size = gctx.types[base].inner.size_hlsl(gctx); - ((count - 1) * stride) + last_el_size + let last_el_size = gctx.types[base].inner.size_hlsl(gctx)?; + Ok(((count - 1) * stride) + last_el_size) } - _ => self.size(gctx), + _ => Ok(self.size(gctx)), } } diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index c3d292e16b..a755814421 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -442,6 +442,8 @@ pub enum Error { Custom(String), #[error("overrides should not be present at this stage")] Override, + #[error(transparent)] + ResolveArraySizeError(#[from] proc::ResolveArraySizeError), } #[derive(PartialEq, Eq, Hash)] diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 02b8e9aaf3..2ed1371a10 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -268,10 +268,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { module_info: &valid::ModuleInfo, fragment_entry_point: Option<&FragmentEntryPoint<'_>>, ) -> Result { - if !module.overrides.is_empty() { - return Err(Error::Override); - } - self.reset(module); // Write special constants, if needed @@ -1129,12 +1125,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { ) -> BackendResult { write!(self.out, "[")?; - match size { - crate::ArraySize::Constant(size) => { + match size.resolve(module.to_ctx())? { + proc::IndexableLength::Known(size) => { write!(self.out, "{size}")?; } - crate::ArraySize::Pending(_) => unreachable!(), - crate::ArraySize::Dynamic => unreachable!(), + proc::IndexableLength::Dynamic => unreachable!(), } write!(self.out, "]")?; @@ -1179,7 +1174,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } let ty_inner = &module.types[member.ty].inner; - last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx()); + last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?; // The indentation is only for readability write!(self.out, "{}", back::INDENT)?; @@ -2701,7 +2696,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write_expression(self, value)?; write!(self.out, ").{number_of_components}")? } - _ => unreachable!(), + _ => { + return Err(Error::Override); + } } Ok(()) @@ -2971,7 +2968,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { index::IndexableLength::Known(limit) => { write!(self.out, "{}u", limit - 1)?; } - index::IndexableLength::Pending => unreachable!(), index::IndexableLength::Dynamic => unreachable!(), } write!(self.out, ")")?; diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 86dcd58eb7..64fbfb9cf0 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -182,6 +182,8 @@ pub enum Error { Override, #[error("bitcasting to {0:?} is not supported")] UnsupportedBitCast(crate::TypeInner), + #[error(transparent)] + ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError), } #[derive(Clone, Debug, PartialEq, thiserror::Error)] diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index bd90b96b6b..c799df743e 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1563,7 +1563,9 @@ impl Writer { put_expression(self, ctx, value)?; write!(self.out, ")")?; } - _ => unreachable!(), + _ => { + return Err(Error::Override); + } } Ok(()) @@ -2612,7 +2614,6 @@ impl Writer { self.out.write_str(") < ")?; match length { index::IndexableLength::Known(value) => write!(self.out, "{value}")?, - index::IndexableLength::Pending => unreachable!(), index::IndexableLength::Dynamic => { let global = context.function.originating_global(base).ok_or_else(|| { @@ -2749,7 +2750,7 @@ impl Writer { ) -> BackendResult { let accessing_wrapped_array = match *base_ty { crate::TypeInner::Array { - size: crate::ArraySize::Constant(_), + size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_), .. } => true, _ => false, @@ -2777,7 +2778,6 @@ impl Writer { index::IndexableLength::Known(limit) => { write!(self.out, "{}u", limit - 1)?; } - index::IndexableLength::Pending => unreachable!(), index::IndexableLength::Dynamic => { let global = context.function.originating_global(base).ok_or_else(|| { Error::GenericValidation("Could not find originating global".into()) @@ -3795,10 +3795,6 @@ impl Writer { options: &Options, pipeline_options: &PipelineOptions, ) -> Result { - if !module.overrides.is_empty() { - return Err(Error::Override); - } - self.names.clear(); self.namer.reset( module, @@ -3995,8 +3991,8 @@ impl Writer { first_time: false, }; - match size { - crate::ArraySize::Constant(size) => { + match size.resolve(module.to_ctx())? { + proc::IndexableLength::Known(size) => { writeln!(self.out, "struct {name} {{")?; writeln!( self.out, @@ -4008,10 +4004,7 @@ impl Writer { )?; writeln!(self.out, "}};")?; } - crate::ArraySize::Pending(_) => { - unreachable!() - } - crate::ArraySize::Dynamic => { + proc::IndexableLength::Dynamic => { writeln!(self.out, "typedef {base_name} {name}[1];")?; } } @@ -6694,10 +6687,8 @@ mod workgroup_mem_init { writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?; } crate::TypeInner::Array { base, size, .. } => { - let count = match size.to_indexable_length(module).expect("Bad array size") - { + let count = match size.resolve(module.to_ctx())? { proc::IndexableLength::Known(count) => count, - proc::IndexableLength::Pending => unreachable!(), proc::IndexableLength::Dynamic => unreachable!(), }; diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 4a54da5b15..6c5be58520 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -85,7 +85,8 @@ pub fn process_overrides<'a>( // An iterator through the original overrides table, consumed in // approximate tandem with the global expressions. - let mut override_iter = module.overrides.drain(); + let mut overrides = mem::take(&mut module.overrides); + let mut override_iter = overrides.iter_mut_span(); // Do two things in tandem: // @@ -164,15 +165,26 @@ pub fn process_overrides<'a>( // Finish processing any overrides we didn't visit in the loop above. for entry in override_iter { - process_override( - entry, - pipeline_constants, - &mut module, - &mut override_map, - &adjusted_global_expressions, - &mut adjusted_constant_initializers, - &mut global_expression_kind_tracker, - )?; + match *entry.1 { + Override { name: Some(_), .. } | Override { id: Some(_), .. } => { + process_override( + entry, + pipeline_constants, + &mut module, + &mut override_map, + &adjusted_global_expressions, + &mut adjusted_constant_initializers, + &mut global_expression_kind_tracker, + )?; + } + Override { + init: Some(ref mut init), + .. + } => { + *init = adjusted_global_expressions[*init]; + } + _ => {} + } } // Update the initialization expression handles of all `Constant`s @@ -204,76 +216,17 @@ pub fn process_overrides<'a>( process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?; } module.entry_points = entry_points; - - process_pending(&mut module, &override_map, &adjusted_global_expressions)?; + module.overrides = overrides; // Now that we've rewritten all the expressions, we need to // recompute their types and other metadata. For the time being, // do a full re-validation. let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); - let module_info = validator.validate_no_overrides(&module)?; + let module_info = validator.validate_resolved_overrides(&module)?; Ok((Cow::Owned(module), Cow::Owned(module_info))) } -fn process_pending( - module: &mut Module, - override_map: &HandleVec>, - adjusted_global_expressions: &HandleVec>, -) -> Result<(), PipelineConstantError> { - for (handle, ty) in module.types.clone().iter() { - if let TypeInner::Array { - base, - size: crate::ArraySize::Pending(size), - stride, - } = ty.inner - { - let expr = match size { - crate::PendingArraySize::Expression(size_expr) => { - adjusted_global_expressions[size_expr] - } - crate::PendingArraySize::Override(size_override) => { - module.constants[override_map[size_override]].init - } - }; - let value = module - .to_ctx() - .eval_expr_to_u32(expr) - .map(|n| { - if n == 0 { - Err(PipelineConstantError::ValidationError( - WithSpan::new(ValidationError::ArraySizeError { handle: expr }) - .with_span( - module.global_expressions.get_span(expr), - "evaluated to zero", - ), - )) - } else { - Ok(core::num::NonZeroU32::new(n).unwrap()) - } - }) - .map_err(|_| { - PipelineConstantError::ValidationError( - WithSpan::new(ValidationError::ArraySizeError { handle: expr }) - .with_span(module.global_expressions.get_span(expr), "negative"), - ) - })??; - module.types.replace( - handle, - crate::Type { - name: None, - inner: TypeInner::Array { - base, - size: crate::ArraySize::Constant(value), - stride, - }, - }, - ); - } - } - Ok(()) -} - fn process_workgroup_size_override( module: &mut Module, adjusted_global_expressions: &HandleVec>, @@ -313,7 +266,7 @@ fn process_workgroup_size_override( /// /// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. fn process_override( - (old_h, r#override, span): (Handle, Override, Span), + (old_h, r#override, span): (Handle, &mut Override, &Span), pipeline_constants: &PipelineConstants, module: &mut Module, override_map: &mut HandleVec>, @@ -351,13 +304,14 @@ fn process_override( // Generate a new `Constant` to represent the override's value. let constant = Constant { - name: r#override.name, + name: r#override.name.clone(), ty: r#override.ty, init, }; - let h = module.constants.append(constant, span); + let h = module.constants.append(constant, *span); override_map.insert(old_h, h); adjusted_constant_initializers.insert(h); + r#override.init = Some(init); Ok(h) } diff --git a/naga/src/back/spv/index.rs b/naga/src/back/spv/index.rs index 5dc6174ece..45d628e0f7 100644 --- a/naga/src/back/spv/index.rs +++ b/naga/src/back/spv/index.rs @@ -268,13 +268,10 @@ impl BlockContext<'_> { block: &mut Block, ) -> Result, Error> { let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types); - match sequence_ty.indexable_length(self.ir_module) { + match sequence_ty.indexable_length_resolved(self.ir_module) { Ok(crate::proc::IndexableLength::Known(known_length)) => { Ok(MaybeKnown::Known(known_length)) } - Ok(crate::proc::IndexableLength::Pending) => { - unreachable!() - } Ok(crate::proc::IndexableLength::Dynamic) => { let length_id = self.write_runtime_array_length(sequence, block)?; Ok(MaybeKnown::Computed(length_id)) diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 145e499c50..3f9fcc98e8 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -75,6 +75,8 @@ pub enum Error { Validation(&'static str), #[error("overrides should not be present at this stage")] Override, + #[error(transparent)] + ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError), } #[derive(Default)] diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index c000343354..006d7da4d1 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1277,10 +1277,10 @@ impl Writer { fn write_type_declaration_arena( &mut self, - arena: &UniqueArena, + module: &crate::Module, handle: Handle, ) -> Result { - let ty = &arena[handle]; + let ty = &module.types[handle]; // If it's a type that needs SPIR-V capabilities, request them now. // This needs to happen regardless of the LocalType lookup succeeding, // because some types which map to the same LocalType have different @@ -1313,24 +1313,26 @@ impl Writer { self.decorate(id, Decoration::ArrayStride, &[stride]); let type_id = self.get_handle_type_id(base); - match size { - crate::ArraySize::Constant(length) => { - let length_id = self.get_index_constant(length.get()); + match size.resolve(module.to_ctx())? { + crate::proc::IndexableLength::Known(length) => { + let length_id = self.get_index_constant(length); Instruction::type_array(id, type_id, length_id) } - crate::ArraySize::Pending(_) => unreachable!(), - crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id), + crate::proc::IndexableLength::Dynamic => { + Instruction::type_runtime_array(id, type_id) + } } } crate::TypeInner::BindingArray { base, size } => { let type_id = self.get_handle_type_id(base); - match size { - crate::ArraySize::Constant(length) => { - let length_id = self.get_index_constant(length.get()); + match size.resolve(module.to_ctx())? { + crate::proc::IndexableLength::Known(length) => { + let length_id = self.get_index_constant(length); Instruction::type_array(id, type_id, length_id) } - crate::ArraySize::Pending(_) => unreachable!(), - crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id), + crate::proc::IndexableLength::Dynamic => { + Instruction::type_runtime_array(id, type_id) + } } } crate::TypeInner::Struct { @@ -1340,7 +1342,7 @@ impl Writer { let mut has_runtime_array = false; let mut member_ids = Vec::with_capacity(members.len()); for (index, member) in members.iter().enumerate() { - let member_ty = &arena[member.ty]; + let member_ty = &module.types[member.ty]; match member_ty.inner { crate::TypeInner::Array { base: _, @@ -1351,7 +1353,7 @@ impl Writer { } _ => (), } - self.decorate_struct_member(id, index, member, arena)?; + self.decorate_struct_member(id, index, member, &module.types)?; let member_id = self.get_handle_type_id(member.ty); member_ids.push(member_id); } @@ -1600,7 +1602,9 @@ impl Writer { self.get_constant_composite(ty, component_ids) } - _ => unreachable!(), + _ => { + return Err(Error::Override); + } }; self.constant_ids[handle] = id; @@ -2302,7 +2306,7 @@ impl Writer { // write all types for (handle, _) in ir_module.types.iter() { - self.write_type_declaration_arena(&ir_module.types, handle)?; + self.write_type_declaration_arena(ir_module, handle)?; } // write all const-expressions as constants @@ -2422,10 +2426,6 @@ impl Writer { debug_info: &Option, words: &mut Vec, ) -> Result<(), Error> { - if !ir_module.overrides.is_empty() { - return Err(Error::Override); - } - self.reset(); // Try to find the entry point and corresponding index diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index ee3d5d3f60..97cc9a7b42 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -359,12 +359,7 @@ impl<'module> ModuleTracer<'module> { crate::TypeInner::Array { size, .. } | crate::TypeInner::BindingArray { size, .. } => match size { crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => None, - crate::ArraySize::Pending(pending) => match pending { - crate::PendingArraySize::Expression(handle) => Some(handle), - crate::PendingArraySize::Override(handle) => { - self.module.overrides[handle].init - } - }, + crate::ArraySize::Pending(handle) => self.module.overrides[handle].init, }, _ => None, }, @@ -517,12 +512,21 @@ fn type_expression_interdependence() { crate::Span::default(), ); let type_needs_expression = |module: &mut crate::Module, handle| { + let override_handle = module.overrides.append( + crate::Override { + name: None, + id: None, + ty: u32, + init: Some(handle), + }, + crate::Span::default(), + ); module.types.insert( crate::Type { name: None, inner: crate::TypeInner::Array { base: u32, - size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(handle)), + size: crate::ArraySize::Pending(override_handle), stride: 4, }, }, @@ -654,7 +658,7 @@ fn array_length_override() { name: Some("array".to_string()), inner: crate::TypeInner::Array { base: ty_bool, - size: crate::ArraySize::Pending(crate::PendingArraySize::Override(o)), + size: crate::ArraySize::Pending(o), stride: 4, }, }, @@ -760,7 +764,7 @@ fn array_length_override_mutual() { name: Some("delicious_array".to_string()), inner: Ti::Array { base: ty_u32, - size: crate::ArraySize::Pending(crate::PendingArraySize::Override(second_override)), + size: crate::ArraySize::Pending(second_override), stride: 4, }, }, @@ -795,12 +799,21 @@ fn array_length_expression() { crate::Expression::Literal(crate::Literal::U32(1)), crate::Span::default(), ); + let override_one = module.overrides.append( + crate::Override { + name: None, + id: None, + ty: ty_u32, + init: Some(one), + }, + crate::Span::default(), + ); let _ty_array = module.types.insert( crate::Type { name: Some("array".to_string()), inner: crate::TypeInner::Array { base: ty_u32, - size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(one)), + size: crate::ArraySize::Pending(override_one), stride: 4, }, }, diff --git a/naga/src/compact/types.rs b/naga/src/compact/types.rs index 2932568268..0a1db16f9f 100644 --- a/naga/src/compact/types.rs +++ b/naga/src/compact/types.rs @@ -32,19 +32,14 @@ impl TypeTracer<'_> { | Ti::BindingArray { base, size } => { self.types_used.insert(base); match size { - crate::ArraySize::Pending(pending) => match pending { - crate::PendingArraySize::Expression(expr) => { + crate::ArraySize::Pending(handle) => { + self.overrides_used.insert(handle); + let r#override = &self.overrides[handle]; + self.types_used.insert(r#override.ty); + if let Some(expr) = r#override.init { self.expressions_used.insert(expr); } - crate::PendingArraySize::Override(handle) => { - self.overrides_used.insert(handle); - let r#override = &self.overrides[handle]; - self.types_used.insert(r#override.ty); - if let Some(expr) = r#override.init { - self.expressions_used.insert(expr); - } - } - }, + } crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => {} } } @@ -94,14 +89,7 @@ impl ModuleMap { } => { adjust(base); match *size { - crate::ArraySize::Pending(crate::PendingArraySize::Expression( - ref mut size_expr, - )) => { - self.global_expressions.adjust(size_expr); - } - crate::ArraySize::Pending(crate::PendingArraySize::Override( - ref mut r#override, - )) => { + crate::ArraySize::Pending(ref mut r#override) => { self.overrides.adjust(r#override); } crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => {} diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 010d6e19b3..385a59c5b0 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -3305,14 +3305,23 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { size_expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, span: Span, - ) -> Result> { + ) -> Result, Error<'source>> { let expr = self.expression(size_expr, ctx)?; match resolve_inner!(ctx, expr).scalar_kind().ok_or(0) { Ok(crate::ScalarKind::Sint) | Ok(crate::ScalarKind::Uint) => Ok({ if let crate::Expression::Override(handle) = ctx.module.global_expressions[expr] { - crate::PendingArraySize::Override(handle) + handle } else { - crate::PendingArraySize::Expression(expr) + let ty = ctx.register_type(expr)?; + ctx.module.overrides.append( + crate::Override { + name: None, + id: None, + ty, + init: Some(expr), + }, + span, + ) } }), _ => Err(Error::ExpectedConstExprConcreteIntegerScalar(span)), diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 48156b977e..afb9f8d0e6 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -502,15 +502,6 @@ pub struct Scalar { pub width: Bytes, } -#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] -#[cfg_attr(feature = "serialize", derive(Serialize))] -#[cfg_attr(feature = "deserialize", derive(Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum PendingArraySize { - Expression(Handle), - Override(Handle), -} - /// Size of an array. #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -521,7 +512,7 @@ pub enum ArraySize { /// The array size is constant. Constant(core::num::NonZeroU32), /// The array size is an override-expression. - Pending(PendingArraySize), + Pending(Handle), /// The array size can change at runtime. Dynamic, } diff --git a/naga/src/proc/index.rs b/naga/src/proc/index.rs index ac2a6589c1..9f1c0ddb79 100644 --- a/naga/src/proc/index.rs +++ b/naga/src/proc/index.rs @@ -304,6 +304,19 @@ pub fn find_checked_indexes( /// matrices. It does not handle struct member indices; those never require /// run-time checks, so it's best to deal with them further up the call /// chain. +/// +/// This function assumes that any relevant overrides have fully-evaluated +/// constants as their values (as arranged by [`process_overrides`], for +/// example). +/// +/// [`process_overrides`]: crate::back::pipeline_constants::process_overrides +/// +/// # Panics +/// +/// - If `base` is not an indexable type, panic. +/// +/// - If `base` is an override-sized array, but the override's value is not a +/// fully-evaluated constant expression, panic. pub fn access_needs_check( base: Handle, mut index: GuardedIndex, @@ -315,7 +328,7 @@ pub fn access_needs_check( // Unwrap safety: `Err` here indicates unindexable base types and invalid // length constants, but `access_needs_check` is only used by back ends, so // validation should have caught those problems. - let length = base_inner.indexable_length(module).unwrap(); + let length = base_inner.indexable_length_resolved(module).unwrap(); index.try_resolve_to_constant(expressions, module); if let (&GuardedIndex::Known(index), &IndexableLength::Known(length)) = (&index, &length) { if index < length { @@ -357,8 +370,10 @@ impl GuardedIndex { pub enum IndexableLengthError { #[error("Type is not indexable, and has no length (validation error)")] TypeNotIndexable, - #[error("Array length constant {0:?} is invalid")] - InvalidArrayLength(Handle), + #[error(transparent)] + ResolveArraySizeError(#[from] super::ResolveArraySizeError), + #[error("Array size is still pending")] + Pending(crate::ArraySize), } impl crate::TypeInner { @@ -405,6 +420,72 @@ impl crate::TypeInner { }; Ok(IndexableLength::Known(known_length)) } + + /// Return the length of `self`, assuming overrides are yet to be supplied. + /// + /// Return the number of elements in `self`: + /// + /// - If `self` is a runtime-sized array, then return + /// [`IndexableLength::Dynamic`]. + /// + /// - If `self` is an override-sized array, then assume that override values + /// have not yet been supplied, and return [`IndexableLength::Dynamic`]. + /// + /// - Otherwise, the type simply tells us the length of `self`, so return + /// [`IndexableLength::Known`]. + /// + /// If `self` is not an indexable type at all, return an error. + /// + /// The difference between this and `indexable_length_resolved` is that we + /// treat override-sized arrays and dynamically-sized arrays both as + /// [`Dynamic`], on the assumption that our callers want to treat both cases + /// as "not yet possible to check". + /// + /// [`Dynamic`]: IndexableLength::Dynamic + pub fn indexable_length_pending( + &self, + module: &crate::Module, + ) -> Result { + let length = self.indexable_length(module); + if let Err(IndexableLengthError::Pending(_)) = length { + return Ok(IndexableLength::Dynamic); + } + length + } + + /// Return the length of `self`, assuming overrides have been resolved. + /// + /// Return the number of elements in `self`: + /// + /// - If `self` is a runtime-sized array, then return + /// [`IndexableLength::Dynamic`]. + /// + /// - If `self` is an override-sized array, then assume that the override's + /// value is a fully-evaluated constant expression, and return + /// [`IndexableLength::Known`]. Otherwise, return an error. + /// + /// - Otherwise, the type simply tells us the length of `self`, so return + /// [`IndexableLength::Known`]. + /// + /// If `self` is not an indexable type at all, return an error. + /// + /// The difference between this and `indexable_length_pending` is + /// that if `self` is override-sized, we require the override's + /// value to be known. + pub fn indexable_length_resolved( + &self, + module: &crate::Module, + ) -> Result { + let length = self.indexable_length(module); + + // If the length is override-based, then try to compute its value now. + if let Err(IndexableLengthError::Pending(size)) = length { + if let IndexableLength::Known(computed) = size.resolve(module.to_ctx())? { + return Ok(IndexableLength::Known(computed)); + } + } + length + } } /// The number of elements in an indexable type. @@ -416,8 +497,6 @@ pub enum IndexableLength { /// Values of this type always have the given number of elements. Known(u32), - Pending, - /// The number of elements is determined at runtime. Dynamic, } @@ -427,10 +506,10 @@ impl crate::ArraySize { self, _module: &crate::Module, ) -> Result { - Ok(match self { - Self::Constant(length) => IndexableLength::Known(length.get()), - Self::Pending(_) => IndexableLength::Pending, - Self::Dynamic => IndexableLength::Dynamic, - }) + match self { + Self::Constant(length) => Ok(IndexableLength::Known(length.get())), + Self::Pending(_) => Err(IndexableLengthError::Pending(self)), + Self::Dynamic => Ok(IndexableLength::Dynamic), + } } } diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 0ba561027f..44c76b911d 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -19,6 +19,7 @@ pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, Indexab pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout}; pub use namer::{EntryPointIndex, NameKey, Namer}; pub use terminator::ensure_block_returns; +use thiserror::Error; pub use typifier::{ResolveContext, ResolveError, TypeResolution}; impl From for super::Scalar { @@ -483,6 +484,47 @@ impl GlobalCtx<'_> { } } +#[derive(Error, Debug, Clone, Copy, PartialEq)] +pub enum ResolveArraySizeError { + #[error("array element count must be positive (> 0)")] + ExpectedPositiveArrayLength, + #[error("internal: array size override has not been resolved")] + NonConstArrayLength, +} + +impl crate::ArraySize { + /// Return the number of elements that `size` represents, if known at code generation time. + /// + /// If `size` is override-based, return an error unless the override's + /// initializer is a fully evaluated constant expression. You can call + /// [`pipeline_constants::process_overrides`] to supply values for a + /// module's overrides and ensure their initializers are fully evaluated, as + /// this function expects. + /// + /// [`pipeline_constants::process_overrides`]: crate::back::pipeline_constants::process_overrides + pub fn resolve(&self, gctx: GlobalCtx) -> Result { + match *self { + crate::ArraySize::Constant(length) => Ok(IndexableLength::Known(length.get())), + crate::ArraySize::Pending(handle) => { + let Some(expr) = gctx.overrides[handle].init else { + return Err(ResolveArraySizeError::NonConstArrayLength); + }; + let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err { + U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength, + U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength, + })?; + + if length == 0 { + return Err(ResolveArraySizeError::ExpectedPositiveArrayLength); + } + + Ok(IndexableLength::Known(length)) + } + crate::ArraySize::Dynamic => Ok(IndexableLength::Dynamic), + } + } +} + /// Return an iterator over the individual components assembled by a /// `Compose` expression. /// diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index 3b9e9348a9..4eda07c9b7 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -135,7 +135,7 @@ impl crate::TypeInner { } /// Get the size of this type. - pub fn size(&self, _gctx: super::GlobalCtx) -> u32 { + pub fn size(&self, gctx: super::GlobalCtx) -> u32 { match *self { Self::Scalar(scalar) | Self::Atomic(scalar) => scalar.width as u32, Self::Vector { size, scalar } => size as u32 * scalar.width as u32, @@ -151,13 +151,13 @@ impl crate::TypeInner { size, stride, } => { - let count = match size { - crate::ArraySize::Constant(count) => count.get(), + let count = match size.resolve(gctx) { + Ok(crate::proc::IndexableLength::Known(count)) => count, // any struct member or array element needing a size at pipeline-creation time // must have a creation-fixed footprint - crate::ArraySize::Pending(_) => 0, + Err(_) => 0, // A dynamically-sized array has to have at least one element - crate::ArraySize::Dynamic => 1, + Ok(crate::proc::IndexableLength::Dynamic) => 1, }; count * stride } diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 49b2fd27b2..ec16c6d958 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -223,7 +223,7 @@ impl super::Validator { crate::TypeInner::Scalar { .. } => {} _ => return Err(ConstExpressionError::InvalidSplatType(value)), }, - _ if global_expr_kind.is_const(handle) || !self.allow_overrides => { + _ if global_expr_kind.is_const(handle) || self.overrides_resolved => { return Err(ConstExpressionError::NonFullyEvaluatedConst) } // the constant evaluator will report errors about override-expressions @@ -285,11 +285,14 @@ impl super::Validator { .eval_expr_to_u32_from(index, &function.expressions) { Ok(value) => { + let length = if self.overrides_resolved { + base_type.indexable_length_resolved(module) + } else { + base_type.indexable_length_pending(module) + }?; // If we know both the length and the index, we can do the // bounds check now. - if let crate::proc::IndexableLength::Known(known_length) = - base_type.indexable_length(module)? - { + if let crate::proc::IndexableLength::Known(known_length) = length { if value >= known_length { return Err(ExpressionError::IndexOutOfBounds(base, value)); } diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 93265e17a0..62f5d00bb6 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -323,15 +323,12 @@ impl super::Validator { | crate::TypeInner::BindingArray { base, size, .. } => { handle.check_dep(base)?; match size { - crate::ArraySize::Pending(pending) => match pending { - crate::PendingArraySize::Expression(expr) => Some(expr), - crate::PendingArraySize::Override(h) => { - Self::validate_override_handle(h, overrides)?; - let r#override = &overrides[h]; - handle.check_dep(r#override.ty)?; - r#override.init - } - }, + crate::ArraySize::Pending(h) => { + Self::validate_override_handle(h, overrides)?; + let r#override = &overrides[h]; + handle.check_dep(r#override.ty)?; + r#override.init + } crate::ArraySize::Constant(_) | crate::ArraySize::Dynamic => None, } } @@ -923,7 +920,7 @@ fn constant_deps() { #[test] fn array_size_deps() { use super::Validator; - use crate::{ArraySize, Expression, PendingArraySize, Scalar, Span, Type, TypeInner}; + use crate::{ArraySize, Expression, Override, Scalar, Span, Type, TypeInner}; let nowhere = Span::default(); @@ -939,12 +936,21 @@ fn array_size_deps() { let ex_zero = m .global_expressions .append(Expression::ZeroValue(ty_u32), nowhere); + let ty_handle = m.overrides.append( + Override { + name: None, + id: None, + ty: ty_u32, + init: Some(ex_zero), + }, + nowhere, + ); let ty_arr = m.types.insert( Type { name: Some("bad_array".to_string()), inner: TypeInner::Array { base: ty_u32, - size: ArraySize::Pending(PendingArraySize::Expression(ex_zero)), + size: ArraySize::Pending(ty_handle), stride: 4, }, }, @@ -963,7 +969,7 @@ fn array_size_deps() { #[test] fn array_size_override() { use super::Validator; - use crate::{ArraySize, Override, PendingArraySize, Scalar, Span, Type, TypeInner}; + use crate::{ArraySize, Override, Scalar, Span, Type, TypeInner}; let nowhere = Span::default(); @@ -983,7 +989,7 @@ fn array_size_override() { name: Some("bad_array".to_string()), inner: TypeInner::Array { base: ty_u32, - size: ArraySize::Pending(PendingArraySize::Override(bad_override)), + size: ArraySize::Pending(bad_override), stride: 4, }, }, @@ -996,7 +1002,7 @@ fn array_size_override() { #[test] fn override_init_deps() { use super::Validator; - use crate::{ArraySize, Expression, Override, PendingArraySize, Scalar, Span, Type, TypeInner}; + use crate::{ArraySize, Expression, Override, Scalar, Span, Type, TypeInner}; let nowhere = Span::default(); @@ -1026,7 +1032,7 @@ fn override_init_deps() { name: Some("bad_array".to_string()), inner: TypeInner::Array { base: ty_u32, - size: ArraySize::Pending(PendingArraySize::Override(r#override)), + size: ArraySize::Pending(r#override), stride: 4, }, }, diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index e7a50d928d..a2cec05c12 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -281,7 +281,10 @@ pub struct Validator { valid_expression_list: Vec>, valid_expression_set: HandleSet, override_ids: FastHashSet, - allow_overrides: bool, + + /// Treat overrides whose initializers are not fully-evaluated + /// constant expressions as errors. + overrides_resolved: bool, /// A checklist of expressions that must be visited by a specific kind of /// statement. @@ -332,6 +335,13 @@ pub enum OverrideError { TypeNotScalar, #[error("Override declarations are not allowed")] NotAllowed, + #[error("Override is uninitialized")] + UninitializedOverride, + #[error("Constant expression {handle:?} is invalid")] + ConstExpression { + handle: Handle, + source: ConstExpressionError, + }, } #[derive(Clone, Debug, thiserror::Error)] @@ -472,7 +482,7 @@ impl Validator { valid_expression_list: Vec::new(), valid_expression_set: HandleSet::new(), override_ids: FastHashSet::default(), - allow_overrides: true, + overrides_resolved: false, needs_visit: HandleSet::new(), } } @@ -532,16 +542,8 @@ impl Validator { gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, ) -> Result<(), OverrideError> { - if !self.allow_overrides { - return Err(OverrideError::NotAllowed); - } - let o = &gctx.overrides[handle]; - if o.name.is_none() && o.id.is_none() { - return Err(OverrideError::MissingNameAndID); - } - if let Some(id) = o.id { if !self.override_ids.insert(id) { return Err(OverrideError::DuplicateID); @@ -570,6 +572,8 @@ impl Validator { if !decl_ty.equivalent(init_ty, gctx.types) { return Err(OverrideError::InvalidType); } + } else if self.overrides_resolved { + return Err(OverrideError::UninitializedOverride); } Ok(()) @@ -580,18 +584,22 @@ impl Validator { &mut self, module: &crate::Module, ) -> Result> { - self.allow_overrides = true; + self.overrides_resolved = false; self.validate_impl(module) } - /// Check the given module to be valid. + /// Check the given module to be valid, requiring overrides to be resolved. /// - /// With the additional restriction that overrides are not present. - pub fn validate_no_overrides( + /// This is the same as [`validate`], except that any override + /// whose value is not a fully-evaluated constant expression is + /// treated as an error. + /// + /// [`validate`]: Validator::validate + pub fn validate_resolved_overrides( &mut self, module: &crate::Module, ) -> Result> { - self.allow_overrides = false; + self.overrides_resolved = true; self.validate_impl(module) } @@ -634,20 +642,6 @@ impl Validator { } .with_span_handle(handle, &module.types) })?; - if !self.allow_overrides { - if let crate::TypeInner::Array { - size: crate::ArraySize::Pending(_), - .. - } = ty.inner - { - return Err((ValidationError::Type { - handle, - name: ty.name.clone().unwrap_or_default(), - source: TypeError::UnresolvedOverride(handle), - }) - .with_span_handle(handle, &module.types)); - } - } mod_info.type_flags.push(ty_info.flags); self.types[handle.index()] = ty_info; } @@ -702,7 +696,7 @@ impl Validator { source, } .with_span_handle(handle, &module.overrides) - })? + })?; } } diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index 5863eb813f..f24f6e5200 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -535,6 +535,7 @@ impl super::Validator { | TypeFlags::COPY | TypeFlags::HOST_SHAREABLE | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE } crate::ArraySize::Dynamic => { // Non-SIZED types may only appear as the last element of a structure. diff --git a/naga/tests/validation.rs b/naga/tests/validation.rs index 330e082e95..5de3b592a2 100644 --- a/naga/tests/validation.rs +++ b/naga/tests/validation.rs @@ -307,7 +307,7 @@ fn main() {{ ); let module = naga::front::wgsl::parse_str(&source).unwrap(); let err = valid::Validator::new(Default::default(), valid::Capabilities::all()) - .validate_no_overrides(&module) + .validate(&module) .expect_err("module should be invalid"); assert_eq!(err.emit_to_string(&source), expected_err); } @@ -381,7 +381,7 @@ fn incompatible_interpolation_and_sampling_types() { for (invalid_source, invalid_module, interpolation, sampling, interpolate_attr) in invalid_cases { let err = valid::Validator::new(Default::default(), valid::Capabilities::all()) - .validate_no_overrides(&invalid_module) + .validate(&invalid_module) .expect_err(&format!( "module should be invalid for {interpolate_attr:?}" )); @@ -679,7 +679,7 @@ error: Entry point main at Compute is invalid for (source, expected_err) in cases { let module = naga::front::wgsl::parse_str(source).unwrap(); let err = valid::Validator::new(Default::default(), valid::Capabilities::all()) - .validate_no_overrides(&module) + .validate(&module) .expect_err("module should be invalid"); println!("{}", err.emit_to_string(source)); assert_eq!(err.emit_to_string(source), expected_err); diff --git a/tests/gpu-tests/shader/array_size_overrides.rs b/tests/gpu-tests/shader/array_size_overrides.rs index 2fd96f02a5..3b3cb6f603 100644 --- a/tests/gpu-tests/shader/array_size_overrides.rs +++ b/tests/gpu-tests/shader/array_size_overrides.rs @@ -4,13 +4,29 @@ use wgpu::{BufferDescriptor, BufferUsages, MapMode, PollType}; use wgpu_test::{fail_if, gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; const SHADER: &str = r#" - override n = 8; + const testing_shared: array = array(8); + override n = testing_shared[0u]; + override m = testing_shared[0u]; var arr: array; @group(0) @binding(0) var output: array; + // When `n` is overridden to 14 above, it will generate the type `array` which + // already exists in the program because of variable declaration. Ensures naga does not panic + // when this happens. + // + // See https://github.com/gfx-rs/wgpu/issues/6722 for more info. + var testing0: array; + + // Tests whether two overrides that are initialized by the same expression + // crashes the unique types arena + // + // See https://github.com/gfx-rs/wgpu/pull/6787#pullrequestreview-2576905294 + var testing1: array; + var testing2: array; + @compute @workgroup_size(1) fn main() { // 1d spiral for (var i = 0; i < n - 2; i++) {