Skip to content

Commit

Permalink
Changed implementation of ReLU to the faster equivalent functionality…
Browse files Browse the repository at this point in the history
… from BitwiseReLU, and deleted BitwiseReLU. The only value were the two implementations diverge is double.NaN, and that value is avoided throughout the entire library.
  • Loading branch information
colgreen committed Mar 31, 2023
1 parent 3ec2787 commit 3424bba
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ public class ActivationFunctionsBenchmarks
static readonly IActivationFunction<double> __PolynomialApproximantSteep = new PolynomialApproximantSteep();
static readonly IActivationFunction<double> __QuadraticSigmoid = new QuadraticSigmoid();
static readonly IActivationFunction<double> __ReLU = new ReLU();
static readonly IActivationFunction<double> __BitwiseReLU = new BitwiseReLU();
static readonly IActivationFunction<double> __ScaledELU = new ScaledELU();
static readonly IActivationFunction<double> __SoftSignSteep = new SoftSignSteep();
static readonly IActivationFunction<double> __SReLU = new SReLU();
Expand Down Expand Up @@ -126,12 +125,6 @@ public void ReLU()
RunBenchmark(__ReLU);
}

[Benchmark]
public void BitwiseReLU()
{
RunBenchmark(__BitwiseReLU);
}

[Benchmark]
public void ScaledELU()
{
Expand Down
73 changes: 0 additions & 73 deletions src/SharpNeat/NeuralNets/Double/ActivationFunctions/BitwiseReLU.cs

This file was deleted.

26 changes: 20 additions & 6 deletions src/SharpNeat/NeuralNets/Double/ActivationFunctions/ReLU.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,31 @@ public sealed class ReLU : IActivationFunction<double>
/// <inheritdoc/>
public void Fn(ref double x)
{
if(x < 0.0)
x = 0.0;
// Calculate the equivalent of:
//
// return x < 0.0 ? 0.0 : x;
//
// The approach used here uses bit manipulation of the double precision bits to achieve faster performance. The
// performance improvement is due to the avoidance of the conditional branch.

// Get the bits of the double as a signed long (noting that the high bit is the sign bit for both double and
// long).
long xlong = Unsafe.As<double,long>(ref x);

// Shift xlong right 63 bits. This shifts all of the value bits out of the value; these bits are replaced with
// the sign bit (which is how shift right works for signed types). Therefore, if xlong was negative then all
// the bits are set to 1 (including the sign bit), otherwise they are all set to zero.
// We then take the complement (flip all the bits), and bitwise AND the result with the original value of xlong.
// This means that we AND xlong with zeros when x is negative, and AND with all ones when the x is positive,
// thus achieving the ReLU function without using a conditional branch.
x = BitConverter.Int64BitsToDouble(xlong & ~(xlong >> 63));
}

/// <inheritdoc/>
public void Fn(ref double x, ref double y)
{
y = x;

if(x < 0.0)
y = 0.0;
long xlong = Unsafe.As<double,long>(ref x);
y = BitConverter.Int64BitsToDouble(xlong & ~(xlong >> 63));
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ public void TestMonotonicity()
AssertMonotonic(new PolynomialApproximantSteep(), true);
AssertMonotonic(new QuadraticSigmoid(), false);
AssertMonotonic(new ReLU(), false);
AssertMonotonic(new BitwiseReLU(), false);
AssertMonotonic(new ScaledELU(), true);
AssertMonotonic(new SoftSignSteep(), true);
AssertMonotonic(new SReLU(), true);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using FluentAssertions;
using Xunit;

namespace SharpNeat.NeuralNets.Double.ActivationFunctions;

#pragma warning disable xUnit1025 // InlineData should be unique within the Theory it belongs to

public class ReLUTests
{
[Theory]
[InlineData(0.0)]
[InlineData(-0.0)]
[InlineData(-0.000001)]
[InlineData(+0.000001)]
[InlineData(-0.1)]
[InlineData(0.1)]
[InlineData(-1.1)]
[InlineData(1.1)]
[InlineData(-1_000_000.0)]
[InlineData(1_000_000.0)]
[InlineData(double.Epsilon)]
[InlineData(-double.Epsilon)]
[InlineData(double.MinValue)]
[InlineData(double.MaxValue)]
[InlineData(double.PositiveInfinity)]
[InlineData(double.NegativeInfinity)]
public void BitwiseReLUGivesCorrectResponses(double x)
{
// Arrange.
var relu = new ReLU();

// Act.
double actual = x;
relu.Fn(ref actual);

// Assert.
double expected = x < 0.0 ? 0.0 : x;
actual.Should().Be(expected);
}
}

0 comments on commit 3424bba

Please sign in to comment.