diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 24c54ff8ab..67ab887285 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1184,6 +1184,31 @@ impl Writer { Ok(()) } + /// Emit code for the sign(i32) expression. + /// + fn put_isign( + &mut self, + arg: Handle, + context: &ExpressionContext, + ) -> BackendResult { + write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?; + match context.resolve_type(arg) { + &crate::TypeInner::Vector { size, .. } => { + let size = back::vector_size_str(size); + write!(self.out, "int{size}(-1), int{size}(1)")?; + } + _ => { + write!(self.out, "-1, 1")?; + } + } + write!(self.out, ", (")?; + self.put_expression(arg, context, true)?; + write!(self.out, " > 0)), 0, (")?; + self.put_expression(arg, context, true)?; + write!(self.out, " == 0))")?; + Ok(()) + } + fn put_const_expression( &mut self, expr_handle: Handle, @@ -1647,8 +1672,9 @@ impl Writer { } => { use crate::MathFunction as Mf; - let scalar_argument = match *context.resolve_type(arg) { - crate::TypeInner::Scalar { .. } => true, + let arg_type = context.resolve_type(arg); + let scalar_argument = match arg_type { + &crate::TypeInner::Scalar { .. } => true, _ => false, }; @@ -1713,7 +1739,12 @@ impl Writer { Mf::Reflect => "reflect", Mf::Refract => "refract", // computational - Mf::Sign => "sign", + Mf::Sign => match arg_type.scalar_kind() { + Some(crate::ScalarKind::Sint) => { + return self.put_isign(arg, context); + } + _ => "sign", + }, Mf::Fma => "fma", Mf::Mix => "mix", Mf::Step => "step", @@ -2423,6 +2454,16 @@ impl Writer { crate::MathFunction::FindMsb => { self.need_bake_expressions.insert(arg); } + crate::MathFunction::Sign => { + // WGSL's `sign` function works also on signed ints, but Metal's only + // works on floating points, so we emit inline code for integer `sign` + // calls. But that code uses each argument 2 times (see `put_isign`), + // so to avoid duplicated evaluation, we must bake the argument. + let inner = context.resolve_type(expr_handle); + if inner.scalar_kind() == Some(crate::ScalarKind::Sint) { + self.need_bake_expressions.insert(arg); + } + } _ => {} } } diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 548b6c1451..502ce7420a 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -976,7 +976,6 @@ impl super::Validator { | Mf::Log | Mf::Log2 | Mf::Length - | Mf::Sign | Mf::Sqrt | Mf::InverseSqrt => { if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { @@ -992,6 +991,22 @@ impl super::Validator { _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } + Mf::Sign => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Scalar { + kind: Sk::Float | Sk::Sint, + .. + } + | Ti::Vector { + kind: Sk::Float | Sk::Sint, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => { let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), None, None) => ty1, diff --git a/tests/in/math-functions.wgsl b/tests/in/math-functions.wgsl index 408f8a74f8..d08e76e4f2 100644 --- a/tests/in/math-functions.wgsl +++ b/tests/in/math-functions.wgsl @@ -8,6 +8,10 @@ fn main() { let d = radians(v); let e = saturate(v); let g = refract(v, v, f); + let sign_a = sign(-1); + let sign_b = sign(vec4(-1)); + let sign_c = sign(-1.0); + let sign_d = sign(vec4(-1.0)); let const_dot = dot(vec2(), vec2()); let first_leading_bit_abs = firstLeadingBit(abs(0u)); let flb_a = firstLeadingBit(-1); diff --git a/tests/out/glsl/math-functions.main.Fragment.glsl b/tests/out/glsl/math-functions.main.Fragment.glsl index be81715ce1..37c072a6fa 100644 --- a/tests/out/glsl/math-functions.main.Fragment.glsl +++ b/tests/out/glsl/math-functions.main.Fragment.glsl @@ -62,6 +62,10 @@ void main() { vec4 d = radians(v); vec4 e = clamp(v, vec4(0.0), vec4(1.0)); vec4 g = refract(v, v, 1.0); + int sign_a = sign(-1); + ivec4 sign_b = sign(ivec4(-1)); + float sign_c = sign(-1.0); + vec4 sign_d = sign(vec4(-1.0)); int const_dot = ( + ivec2(0).x * ivec2(0).x + ivec2(0).y * ivec2(0).y); uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u))))); int flb_a = findMSB(-1); @@ -81,8 +85,8 @@ void main() { ivec2 ctz_h = ivec2(min(uvec2(findLSB(ivec2(1))), uvec2(32u))); int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1)); uint clz_b = uint(31 - findMSB(1u)); - ivec2 _e58 = ivec2(-1); - ivec2 clz_c = mix(ivec2(31) - findMSB(_e58), ivec2(0), lessThan(_e58, ivec2(0))); + ivec2 _e68 = ivec2(-1); + ivec2 clz_c = mix(ivec2(31) - findMSB(_e68), ivec2(0), lessThan(_e68, ivec2(0))); uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u))); float lde_a = ldexp(1.0, 2); vec2 lde_b = ldexp(vec2(1.0, 2.0), ivec2(3, 4)); diff --git a/tests/out/hlsl/math-functions.hlsl b/tests/out/hlsl/math-functions.hlsl index afdc6f4671..fc5cadb65e 100644 --- a/tests/out/hlsl/math-functions.hlsl +++ b/tests/out/hlsl/math-functions.hlsl @@ -72,6 +72,10 @@ void main() float4 d = radians(v); float4 e = saturate(v); float4 g = refract(v, v, 1.0); + int sign_a = sign(-1); + int4 sign_b = sign((-1).xxxx); + float sign_c = sign(-1.0); + float4 sign_d = sign((-1.0).xxxx); int const_dot = dot((int2)0, (int2)0); uint first_leading_bit_abs = firstbithigh(abs(0u)); int flb_a = asint(firstbithigh(-1)); @@ -91,8 +95,8 @@ void main() int2 ctz_h = asint(min((32u).xx, firstbitlow((1).xx))); int clz_a = (-1 < 0 ? 0 : 31 - asint(firstbithigh(-1))); uint clz_b = (31u - firstbithigh(1u)); - int2 _expr58 = (-1).xx; - int2 clz_c = (_expr58 < (0).xx ? (0).xx : (31).xx - asint(firstbithigh(_expr58))); + int2 _expr68 = (-1).xx; + int2 clz_c = (_expr68 < (0).xx ? (0).xx : (31).xx - asint(firstbithigh(_expr68))); uint2 clz_d = ((31u).xx - firstbithigh((1u).xx)); float lde_a = ldexp(1.0, 2); float2 lde_b = ldexp(float2(1.0, 2.0), int2(3, 4)); diff --git a/tests/out/msl/math-functions.msl b/tests/out/msl/math-functions.msl index 14824ec30f..dccb90ad6c 100644 --- a/tests/out/msl/math-functions.msl +++ b/tests/out/msl/math-functions.msl @@ -64,14 +64,19 @@ fragment void main_( metal::float4 d = ((v) * 0.017453292519943295474); metal::float4 e = metal::saturate(v); metal::float4 g = metal::refract(v, v, 1.0); + int sign_a = metal::select(metal::select(-1, 1, (-1 > 0)), 0, (-1 == 0)); + metal::int4 _e12 = metal::int4(-1); + metal::int4 sign_b = metal::select(metal::select(int4(-1), int4(1), (_e12 > 0)), 0, (_e12 == 0)); + float sign_c = metal::sign(-1.0); + metal::float4 sign_d = metal::sign(metal::float4(-1.0)); int const_dot = ( + metal::int2 {}.x * metal::int2 {}.x + metal::int2 {}.y * metal::int2 {}.y); - uint _e13 = metal::abs(0u); - uint first_leading_bit_abs = metal::select(31 - metal::clz(_e13), uint(-1), _e13 == 0 || _e13 == -1); + uint _e23 = metal::abs(0u); + uint first_leading_bit_abs = metal::select(31 - metal::clz(_e23), uint(-1), _e23 == 0 || _e23 == -1); int flb_a = metal::select(31 - metal::clz(metal::select(-1, ~-1, -1 < 0)), int(-1), -1 == 0 || -1 == -1); - metal::int2 _e18 = metal::int2(-1); - metal::int2 flb_b = metal::select(31 - metal::clz(metal::select(_e18, ~_e18, _e18 < 0)), int2(-1), _e18 == 0 || _e18 == -1); - metal::uint2 _e21 = metal::uint2(1u); - metal::uint2 flb_c = metal::select(31 - metal::clz(_e21), uint2(-1), _e21 == 0 || _e21 == -1); + metal::int2 _e28 = metal::int2(-1); + metal::int2 flb_b = metal::select(31 - metal::clz(metal::select(_e28, ~_e28, _e28 < 0)), int2(-1), _e28 == 0 || _e28 == -1); + metal::uint2 _e31 = metal::uint2(1u); + metal::uint2 flb_c = metal::select(31 - metal::clz(_e31), uint2(-1), _e31 == 0 || _e31 == -1); int ftb_a = (((metal::ctz(-1) + 1) % 33) - 1); uint ftb_b = (((metal::ctz(1u) + 1) % 33) - 1); metal::int2 ftb_c = (((metal::ctz(metal::int2(-1)) + 1) % 33) - 1); diff --git a/tests/out/spv/math-functions.spvasm b/tests/out/spv/math-functions.spvasm index 260c3b4bd4..7edb0e26b4 100644 --- a/tests/out/spv/math-functions.spvasm +++ b/tests/out/spv/math-functions.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 127 +; Bound: 134 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -32,116 +32,123 @@ OpMemberDecorate %13 1 Offset 16 %16 = OpTypeFunction %2 %17 = OpConstant %4 1.0 %18 = OpConstant %4 0.0 -%19 = OpConstantNull %5 -%20 = OpTypeInt 32 0 -%21 = OpConstant %20 0 -%22 = OpConstant %6 -1 -%23 = OpConstant %20 1 -%24 = OpConstant %6 0 -%25 = OpConstant %20 4294967295 -%26 = OpConstant %6 1 -%27 = OpConstant %6 2 -%28 = OpConstant %4 2.0 -%29 = OpConstant %6 3 -%30 = OpConstant %6 4 -%31 = OpConstant %4 1.5 -%39 = OpConstantComposite %3 %18 %18 %18 %18 -%40 = OpConstantComposite %3 %17 %17 %17 %17 -%43 = OpConstantNull %6 -%56 = OpTypeVector %20 2 -%66 = OpConstant %20 32 -%76 = OpConstantComposite %56 %66 %66 -%88 = OpConstant %6 31 -%94 = OpConstantComposite %5 %88 %88 +%19 = OpConstant %6 -1 +%20 = OpConstant %4 -1.0 +%21 = OpConstantNull %5 +%22 = OpTypeInt 32 0 +%23 = OpConstant %22 0 +%24 = OpConstant %22 1 +%25 = OpConstant %6 0 +%26 = OpConstant %22 4294967295 +%27 = OpConstant %6 1 +%28 = OpConstant %6 2 +%29 = OpConstant %4 2.0 +%30 = OpConstant %6 3 +%31 = OpConstant %6 4 +%32 = OpConstant %4 1.5 +%40 = OpConstantComposite %3 %18 %18 %18 %18 +%41 = OpConstantComposite %3 %17 %17 %17 %17 +%50 = OpConstantNull %6 +%63 = OpTypeVector %22 2 +%73 = OpConstant %22 32 +%83 = OpConstantComposite %63 %73 %73 +%95 = OpConstant %6 31 +%101 = OpConstantComposite %5 %95 %95 %15 = OpFunction %2 None %16 %14 = OpLabel -OpBranch %32 -%32 = OpLabel -%33 = OpCompositeConstruct %3 %18 %18 %18 %18 -%34 = OpExtInst %4 %1 Degrees %17 -%35 = OpExtInst %4 %1 Radians %17 -%36 = OpExtInst %3 %1 Degrees %33 -%37 = OpExtInst %3 %1 Radians %33 -%38 = OpExtInst %3 %1 FClamp %33 %39 %40 -%41 = OpExtInst %3 %1 Refract %33 %33 %17 -%44 = OpCompositeExtract %6 %19 0 -%45 = OpCompositeExtract %6 %19 0 -%46 = OpIMul %6 %44 %45 -%47 = OpIAdd %6 %43 %46 -%48 = OpCompositeExtract %6 %19 1 -%49 = OpCompositeExtract %6 %19 1 -%50 = OpIMul %6 %48 %49 -%42 = OpIAdd %6 %47 %50 -%51 = OpCopyObject %20 %21 -%52 = OpExtInst %20 %1 FindUMsb %51 -%53 = OpExtInst %6 %1 FindSMsb %22 -%54 = OpCompositeConstruct %5 %22 %22 -%55 = OpExtInst %5 %1 FindSMsb %54 -%57 = OpCompositeConstruct %56 %23 %23 -%58 = OpExtInst %56 %1 FindUMsb %57 -%59 = OpExtInst %6 %1 FindILsb %22 -%60 = OpExtInst %20 %1 FindILsb %23 -%61 = OpCompositeConstruct %5 %22 %22 -%62 = OpExtInst %5 %1 FindILsb %61 -%63 = OpCompositeConstruct %56 %23 %23 -%64 = OpExtInst %56 %1 FindILsb %63 -%67 = OpExtInst %20 %1 FindILsb %21 -%65 = OpExtInst %20 %1 UMin %66 %67 -%69 = OpExtInst %6 %1 FindILsb %24 -%68 = OpExtInst %6 %1 UMin %66 %69 -%71 = OpExtInst %20 %1 FindILsb %25 -%70 = OpExtInst %20 %1 UMin %66 %71 -%73 = OpExtInst %6 %1 FindILsb %22 -%72 = OpExtInst %6 %1 UMin %66 %73 -%74 = OpCompositeConstruct %56 %21 %21 -%77 = OpExtInst %56 %1 FindILsb %74 -%75 = OpExtInst %56 %1 UMin %76 %77 -%78 = OpCompositeConstruct %5 %24 %24 -%80 = OpExtInst %5 %1 FindILsb %78 -%79 = OpExtInst %5 %1 UMin %76 %80 -%81 = OpCompositeConstruct %56 %23 %23 -%83 = OpExtInst %56 %1 FindILsb %81 -%82 = OpExtInst %56 %1 UMin %76 %83 -%84 = OpCompositeConstruct %5 %26 %26 -%86 = OpExtInst %5 %1 FindILsb %84 -%85 = OpExtInst %5 %1 UMin %76 %86 -%89 = OpExtInst %6 %1 FindUMsb %22 -%87 = OpISub %6 %88 %89 -%91 = OpExtInst %6 %1 FindUMsb %23 -%90 = OpISub %20 %88 %91 -%92 = OpCompositeConstruct %5 %22 %22 -%95 = OpExtInst %5 %1 FindUMsb %92 -%93 = OpISub %5 %94 %95 -%96 = OpCompositeConstruct %56 %23 %23 -%98 = OpExtInst %5 %1 FindUMsb %96 -%97 = OpISub %56 %94 %98 -%99 = OpExtInst %4 %1 Ldexp %17 %27 -%100 = OpCompositeConstruct %7 %17 %28 -%101 = OpCompositeConstruct %5 %29 %30 -%102 = OpExtInst %7 %1 Ldexp %100 %101 -%103 = OpExtInst %8 %1 ModfStruct %31 -%104 = OpExtInst %8 %1 ModfStruct %31 -%105 = OpCompositeExtract %4 %104 0 -%106 = OpExtInst %8 %1 ModfStruct %31 -%107 = OpCompositeExtract %4 %106 1 -%108 = OpCompositeConstruct %7 %31 %31 -%109 = OpExtInst %9 %1 ModfStruct %108 -%110 = OpCompositeConstruct %3 %31 %31 %31 %31 -%111 = OpExtInst %10 %1 ModfStruct %110 -%112 = OpCompositeExtract %3 %111 1 -%113 = OpCompositeExtract %4 %112 0 -%114 = OpCompositeConstruct %7 %31 %31 -%115 = OpExtInst %9 %1 ModfStruct %114 -%116 = OpCompositeExtract %7 %115 0 -%117 = OpCompositeExtract %4 %116 1 -%118 = OpExtInst %11 %1 FrexpStruct %31 -%119 = OpExtInst %11 %1 FrexpStruct %31 +OpBranch %33 +%33 = OpLabel +%34 = OpCompositeConstruct %3 %18 %18 %18 %18 +%35 = OpExtInst %4 %1 Degrees %17 +%36 = OpExtInst %4 %1 Radians %17 +%37 = OpExtInst %3 %1 Degrees %34 +%38 = OpExtInst %3 %1 Radians %34 +%39 = OpExtInst %3 %1 FClamp %34 %40 %41 +%42 = OpExtInst %3 %1 Refract %34 %34 %17 +%43 = OpExtInst %6 %1 SSign %19 +%44 = OpCompositeConstruct %12 %19 %19 %19 %19 +%45 = OpExtInst %12 %1 SSign %44 +%46 = OpExtInst %4 %1 FSign %20 +%47 = OpCompositeConstruct %3 %20 %20 %20 %20 +%48 = OpExtInst %3 %1 FSign %47 +%51 = OpCompositeExtract %6 %21 0 +%52 = OpCompositeExtract %6 %21 0 +%53 = OpIMul %6 %51 %52 +%54 = OpIAdd %6 %50 %53 +%55 = OpCompositeExtract %6 %21 1 +%56 = OpCompositeExtract %6 %21 1 +%57 = OpIMul %6 %55 %56 +%49 = OpIAdd %6 %54 %57 +%58 = OpCopyObject %22 %23 +%59 = OpExtInst %22 %1 FindUMsb %58 +%60 = OpExtInst %6 %1 FindSMsb %19 +%61 = OpCompositeConstruct %5 %19 %19 +%62 = OpExtInst %5 %1 FindSMsb %61 +%64 = OpCompositeConstruct %63 %24 %24 +%65 = OpExtInst %63 %1 FindUMsb %64 +%66 = OpExtInst %6 %1 FindILsb %19 +%67 = OpExtInst %22 %1 FindILsb %24 +%68 = OpCompositeConstruct %5 %19 %19 +%69 = OpExtInst %5 %1 FindILsb %68 +%70 = OpCompositeConstruct %63 %24 %24 +%71 = OpExtInst %63 %1 FindILsb %70 +%74 = OpExtInst %22 %1 FindILsb %23 +%72 = OpExtInst %22 %1 UMin %73 %74 +%76 = OpExtInst %6 %1 FindILsb %25 +%75 = OpExtInst %6 %1 UMin %73 %76 +%78 = OpExtInst %22 %1 FindILsb %26 +%77 = OpExtInst %22 %1 UMin %73 %78 +%80 = OpExtInst %6 %1 FindILsb %19 +%79 = OpExtInst %6 %1 UMin %73 %80 +%81 = OpCompositeConstruct %63 %23 %23 +%84 = OpExtInst %63 %1 FindILsb %81 +%82 = OpExtInst %63 %1 UMin %83 %84 +%85 = OpCompositeConstruct %5 %25 %25 +%87 = OpExtInst %5 %1 FindILsb %85 +%86 = OpExtInst %5 %1 UMin %83 %87 +%88 = OpCompositeConstruct %63 %24 %24 +%90 = OpExtInst %63 %1 FindILsb %88 +%89 = OpExtInst %63 %1 UMin %83 %90 +%91 = OpCompositeConstruct %5 %27 %27 +%93 = OpExtInst %5 %1 FindILsb %91 +%92 = OpExtInst %5 %1 UMin %83 %93 +%96 = OpExtInst %6 %1 FindUMsb %19 +%94 = OpISub %6 %95 %96 +%98 = OpExtInst %6 %1 FindUMsb %24 +%97 = OpISub %22 %95 %98 +%99 = OpCompositeConstruct %5 %19 %19 +%102 = OpExtInst %5 %1 FindUMsb %99 +%100 = OpISub %5 %101 %102 +%103 = OpCompositeConstruct %63 %24 %24 +%105 = OpExtInst %5 %1 FindUMsb %103 +%104 = OpISub %63 %101 %105 +%106 = OpExtInst %4 %1 Ldexp %17 %28 +%107 = OpCompositeConstruct %7 %17 %29 +%108 = OpCompositeConstruct %5 %30 %31 +%109 = OpExtInst %7 %1 Ldexp %107 %108 +%110 = OpExtInst %8 %1 ModfStruct %32 +%111 = OpExtInst %8 %1 ModfStruct %32 +%112 = OpCompositeExtract %4 %111 0 +%113 = OpExtInst %8 %1 ModfStruct %32 +%114 = OpCompositeExtract %4 %113 1 +%115 = OpCompositeConstruct %7 %32 %32 +%116 = OpExtInst %9 %1 ModfStruct %115 +%117 = OpCompositeConstruct %3 %32 %32 %32 %32 +%118 = OpExtInst %10 %1 ModfStruct %117 +%119 = OpCompositeExtract %3 %118 1 %120 = OpCompositeExtract %4 %119 0 -%121 = OpExtInst %11 %1 FrexpStruct %31 -%122 = OpCompositeExtract %6 %121 1 -%123 = OpCompositeConstruct %3 %31 %31 %31 %31 -%124 = OpExtInst %13 %1 FrexpStruct %123 -%125 = OpCompositeExtract %12 %124 1 -%126 = OpCompositeExtract %6 %125 0 +%121 = OpCompositeConstruct %7 %32 %32 +%122 = OpExtInst %9 %1 ModfStruct %121 +%123 = OpCompositeExtract %7 %122 0 +%124 = OpCompositeExtract %4 %123 1 +%125 = OpExtInst %11 %1 FrexpStruct %32 +%126 = OpExtInst %11 %1 FrexpStruct %32 +%127 = OpCompositeExtract %4 %126 0 +%128 = OpExtInst %11 %1 FrexpStruct %32 +%129 = OpCompositeExtract %6 %128 1 +%130 = OpCompositeConstruct %3 %32 %32 %32 %32 +%131 = OpExtInst %13 %1 FrexpStruct %130 +%132 = OpCompositeExtract %12 %131 1 +%133 = OpCompositeExtract %6 %132 0 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/math-functions.wgsl b/tests/out/wgsl/math-functions.wgsl index 149ebff8e0..0b20291f57 100644 --- a/tests/out/wgsl/math-functions.wgsl +++ b/tests/out/wgsl/math-functions.wgsl @@ -7,6 +7,10 @@ fn main() { let d = radians(v); let e = saturate(v); let g = refract(v, v, 1.0); + let sign_a = sign(-1); + let sign_b = sign(vec4(-1)); + let sign_c = sign(-1.0); + let sign_d = sign(vec4(-1.0)); let const_dot = dot(vec2(), vec2()); let first_leading_bit_abs = firstLeadingBit(abs(0u)); let flb_a = firstLeadingBit(-1);