From ff82611078dc3caa16669bc1e9fa7643b0a2d23a Mon Sep 17 00:00:00 2001 From: michalwa Date: Fri, 8 Jun 2018 20:25:22 +0200 Subject: [PATCH 1/2] NeuralNetworkBuilder & ActivationFunctionFactory Refactoring --- .gitignore | 5 +- pom.xml | 6 ++ .../basicneuralnetwork/NeuralNetwork.java | 10 +--- .../NeuralNetworkBuilder.java | 58 +++++++++++++++++++ .../ActivationFunction.java | 7 +-- .../ActivationFunctionFactory.java | 26 +++------ .../ReLuActivationFunction.java | 6 +- .../SigmoidActivationFunction.java | 6 +- .../TanhActivationFunction.java | 8 ++- .../utilities/FileReaderAndWriter.java | 12 ++-- .../utilities/InterfaceAdapter.java | 9 +-- .../utilities/MatrixUtilities.java | 8 +-- .../NeuralNetworkBuilderTest.java | 36 ++++++++++++ 13 files changed, 147 insertions(+), 50 deletions(-) create mode 100644 src/main/java/basicneuralnetwork/NeuralNetworkBuilder.java create mode 100644 src/test/java/basicneuralnetwork/NeuralNetworkBuilderTest.java diff --git a/.gitignore b/.gitignore index c9c8d50..835ea08 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,9 @@ -.idea +.* +*.iml target out artifacts nn_data.json src/main/java/basicneuralnetwork/Test.java src/main/java/basicneuralnetwork/activationfunctions/SoftmaxActivationFunction.java -nn_data.json \ No newline at end of file +nn_data.json diff --git a/pom.xml b/pom.xml index ca6108f..d47f959 100644 --- a/pom.xml +++ b/pom.xml @@ -51,6 +51,12 @@ gson 2.8.4 + + junit + junit + 4.12 + test + diff --git a/src/main/java/basicneuralnetwork/NeuralNetwork.java b/src/main/java/basicneuralnetwork/NeuralNetwork.java index 4115b2e..8ae8865 100644 --- a/src/main/java/basicneuralnetwork/NeuralNetwork.java +++ b/src/main/java/basicneuralnetwork/NeuralNetwork.java @@ -13,8 +13,6 @@ */ public class NeuralNetwork { - private ActivationFunctionFactory activationFunctionFactory = new ActivationFunctionFactory(); - private Random random = new Random(); // Dimensions of the neural network @@ -114,7 +112,7 @@ public double[] guess(double[] input) { throw new WrongDimensionException(input.length, inputNodes, "Input"); } else { // Get ActivationFunction-object from the map by key - ActivationFunction activationFunction = activationFunctionFactory.getActivationFunctionByKey(activationFunctionKey); + ActivationFunction activationFunction = ActivationFunctionFactory.createByName(activationFunctionKey); // Transform array to matrix SimpleMatrix output = MatrixUtilities.arrayToMatrix(input); @@ -134,7 +132,7 @@ public void train(double[] inputArray, double[] targetArray) { throw new WrongDimensionException(targetArray.length, outputNodes, "Output"); } else { // Get ActivationFunction-object from the map by key - ActivationFunction activationFunction = activationFunctionFactory.getActivationFunctionByKey(activationFunctionKey); + ActivationFunction activationFunction = ActivationFunctionFactory.createByName(activationFunctionKey); // Transform 2D array to matrix SimpleMatrix input = MatrixUtilities.arrayToMatrix(inputArray); @@ -270,10 +268,6 @@ public void setActivationFunction(String activationFunction) { this.activationFunctionKey = activationFunction; } - public void addActivationFunction(String key, ActivationFunction activationFunction){ - activationFunctionFactory.addActivationFunction(key, activationFunction); - } - public double getLearningRate() { return learningRate; } diff --git a/src/main/java/basicneuralnetwork/NeuralNetworkBuilder.java b/src/main/java/basicneuralnetwork/NeuralNetworkBuilder.java new file mode 100644 index 0000000..d1f202b --- /dev/null +++ b/src/main/java/basicneuralnetwork/NeuralNetworkBuilder.java @@ -0,0 +1,58 @@ +package basicneuralnetwork; + +/** + * Created by MichalWa on 08.06.18 + */ +public class NeuralNetworkBuilder { + + private int inputNodes = 0; + private int hiddenLayers = 0; + private int hiddenNodes = 0; + private int outputNodes = 0; + private String activationFunction = null; + private double learningRate = -1.0; + + public NeuralNetworkBuilder setInputNodes(int inputNodes) { + this.inputNodes = inputNodes; + return this; + } + + public NeuralNetworkBuilder setHiddenLayers(int hiddenLayers) { + this.hiddenLayers = hiddenLayers; + return this; + } + + public NeuralNetworkBuilder setHiddenNodes(int hiddenNodes) { + this.hiddenNodes = hiddenNodes; + return this; + } + + public NeuralNetworkBuilder setOutputNodes(int outputNodes) { + this.outputNodes = outputNodes; + return this; + } + + public NeuralNetworkBuilder setActivationFunction(String activationFunction) { + this.activationFunction = activationFunction; + return this; + } + + public NeuralNetworkBuilder setLearningRate(double learningRate) { + this.learningRate = learningRate; + return this; + } + + public NeuralNetwork create() { + if(inputNodes < 1) throw new IllegalStateException("There must be 1 or more input nodes."); + if(hiddenNodes < 1) throw new IllegalStateException("There must be 1 or more hidden nodes."); + if(outputNodes < 1) throw new IllegalStateException("There must be 1 or more output nodes"); + + NeuralNetwork nn = new NeuralNetwork(inputNodes, hiddenLayers, hiddenNodes, outputNodes); + + if(activationFunction != null) nn.setActivationFunction(activationFunction); + if(learningRate != -1.0) nn.setLearningRate(learningRate); + + return nn; + } + +} diff --git a/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunction.java b/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunction.java index 2097a05..17ccdc0 100644 --- a/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunction.java +++ b/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunction.java @@ -8,9 +8,9 @@ // This interface and it's methods have to be implemented in all ActivationFunction-classes public interface ActivationFunction { - String SIGMOID = "SIGMOID"; - String TANH = "TANH"; - String RELU = "RELU"; + String SIGMOID = SigmoidActivationFunction.NAME; + String TANH = TanhActivationFunction.NAME; + String RELU = ReLuActivationFunction.NAME; // Activation function SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input); @@ -19,5 +19,4 @@ public interface ActivationFunction { SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input); String getName(); - } diff --git a/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunctionFactory.java b/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunctionFactory.java index 356ff7c..d9b221c 100644 --- a/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunctionFactory.java +++ b/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunctionFactory.java @@ -2,31 +2,21 @@ import java.util.HashMap; import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; /** * Created by KimFeichtinger on 04.05.18. */ public class ActivationFunctionFactory { - private Map activationFunctionMap = new HashMap<>(); - - public ActivationFunctionFactory () { - // Fill map with all the activation functions - ActivationFunction sigmoid = new SigmoidActivationFunction(); - activationFunctionMap.put(sigmoid.getName(), sigmoid); - - ActivationFunction tanh = new TanhActivationFunction(); - activationFunctionMap.put(tanh.getName(), tanh); - - ActivationFunction relu = new ReLuActivationFunction(); - activationFunctionMap.put(relu.getName(), relu); - } - - public ActivationFunction getActivationFunctionByKey (String activationFunctionKey) { - return activationFunctionMap.get(activationFunctionKey); + private static Map> factories = new HashMap<>(); + + public static ActivationFunction createByName(String name) { + return Optional.ofNullable(factories.get(name)).map(Supplier::get).orElse(null); } - public void addActivationFunction(String key, ActivationFunction activationFunction) { - activationFunctionMap.put(key, activationFunction); + public static void register(String key, Supplier factory) { + factories.put(key, factory); } } diff --git a/src/main/java/basicneuralnetwork/activationfunctions/ReLuActivationFunction.java b/src/main/java/basicneuralnetwork/activationfunctions/ReLuActivationFunction.java index 5cb7709..9e2a802 100644 --- a/src/main/java/basicneuralnetwork/activationfunctions/ReLuActivationFunction.java +++ b/src/main/java/basicneuralnetwork/activationfunctions/ReLuActivationFunction.java @@ -7,7 +7,11 @@ */ public class ReLuActivationFunction implements ActivationFunction { - private static final String NAME = "RELU"; + public static final String NAME = "relu"; + + static { + ActivationFunctionFactory.register(NAME, ReLuActivationFunction::new); + } public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) { SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols()); diff --git a/src/main/java/basicneuralnetwork/activationfunctions/SigmoidActivationFunction.java b/src/main/java/basicneuralnetwork/activationfunctions/SigmoidActivationFunction.java index 839c11e..117b363 100644 --- a/src/main/java/basicneuralnetwork/activationfunctions/SigmoidActivationFunction.java +++ b/src/main/java/basicneuralnetwork/activationfunctions/SigmoidActivationFunction.java @@ -7,7 +7,11 @@ */ public class SigmoidActivationFunction implements ActivationFunction { - private static final String NAME = "SIGMOID"; + public static final String NAME = "sigmoid"; + + static { + ActivationFunctionFactory.register(NAME, SigmoidActivationFunction::new); + } // Sigmoid public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) { diff --git a/src/main/java/basicneuralnetwork/activationfunctions/TanhActivationFunction.java b/src/main/java/basicneuralnetwork/activationfunctions/TanhActivationFunction.java index b4e4c4d..8a54744 100644 --- a/src/main/java/basicneuralnetwork/activationfunctions/TanhActivationFunction.java +++ b/src/main/java/basicneuralnetwork/activationfunctions/TanhActivationFunction.java @@ -6,8 +6,12 @@ * Created by KimFeichtinger on 20.04.18. */ public class TanhActivationFunction implements ActivationFunction { - - private static final String NAME = "TANH"; + + public static final String NAME = "tanh"; + + static { + ActivationFunctionFactory.register(NAME, TanhActivationFunction::new); + } public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) { SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols()); diff --git a/src/main/java/basicneuralnetwork/utilities/FileReaderAndWriter.java b/src/main/java/basicneuralnetwork/utilities/FileReaderAndWriter.java index f3093f1..0d5e941 100644 --- a/src/main/java/basicneuralnetwork/utilities/FileReaderAndWriter.java +++ b/src/main/java/basicneuralnetwork/utilities/FileReaderAndWriter.java @@ -17,11 +17,12 @@ */ public class FileReaderAndWriter { - public static void writeToFile(NeuralNetwork nn){ + private static final Gson GSON = getGsonBuilder().create(); + + public static void writeToFile(NeuralNetwork nn) { try { FileWriter file = new FileWriter("nn_data.json"); - Gson gson = getGsonBuilder().create(); - String nnData = gson.toJson(nn); + String nnData = GSON.toJson(nn); file.write(nnData); file.flush(); @@ -34,9 +35,8 @@ public static NeuralNetwork readFromFile() { NeuralNetwork nn = null; try { - Gson gson = getGsonBuilder().create(); JsonReader jsonReader = new JsonReader(new FileReader("nn_data.json")); - nn = gson.fromJson(jsonReader, NeuralNetwork.class); + nn = GSON.fromJson(jsonReader, NeuralNetwork.class); } catch (IOException e) { e.printStackTrace(); } @@ -45,7 +45,7 @@ public static NeuralNetwork readFromFile() { } // Get a GsonBuilder-object with all the needed TypeAdapters added - private static GsonBuilder getGsonBuilder(){ + private static GsonBuilder getGsonBuilder() { GsonBuilder gsonBuilder = new GsonBuilder(); gsonBuilder.registerTypeAdapter(ActivationFunction.class, new InterfaceAdapter()); diff --git a/src/main/java/basicneuralnetwork/utilities/InterfaceAdapter.java b/src/main/java/basicneuralnetwork/utilities/InterfaceAdapter.java index 9b56441..44684c0 100644 --- a/src/main/java/basicneuralnetwork/utilities/InterfaceAdapter.java +++ b/src/main/java/basicneuralnetwork/utilities/InterfaceAdapter.java @@ -10,10 +10,11 @@ import java.lang.reflect.Type; -// This class is needed to make the interfaces/ abstract classes used in this project serializable/ deserializable -// so that they can be converted to JSON or from JSON by Google Gson-library -// The solution was found here: -// https://stackoverflow.com/questions/4795349/how-to-serialize-a-class-with-an-interface/9550086#9550086 +/** This class is needed to make the interfaces/ abstract classes used in this project serializable/ deserializable + * so that they can be converted to JSON or from JSON by Google Gson-library + * + * The solution was found here: + * (link) */ public class InterfaceAdapter implements JsonSerializer, JsonDeserializer { @Override diff --git a/src/main/java/basicneuralnetwork/utilities/MatrixUtilities.java b/src/main/java/basicneuralnetwork/utilities/MatrixUtilities.java index bd71708..fb26545 100644 --- a/src/main/java/basicneuralnetwork/utilities/MatrixUtilities.java +++ b/src/main/java/basicneuralnetwork/utilities/MatrixUtilities.java @@ -10,13 +10,13 @@ */ public class MatrixUtilities { - // Converts a 2D array into a SimpleMatrix + /** Converts a 2D array into a SimpleMatrix */ public static SimpleMatrix arrayToMatrix(double[] i) { double[][] input = {i}; return new SimpleMatrix(input).transpose(); } - // Converts a SimpleMatrix into a 2D array + /** Converts a SimpleMatrix into a 2D array */ public static double[][] matrixTo2DArray(SimpleMatrix i) { double[][] result = new double[i.numRows()][i.numCols()]; @@ -28,7 +28,7 @@ public static double[][] matrixTo2DArray(SimpleMatrix i) { return result; } - // Returns one specific column of a matrix as a 1D array + /** Returns one specific column of a matrix as a 1D array */ public static double[] getColumnFromMatrixAsArray(SimpleMatrix data, int column) { double[] result = new double[data.numRows()]; @@ -39,7 +39,7 @@ public static double[] getColumnFromMatrixAsArray(SimpleMatrix data, int column) return result; } - // Merge two matrices and return a new one + /** Merge two matrices and return a new one */ public static SimpleMatrix mergeMatrices(SimpleMatrix matrixA, SimpleMatrix matrixB, double probability) { if (matrixA.numCols() != matrixB.numCols() || matrixA.numRows() != matrixB.numRows()) { throw new WrongDimensionException(); diff --git a/src/test/java/basicneuralnetwork/NeuralNetworkBuilderTest.java b/src/test/java/basicneuralnetwork/NeuralNetworkBuilderTest.java new file mode 100644 index 0000000..1a95741 --- /dev/null +++ b/src/test/java/basicneuralnetwork/NeuralNetworkBuilderTest.java @@ -0,0 +1,36 @@ +package basicneuralnetwork; + +import basicneuralnetwork.activationfunctions.SigmoidActivationFunction; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class NeuralNetworkBuilderTest { + + int inputNodes = 2; + int hiddenLayers = 3; + int hiddenNodes = 4; + int outputNodes = 5; + String activationFunction = SigmoidActivationFunction.NAME; + double learningRate = 0.5; + + @Test + public void builderTest() { + NeuralNetwork nn = new NeuralNetworkBuilder() + .setInputNodes(inputNodes) + .setHiddenLayers(hiddenLayers) + .setHiddenNodes(hiddenNodes) + .setOutputNodes(outputNodes) + .setActivationFunction(activationFunction) + .setLearningRate(learningRate) + .create(); + + assertEquals(inputNodes, nn.getInputNodes()); + assertEquals(hiddenLayers, nn.getHiddenLayers()); + assertEquals(hiddenNodes, nn.getHiddenNodes()); + assertEquals(outputNodes, nn.getOutputNodes()); + assertEquals(nn.getActivationFunctionName(), activationFunction); + assertEquals(nn.getLearningRate(), learningRate, 0.0); + } + +} \ No newline at end of file From 7d1480f66bde1db3581b7aa973c26545ea09da71 Mon Sep 17 00:00:00 2001 From: michalwa Date: Fri, 8 Jun 2018 20:34:00 +0200 Subject: [PATCH 2/2] Simplified activation functions --- .../ActivationFunction.java | 40 ++++++++++++---- .../ReLuActivationFunction.java | 44 ++++-------------- .../SigmoidActivationFunction.java | 46 ++++--------------- .../TanhActivationFunction.java | 45 ++++-------------- 4 files changed, 60 insertions(+), 115 deletions(-) diff --git a/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunction.java b/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunction.java index 17ccdc0..95e5e76 100644 --- a/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunction.java +++ b/src/main/java/basicneuralnetwork/activationfunctions/ActivationFunction.java @@ -5,18 +5,38 @@ /** * Created by KimFeichtinger on 20.04.18. */ -// This interface and it's methods have to be implemented in all ActivationFunction-classes -public interface ActivationFunction { - - String SIGMOID = SigmoidActivationFunction.NAME; - String TANH = TanhActivationFunction.NAME; - String RELU = ReLuActivationFunction.NAME; +public abstract class ActivationFunction { // Activation function - SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input); + public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) { + SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols()); + for (int i = 0; i < input.numRows(); i++) { + for (int j = 0; j < input.numCols(); j ++) { + double value = input.get(i, j); + output.set(i, j, apply(value)); + } + } + return output; + } // Derivative of activation function (not real derivative because Activation function has already been applied to the input) - SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input); - - String getName(); + public SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input) { + SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols()); + for (int i = 0; i < input.numRows(); i++) { + for (int j = 0; j < input.numCols(); j ++) { + double value = input.get(i, j); + output.set(i, j, applyDerivative(value)); + } + } + return output; + } + + /** Applies the function to a single value */ + protected abstract double apply(double value); + + /** Applies the pseudo-derivative of the function to a single value */ + protected abstract double applyDerivative(double value); + + /** Returns the name of the function */ + public abstract String getName(); } diff --git a/src/main/java/basicneuralnetwork/activationfunctions/ReLuActivationFunction.java b/src/main/java/basicneuralnetwork/activationfunctions/ReLuActivationFunction.java index 9e2a802..70ca25c 100644 --- a/src/main/java/basicneuralnetwork/activationfunctions/ReLuActivationFunction.java +++ b/src/main/java/basicneuralnetwork/activationfunctions/ReLuActivationFunction.java @@ -1,50 +1,26 @@ package basicneuralnetwork.activationfunctions; -import org.ejml.simple.SimpleMatrix; - /** * Created by KimFeichtinger on 26.04.18. */ -public class ReLuActivationFunction implements ActivationFunction { +public class ReLuActivationFunction extends ActivationFunction { public static final String NAME = "relu"; static { ActivationFunctionFactory.register(NAME, ReLuActivationFunction::new); } - - public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) { - SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols()); - - for (int i = 0; i < input.numRows(); i++) { - // Column is always 0 because input has only one column - double value = input.get(i, 0); - double result = value > 0 ? value : 0; - - output.set(i, 0, result); - } - - // Formula: - // for input < 0: 0, else input - return output; + + @Override + protected double apply(double value) { + return value > 0 ? value : 0; } - - public SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input) { - SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols()); - - for (int i = 0; i < input.numRows(); i++) { - // Column is always 0 because input has only one column - double value = input.get(i, 0); - double result = value > 0 ? 1 : 0; - - output.set(i, 0, result); - } - - // Formula: - // for input > 0: 1, else 0 - return output; + + @Override + protected double applyDerivative(double value) { + return value > 0 ? 1 : 0; } - + public String getName() { return NAME; } diff --git a/src/main/java/basicneuralnetwork/activationfunctions/SigmoidActivationFunction.java b/src/main/java/basicneuralnetwork/activationfunctions/SigmoidActivationFunction.java index 117b363..6da8d92 100644 --- a/src/main/java/basicneuralnetwork/activationfunctions/SigmoidActivationFunction.java +++ b/src/main/java/basicneuralnetwork/activationfunctions/SigmoidActivationFunction.java @@ -1,52 +1,26 @@ package basicneuralnetwork.activationfunctions; -import org.ejml.simple.SimpleMatrix; - /** * Created by KimFeichtinger on 20.04.18. */ -public class SigmoidActivationFunction implements ActivationFunction { +public class SigmoidActivationFunction extends ActivationFunction { public static final String NAME = "sigmoid"; static { ActivationFunctionFactory.register(NAME, SigmoidActivationFunction::new); } - - // Sigmoid - public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) { - SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols()); - - for (int i = 0; i < input.numRows(); i++) { - // Column is always 0 because input has only one column - double value = input.get(i, 0); - double result = 1 / (1 + Math.exp(-value)); - - output.set(i, 0, result); - } - - // Formula: - // 1 / (1 + Math.exp(-input)); - return output; + + @Override + protected double apply(double value) { + return 1 / (1 + Math.exp(-value)); } - - // Derivative of Sigmoid (not real derivative because Activation function has already been applied to the input) - public SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input) { - SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols()); - - for (int i = 0; i < input.numRows(); i++) { - // Column is always 0 because input has only one column - double value = input.get(i, 0); - double result = value * (1 - value); - - output.set(i, 0, result); - } - - // Formula: - // input * (1 - input); - return output; + + @Override + protected double applyDerivative(double value) { + return value * (1 - value); } - + public String getName() { return NAME; } diff --git a/src/main/java/basicneuralnetwork/activationfunctions/TanhActivationFunction.java b/src/main/java/basicneuralnetwork/activationfunctions/TanhActivationFunction.java index 8a54744..2d87384 100644 --- a/src/main/java/basicneuralnetwork/activationfunctions/TanhActivationFunction.java +++ b/src/main/java/basicneuralnetwork/activationfunctions/TanhActivationFunction.java @@ -1,51 +1,26 @@ package basicneuralnetwork.activationfunctions; -import org.ejml.simple.SimpleMatrix; - /** * Created by KimFeichtinger on 20.04.18. */ -public class TanhActivationFunction implements ActivationFunction { +public class TanhActivationFunction extends ActivationFunction { public static final String NAME = "tanh"; static { ActivationFunctionFactory.register(NAME, TanhActivationFunction::new); } - - public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) { - SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols()); - - for (int i = 0; i < input.numRows(); i++) { - // Column is always 0 because input has only one column - double value = input.get(i, 0); - double result = Math.tanh(value); - - output.set(i, 0, result); - } - - // Formula: - // 2 * (1 / (1 + Math.exp(2 * -input))) - 1; - // Math.tanh(input); - return output; + + @Override + protected double apply(double value) { + return Math.tanh(value); } - - public SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input) { - SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols()); - - for (int i = 0; i < input.numRows(); i++) { - // Column is always 0 because input has only one column - double value = input.get(i, 0); - double result = 1 - (value * value); - - output.set(i, 0, result); - } - - // Formula: - // 1 - (input * input); - return output; + + @Override + protected double applyDerivative(double value) { + return 1 - (value * value); } - + public String getName() { return NAME; }