Skip to content

Commit

Permalink
[naga] Use const ctx instead of global ctx for type resolution (gfx-r…
Browse files Browse the repository at this point in the history
…s#6935)

Signed-off-by: sagudev <[email protected]>
  • Loading branch information
sagudev authored Feb 24, 2025
1 parent e95f6d6 commit 2f255ed
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 190 deletions.
8 changes: 4 additions & 4 deletions naga/src/front/wgsl/lower/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ impl<'source> Lowerer<'source, '_> {
}
ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size },
ast::ConstructorType::Vector { size, ty, ty_span } => {
let ty = self.resolve_ast_type(ty, &mut ctx.as_global())?;
let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
let scalar = match ctx.module.types[ty].inner {
crate::TypeInner::Scalar(sc) => sc,
_ => return Err(Error::UnknownScalarType(ty_span)),
Expand All @@ -596,7 +596,7 @@ impl<'source> Lowerer<'source, '_> {
ty,
ty_span,
} => {
let ty = self.resolve_ast_type(ty, &mut ctx.as_global())?;
let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?;
let scalar = match ctx.module.types[ty].inner {
crate::TypeInner::Scalar(sc) => sc,
_ => return Err(Error::UnknownScalarType(ty_span)),
Expand All @@ -613,8 +613,8 @@ impl<'source> Lowerer<'source, '_> {
}
ast::ConstructorType::PartialArray => Constructor::PartialArray,
ast::ConstructorType::Array { base, size } => {
let base = self.resolve_ast_type(base, &mut ctx.as_global())?;
let size = self.array_size(size, &mut ctx.as_global())?;
let base = self.resolve_ast_type(base, &mut ctx.as_const())?;
let size = self.array_size(size, &mut ctx.as_const())?;

ctx.layouter.update(ctx.module.to_ctx()).unwrap();
let stride = ctx.layouter[base].to_stride();
Expand Down
93 changes: 49 additions & 44 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
}
}

#[allow(dead_code)]
fn as_global(&mut self) -> GlobalContext<'a, '_, '_> {
GlobalContext {
ast_expressions: self.ast_expressions,
Expand Down Expand Up @@ -468,29 +469,28 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
.map_err(|e| Error::ConstantEvaluatorError(e.into(), span))
}

fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> {
fn const_eval_expr_to_u32(
&self,
handle: Handle<crate::Expression>,
) -> Result<u32, crate::proc::U32EvalError> {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => {
if !ctx.local_expression_kind_tracker.is_const(handle) {
return None;
return Err(crate::proc::U32EvalError::NonConst);
}

self.module
.to_ctx()
.eval_expr_to_u32_from(handle, &ctx.function.expressions)
.ok()
}
ExpressionContextType::Constant(Some(ref ctx)) => {
assert!(ctx.local_expression_kind_tracker.is_const(handle));
self.module
.to_ctx()
.eval_expr_to_u32_from(handle, &ctx.function.expressions)
.ok()
}
ExpressionContextType::Constant(None) => {
self.module.to_ctx().eval_expr_to_u32(handle).ok()
}
ExpressionContextType::Override => None,
ExpressionContextType::Constant(None) => self.module.to_ctx().eval_expr_to_u32(handle),
ExpressionContextType::Override => Err(crate::proc::U32EvalError::NonConst),
}
}

Expand Down Expand Up @@ -1069,7 +1069,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
ast::GlobalDeclKind::Var(ref v) => {
let explicit_ty =
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx))
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_const()))
.transpose()?;

let (ty, initializer) =
Expand Down Expand Up @@ -1102,7 +1102,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let mut ectx = ctx.as_const();

let explicit_ty =
c.ty.map(|ast| self.resolve_ast_type(ast, &mut ectx.as_global()))
c.ty.map(|ast| self.resolve_ast_type(ast, &mut ectx))
.transpose()?;

let (ty, init) =
Expand All @@ -1123,7 +1123,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
ast::GlobalDeclKind::Override(ref o) => {
let explicit_ty =
o.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx))
o.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_const()))
.transpose()?;

let mut ectx = ctx.as_override();
Expand Down Expand Up @@ -1165,7 +1165,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let ty = self.resolve_named_ast_type(
alias.ty,
Some(alias.name.name.to_string()),
&mut ctx,
&mut ctx.as_const(),
)?;
ctx.globals
.insert(alias.name.name, LoweredGlobalDecl::Type(ty));
Expand Down Expand Up @@ -1263,7 +1263,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.iter()
.enumerate()
.map(|(i, arg)| -> Result<_, Error<'_>> {
let ty = self.resolve_ast_type(arg.ty, ctx)?;
let ty = self.resolve_ast_type(arg.ty, &mut ctx.as_const())?;
let expr = expressions
.append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
local_table.insert(arg.handle, Declared::Runtime(Typed::Plain(expr)));
Expand All @@ -1282,7 +1282,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.result
.as_ref()
.map(|res| -> Result<_, Error<'_>> {
let ty = self.resolve_ast_type(res.ty, ctx)?;
let ty = self.resolve_ast_type(res.ty, &mut ctx.as_const())?;
Ok(crate::FunctionResult {
ty,
binding: self.binding(&res.binding, ty, ctx)?,
Expand Down Expand Up @@ -1440,9 +1440,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// optimization.
ctx.local_expression_kind_tracker.force_non_const(value);

let explicit_ty =
l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global()))
.transpose()?;
let explicit_ty = l
.ty
.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_const(block, &mut emitter)))
.transpose()?;

if let Some(ty) = explicit_ty {
let mut ctx = ctx.as_expression(block, &mut emitter);
Expand All @@ -1469,12 +1470,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Ok(());
}
ast::LocalDecl::Var(ref v) => {
let explicit_ty =
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_global()))
.transpose()?;

let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);

let explicit_ty =
v.ty.map(|ast| {
self.resolve_ast_type(ast, &mut ctx.as_const(block, &mut emitter))
})
.transpose()?;

let mut ectx = ctx.as_expression(block, &mut emitter);
let (ty, initializer) =
self.type_and_init(v.name, v.init, explicit_ty, &mut ectx)?;
Expand Down Expand Up @@ -1533,11 +1537,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let ectx = &mut ctx.as_const(block, &mut emitter);

let explicit_ty =
c.ty.map(|ast| self.resolve_ast_type(ast, &mut ectx.as_global()))
c.ty.map(|ast| self.resolve_ast_type(ast, &mut ectx.as_const()))
.transpose()?;

let (_ty, init) =
self.type_and_init(c.name, Some(c.init), explicit_ty, ectx)?;
let (_ty, init) = self.type_and_init(
c.name,
Some(c.init),
explicit_ty,
&mut ectx.as_const(),
)?;
let init = init.expect("Local const must have init");

block.extend(emitter.finish(&ctx.function.expressions));
Expand Down Expand Up @@ -1992,7 +2000,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
}

lowered_base.map(|base| match ctx.const_access(index) {
lowered_base.map(|base| match ctx.const_eval_expr_to_u32(index).ok() {
Some(index) => crate::Expression::AccessIndex { base, index },
None => crate::Expression::Access { base, index },
})
Expand Down Expand Up @@ -2069,7 +2077,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
ast::Expression::Bitcast { expr, to, ty_span } => {
let expr = self.expression(expr, ctx)?;
let to_resolved = self.resolve_ast_type(to, &mut ctx.as_global())?;
let to_resolved = self.resolve_ast_type(to, &mut ctx.as_const())?;

let element_scalar = match ctx.module.types[to_resolved].inner {
crate::TypeInner::Scalar(scalar) => scalar,
Expand Down Expand Up @@ -3051,7 +3059,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let mut members = Vec::with_capacity(s.members.len());

for member in s.members.iter() {
let ty = self.resolve_ast_type(member.ty, ctx)?;
let ty = self.resolve_ast_type(member.ty, &mut ctx.as_const())?;

ctx.layouter.update(ctx.module.to_ctx()).unwrap();

Expand Down Expand Up @@ -3138,25 +3146,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
fn array_size(
&mut self,
size: ast::ArraySize<'source>,
ctx: &mut GlobalContext<'source, '_, '_>,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<crate::ArraySize, Error<'source>> {
Ok(match size {
ast::ArraySize::Constant(expr) => {
let span = ctx.ast_expressions.get_span(expr);
let const_expr = self.expression(expr, &mut ctx.as_const());
match const_expr {
Ok(value) => {
let len =
ctx.module.to_ctx().eval_expr_to_u32(value).map_err(
|err| match err {
crate::proc::U32EvalError::NonConst => {
Error::ExpectedConstExprConcreteIntegerScalar(span)
}
crate::proc::U32EvalError::Negative => {
Error::ExpectedPositiveArrayLength(span)
}
},
)?;
let len = ctx.const_eval_expr_to_u32(value).map_err(|err| match err {
crate::proc::U32EvalError::NonConst => {
Error::ExpectedConstExprConcreteIntegerScalar(span)
}
crate::proc::U32EvalError::Negative => {
Error::ExpectedPositiveArrayLength(span)
}
})?;
let size =
NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?;
crate::ArraySize::Constant(size)
Expand All @@ -3167,7 +3172,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
crate::proc::ConstantEvaluatorError::OverrideExpr => {
crate::ArraySize::Pending(self.array_size_override(
expr,
&mut ctx.as_override(),
&mut ctx.as_global().as_override(),
span,
)?)
}
Expand Down Expand Up @@ -3219,7 +3224,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
&mut self,
handle: Handle<ast::Type<'source>>,
name: Option<String>,
ctx: &mut GlobalContext<'source, '_, '_>,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<Handle<crate::Type>, Error<'source>> {
let inner = match ctx.types[handle] {
ast::Type::Scalar(scalar) => scalar.to_inner_scalar(),
Expand Down Expand Up @@ -3257,7 +3262,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
crate::TypeInner::Pointer { base, space }
}
ast::Type::Array { base, size } => {
let base = self.resolve_ast_type(base, ctx)?;
let base = self.resolve_ast_type(base, &mut ctx.as_const())?;
let size = self.array_size(size, ctx)?;

ctx.layouter.update(ctx.module.to_ctx()).unwrap();
Expand Down Expand Up @@ -3297,14 +3302,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
};

Ok(ctx.ensure_type_exists(name, inner))
Ok(ctx.as_global().ensure_type_exists(name, inner))
}

/// Return a Naga `Handle<Type>` representing the front-end type `handle`.
fn resolve_ast_type(
&mut self,
handle: Handle<ast::Type<'source>>,
ctx: &mut GlobalContext<'source, '_, '_>,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<Handle<crate::Type>, Error<'source>> {
self.resolve_named_ast_type(handle, None, ctx)
}
Expand Down
6 changes: 6 additions & 0 deletions naga/tests/in/const-exprs.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ fn main() {
splat_of_constant();
compose_of_constant();
compose_of_splat();
test_local_const();
}

// Swizzle the value of nested Compose expressions.
Expand Down Expand Up @@ -109,3 +110,8 @@ fn relational() {
var vec_all_false = all(vec4(vec3(vec2<bool>(), TRUE), false));
var vec_all_true = all(vec4(true));
}

fn test_local_const() {
const local_const = 2;
var arr: array<f32, local_const>;
}
6 changes: 6 additions & 0 deletions naga/tests/out/glsl/const-exprs.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ void compose_of_splat() {
return;
}

void test_local_const() {
float arr[2] = float[2](0.0, 0.0);
return;
}

uint map_texture_kind(int texture_kind) {
switch(texture_kind) {
case 0: {
Expand Down Expand Up @@ -115,6 +120,7 @@ void main() {
splat_of_constant();
compose_of_constant();
compose_of_splat();
test_local_const();
return;
}

8 changes: 8 additions & 0 deletions naga/tests/out/hlsl/const-exprs.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ void compose_of_splat()
return;
}

void test_local_const()
{
float arr[2] = (float[2])0;

return;
}

uint map_texture_kind(int texture_kind)
{
switch(texture_kind) {
Expand Down Expand Up @@ -128,5 +135,6 @@ void main()
splat_of_constant();
compose_of_constant();
compose_of_splat();
test_local_const();
return;
}
10 changes: 10 additions & 0 deletions naga/tests/out/msl/const-exprs.msl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

using metal::uint;

struct type_6 {
float inner[2];
};
constant uint TWO = 2u;
constant int THREE = 3;
constant bool TRUE = true;
Expand Down Expand Up @@ -76,6 +79,12 @@ void compose_of_splat(
return;
}

void test_local_const(
) {
type_6 arr = {};
return;
}

uint map_texture_kind(
int texture_kind
) {
Expand Down Expand Up @@ -125,5 +134,6 @@ kernel void main_(
splat_of_constant();
compose_of_constant();
compose_of_splat();
test_local_const();
return;
}
Loading

0 comments on commit 2f255ed

Please sign in to comment.