Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wgsl in fix sign #4

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,31 @@ impl<W: Write> Writer<W> {
Ok(())
}

/// Emit code for the sign(i32) expression.
///
fn put_isign(
&mut self,
arg: Handle<crate::Expression>,
context: &ExpressionContext,
) -> BackendResult {
write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
match context.resolve_type(arg) {
&crate::TypeInner::Vector { size, .. } => {
let size = back::vector_size_str(size);
write!(self.out, "int{size}(-1), int{size}(1)")?;
}
_ => {
write!(self.out, "-1, 1")?;
}
}
write!(self.out, ", (")?;
self.put_expression(arg, context, true)?;
write!(self.out, " > 0)), 0, (")?;
self.put_expression(arg, context, true)?;
write!(self.out, " == 0))")?;
Ok(())
}

fn put_const_expression(
&mut self,
expr_handle: Handle<crate::Expression>,
Expand Down Expand Up @@ -1647,8 +1672,9 @@ impl<W: Write> Writer<W> {
} => {
use crate::MathFunction as Mf;

let scalar_argument = match *context.resolve_type(arg) {
crate::TypeInner::Scalar { .. } => true,
let arg_type = context.resolve_type(arg);
let scalar_argument = match arg_type {
&crate::TypeInner::Scalar { .. } => true,
_ => false,
};

Expand Down Expand Up @@ -1713,7 +1739,12 @@ impl<W: Write> Writer<W> {
Mf::Reflect => "reflect",
Mf::Refract => "refract",
// computational
Mf::Sign => "sign",
Mf::Sign => match arg_type.scalar_kind() {
Some(crate::ScalarKind::Sint) => {
return self.put_isign(arg, context);
}
_ => "sign",
},
Mf::Fma => "fma",
Mf::Mix => "mix",
Mf::Step => "step",
Expand Down Expand Up @@ -2423,6 +2454,16 @@ impl<W: Write> Writer<W> {
crate::MathFunction::FindMsb => {
self.need_bake_expressions.insert(arg);
}
crate::MathFunction::Sign => {
// WGSL's `sign` function works also on signed ints, but Metal's only
// works on floating points, so we emit inline code for integer `sign`
// calls. But that code uses each argument 2 times (see `put_isign`),
// so to avoid duplicated evaluation, we must bake the argument.
let inner = context.resolve_type(expr_handle);
if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
self.need_bake_expressions.insert(arg);
}
}
_ => {}
}
}
Expand Down
17 changes: 16 additions & 1 deletion src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,6 @@ impl super::Validator {
| Mf::Log
| Mf::Log2
| Mf::Length
| Mf::Sign
| Mf::Sqrt
| Mf::InverseSqrt => {
if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
Expand All @@ -992,6 +991,22 @@ impl super::Validator {
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::Sign => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar {
kind: Sk::Float | Sk::Sint,
..
}
| Ti::Vector {
kind: Sk::Float | Sk::Sint,
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), None, None) => ty1,
Expand Down
4 changes: 4 additions & 0 deletions tests/in/math-functions.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ fn main() {
let d = radians(v);
let e = saturate(v);
let g = refract(v, v, f);
let sign_a = sign(-1);
let sign_b = sign(vec4(-1));
let sign_c = sign(-1.0);
let sign_d = sign(vec4(-1.0));
let const_dot = dot(vec2<i32>(), vec2<i32>());
let first_leading_bit_abs = firstLeadingBit(abs(0u));
let flb_a = firstLeadingBit(-1);
Expand Down
8 changes: 6 additions & 2 deletions tests/out/glsl/math-functions.main.Fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ void main() {
vec4 d = radians(v);
vec4 e = clamp(v, vec4(0.0), vec4(1.0));
vec4 g = refract(v, v, 1.0);
int sign_a = sign(-1);
ivec4 sign_b = sign(ivec4(-1));
float sign_c = sign(-1.0);
vec4 sign_d = sign(vec4(-1.0));
int const_dot = ( + ivec2(0).x * ivec2(0).x + ivec2(0).y * ivec2(0).y);
uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u)))));
int flb_a = findMSB(-1);
Expand All @@ -81,8 +85,8 @@ void main() {
ivec2 ctz_h = ivec2(min(uvec2(findLSB(ivec2(1))), uvec2(32u)));
int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1));
uint clz_b = uint(31 - findMSB(1u));
ivec2 _e58 = ivec2(-1);
ivec2 clz_c = mix(ivec2(31) - findMSB(_e58), ivec2(0), lessThan(_e58, ivec2(0)));
ivec2 _e68 = ivec2(-1);
ivec2 clz_c = mix(ivec2(31) - findMSB(_e68), ivec2(0), lessThan(_e68, ivec2(0)));
uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u)));
float lde_a = ldexp(1.0, 2);
vec2 lde_b = ldexp(vec2(1.0, 2.0), ivec2(3, 4));
Expand Down
8 changes: 6 additions & 2 deletions tests/out/hlsl/math-functions.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ void main()
float4 d = radians(v);
float4 e = saturate(v);
float4 g = refract(v, v, 1.0);
int sign_a = sign(-1);
int4 sign_b = sign((-1).xxxx);
float sign_c = sign(-1.0);
float4 sign_d = sign((-1.0).xxxx);
int const_dot = dot((int2)0, (int2)0);
uint first_leading_bit_abs = firstbithigh(abs(0u));
int flb_a = asint(firstbithigh(-1));
Expand All @@ -91,8 +95,8 @@ void main()
int2 ctz_h = asint(min((32u).xx, firstbitlow((1).xx)));
int clz_a = (-1 < 0 ? 0 : 31 - asint(firstbithigh(-1)));
uint clz_b = (31u - firstbithigh(1u));
int2 _expr58 = (-1).xx;
int2 clz_c = (_expr58 < (0).xx ? (0).xx : (31).xx - asint(firstbithigh(_expr58)));
int2 _expr68 = (-1).xx;
int2 clz_c = (_expr68 < (0).xx ? (0).xx : (31).xx - asint(firstbithigh(_expr68)));
uint2 clz_d = ((31u).xx - firstbithigh((1u).xx));
float lde_a = ldexp(1.0, 2);
float2 lde_b = ldexp(float2(1.0, 2.0), int2(3, 4));
Expand Down
17 changes: 11 additions & 6 deletions tests/out/msl/math-functions.msl
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,19 @@ fragment void main_(
metal::float4 d = ((v) * 0.017453292519943295474);
metal::float4 e = metal::saturate(v);
metal::float4 g = metal::refract(v, v, 1.0);
int sign_a = metal::select(metal::select(-1, 1, (-1 > 0)), 0, (-1 == 0));
metal::int4 _e12 = metal::int4(-1);
metal::int4 sign_b = metal::select(metal::select(int4(-1), int4(1), (_e12 > 0)), 0, (_e12 == 0));
float sign_c = metal::sign(-1.0);
metal::float4 sign_d = metal::sign(metal::float4(-1.0));
int const_dot = ( + metal::int2 {}.x * metal::int2 {}.x + metal::int2 {}.y * metal::int2 {}.y);
uint _e13 = metal::abs(0u);
uint first_leading_bit_abs = metal::select(31 - metal::clz(_e13), uint(-1), _e13 == 0 || _e13 == -1);
uint _e23 = metal::abs(0u);
uint first_leading_bit_abs = metal::select(31 - metal::clz(_e23), uint(-1), _e23 == 0 || _e23 == -1);
int flb_a = metal::select(31 - metal::clz(metal::select(-1, ~-1, -1 < 0)), int(-1), -1 == 0 || -1 == -1);
metal::int2 _e18 = metal::int2(-1);
metal::int2 flb_b = metal::select(31 - metal::clz(metal::select(_e18, ~_e18, _e18 < 0)), int2(-1), _e18 == 0 || _e18 == -1);
metal::uint2 _e21 = metal::uint2(1u);
metal::uint2 flb_c = metal::select(31 - metal::clz(_e21), uint2(-1), _e21 == 0 || _e21 == -1);
metal::int2 _e28 = metal::int2(-1);
metal::int2 flb_b = metal::select(31 - metal::clz(metal::select(_e28, ~_e28, _e28 < 0)), int2(-1), _e28 == 0 || _e28 == -1);
metal::uint2 _e31 = metal::uint2(1u);
metal::uint2 flb_c = metal::select(31 - metal::clz(_e31), uint2(-1), _e31 == 0 || _e31 == -1);
int ftb_a = (((metal::ctz(-1) + 1) % 33) - 1);
uint ftb_b = (((metal::ctz(1u) + 1) % 33) - 1);
metal::int2 ftb_c = (((metal::ctz(metal::int2(-1)) + 1) % 33) - 1);
Expand Down
Loading
Loading