Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code Style Improvements #2

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
.idea
.*
*.iml
target
out
artifacts
nn_data.json
src/main/java/basicneuralnetwork/Test.java
src/main/java/basicneuralnetwork/activationfunctions/SoftmaxActivationFunction.java
nn_data.json
nn_data.json
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@
<artifactId>gson</artifactId>
<version>2.8.4</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.12</version>
<scope>test</scope>
</dependency>

</dependencies>

Expand Down
10 changes: 2 additions & 8 deletions src/main/java/basicneuralnetwork/NeuralNetwork.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
*/
public class NeuralNetwork {

private ActivationFunctionFactory activationFunctionFactory = new ActivationFunctionFactory();

private Random random = new Random();

// Dimensions of the neural network
Expand Down Expand Up @@ -114,7 +112,7 @@ public double[] guess(double[] input) {
throw new WrongDimensionException(input.length, inputNodes, "Input");
} else {
// Get ActivationFunction-object from the map by key
ActivationFunction activationFunction = activationFunctionFactory.getActivationFunctionByKey(activationFunctionKey);
ActivationFunction activationFunction = ActivationFunctionFactory.createByName(activationFunctionKey);

// Transform array to matrix
SimpleMatrix output = MatrixUtilities.arrayToMatrix(input);
Expand All @@ -134,7 +132,7 @@ public void train(double[] inputArray, double[] targetArray) {
throw new WrongDimensionException(targetArray.length, outputNodes, "Output");
} else {
// Get ActivationFunction-object from the map by key
ActivationFunction activationFunction = activationFunctionFactory.getActivationFunctionByKey(activationFunctionKey);
ActivationFunction activationFunction = ActivationFunctionFactory.createByName(activationFunctionKey);

// Transform 2D array to matrix
SimpleMatrix input = MatrixUtilities.arrayToMatrix(inputArray);
Expand Down Expand Up @@ -270,10 +268,6 @@ public void setActivationFunction(String activationFunction) {
this.activationFunctionKey = activationFunction;
}

public void addActivationFunction(String key, ActivationFunction activationFunction){
activationFunctionFactory.addActivationFunction(key, activationFunction);
}

public double getLearningRate() {
return learningRate;
}
Expand Down
58 changes: 58 additions & 0 deletions src/main/java/basicneuralnetwork/NeuralNetworkBuilder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package basicneuralnetwork;

/**
* Created by MichalWa on 08.06.18
*/
public class NeuralNetworkBuilder {

private int inputNodes = 0;
private int hiddenLayers = 0;
private int hiddenNodes = 0;
private int outputNodes = 0;
private String activationFunction = null;
private double learningRate = -1.0;

public NeuralNetworkBuilder setInputNodes(int inputNodes) {
this.inputNodes = inputNodes;
return this;
}

public NeuralNetworkBuilder setHiddenLayers(int hiddenLayers) {
this.hiddenLayers = hiddenLayers;
return this;
}

public NeuralNetworkBuilder setHiddenNodes(int hiddenNodes) {
this.hiddenNodes = hiddenNodes;
return this;
}

public NeuralNetworkBuilder setOutputNodes(int outputNodes) {
this.outputNodes = outputNodes;
return this;
}

public NeuralNetworkBuilder setActivationFunction(String activationFunction) {
this.activationFunction = activationFunction;
return this;
}

public NeuralNetworkBuilder setLearningRate(double learningRate) {
this.learningRate = learningRate;
return this;
}

public NeuralNetwork create() {
if(inputNodes < 1) throw new IllegalStateException("There must be 1 or more input nodes.");
if(hiddenNodes < 1) throw new IllegalStateException("There must be 1 or more hidden nodes.");
if(outputNodes < 1) throw new IllegalStateException("There must be 1 or more output nodes");

NeuralNetwork nn = new NeuralNetwork(inputNodes, hiddenLayers, hiddenNodes, outputNodes);

if(activationFunction != null) nn.setActivationFunction(activationFunction);
if(learningRate != -1.0) nn.setLearningRate(learningRate);

return nn;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,38 @@
/**
* Created by KimFeichtinger on 20.04.18.
*/
// This interface and it's methods have to be implemented in all ActivationFunction-classes
public interface ActivationFunction {

String SIGMOID = "SIGMOID";
String TANH = "TANH";
String RELU = "RELU";
public abstract class ActivationFunction {

// Activation function
SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input);
public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) {
SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
for (int i = 0; i < input.numRows(); i++) {
for (int j = 0; j < input.numCols(); j ++) {
double value = input.get(i, j);
output.set(i, j, apply(value));
}
}
return output;
}

// Derivative of activation function (not real derivative because Activation function has already been applied to the input)
SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input);

String getName();

public SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input) {
SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
for (int i = 0; i < input.numRows(); i++) {
for (int j = 0; j < input.numCols(); j ++) {
double value = input.get(i, j);
output.set(i, j, applyDerivative(value));
}
}
return output;
}

/** Applies the function to a single value */
protected abstract double apply(double value);

/** Applies the pseudo-derivative of the function to a single value */
protected abstract double applyDerivative(double value);

/** Returns the name of the function */
public abstract String getName();
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,21 @@

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

/**
* Created by KimFeichtinger on 04.05.18.
*/
public class ActivationFunctionFactory {

private Map<String, ActivationFunction> activationFunctionMap = new HashMap<>();

public ActivationFunctionFactory () {
// Fill map with all the activation functions
ActivationFunction sigmoid = new SigmoidActivationFunction();
activationFunctionMap.put(sigmoid.getName(), sigmoid);

ActivationFunction tanh = new TanhActivationFunction();
activationFunctionMap.put(tanh.getName(), tanh);

ActivationFunction relu = new ReLuActivationFunction();
activationFunctionMap.put(relu.getName(), relu);
}

public ActivationFunction getActivationFunctionByKey (String activationFunctionKey) {
return activationFunctionMap.get(activationFunctionKey);
private static Map<String, Supplier<ActivationFunction>> factories = new HashMap<>();

public static ActivationFunction createByName(String name) {
return Optional.ofNullable(factories.get(name)).map(Supplier::get).orElse(null);
}

public void addActivationFunction(String key, ActivationFunction activationFunction) {
activationFunctionMap.put(key, activationFunction);
public static void register(String key, Supplier<ActivationFunction> factory) {
factories.put(key, factory);
}
}
Original file line number Diff line number Diff line change
@@ -1,46 +1,26 @@
package basicneuralnetwork.activationfunctions;

import org.ejml.simple.SimpleMatrix;

/**
* Created by KimFeichtinger on 26.04.18.
*/
public class ReLuActivationFunction implements ActivationFunction {

private static final String NAME = "RELU";

public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) {
SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
public class ReLuActivationFunction extends ActivationFunction {

for (int i = 0; i < input.numRows(); i++) {
// Column is always 0 because input has only one column
double value = input.get(i, 0);
double result = value > 0 ? value : 0;

output.set(i, 0, result);
}

// Formula:
// for input < 0: 0, else input
return output;
public static final String NAME = "relu";

static {
ActivationFunctionFactory.register(NAME, ReLuActivationFunction::new);
}

public SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input) {
SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());

for (int i = 0; i < input.numRows(); i++) {
// Column is always 0 because input has only one column
double value = input.get(i, 0);
double result = value > 0 ? 1 : 0;

output.set(i, 0, result);
}

// Formula:
// for input > 0: 1, else 0
return output;

@Override
protected double apply(double value) {
return value > 0 ? value : 0;
}


@Override
protected double applyDerivative(double value) {
return value > 0 ? 1 : 0;
}

public String getName() {
return NAME;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,48 +1,26 @@
package basicneuralnetwork.activationfunctions;

import org.ejml.simple.SimpleMatrix;

/**
* Created by KimFeichtinger on 20.04.18.
*/
public class SigmoidActivationFunction implements ActivationFunction {

private static final String NAME = "SIGMOID";

// Sigmoid
public SimpleMatrix applyActivationFunctionToMatrix(SimpleMatrix input) {
SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());
public class SigmoidActivationFunction extends ActivationFunction {

for (int i = 0; i < input.numRows(); i++) {
// Column is always 0 because input has only one column
double value = input.get(i, 0);
double result = 1 / (1 + Math.exp(-value));

output.set(i, 0, result);
}

// Formula:
// 1 / (1 + Math.exp(-input));
return output;
public static final String NAME = "sigmoid";

static {
ActivationFunctionFactory.register(NAME, SigmoidActivationFunction::new);
}

// Derivative of Sigmoid (not real derivative because Activation function has already been applied to the input)
public SimpleMatrix applyDerivativeOfActivationFunctionToMatrix(SimpleMatrix input) {
SimpleMatrix output = new SimpleMatrix(input.numRows(), input.numCols());

for (int i = 0; i < input.numRows(); i++) {
// Column is always 0 because input has only one column
double value = input.get(i, 0);
double result = value * (1 - value);

output.set(i, 0, result);
}

// Formula:
// input * (1 - input);
return output;

@Override
protected double apply(double value) {
return 1 / (1 + Math.exp(-value));
}


@Override
protected double applyDerivative(double value) {
return value * (1 - value);
}

public String getName() {
return NAME;
}
Expand Down
Loading