-
Notifications
You must be signed in to change notification settings - Fork 0
/
Layer.java
31 lines (23 loc) · 863 Bytes
/
Layer.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import java.util.function.Function;
public class Layer {
public int size;
public Function<JMatrix, JMatrix> activation;
public Function<JMatrix, JMatrix> derivative;
public Layer(int size, Function<JMatrix, JMatrix> activation, Function<JMatrix, JMatrix> derivative) {
this.size = size;
this.activation = activation != null ? activation : this::none;
this.derivative = derivative != null ? derivative : this::noneDerivative;
}
public JMatrix none(JMatrix x) {
return x;
}
public JMatrix noneDerivative(JMatrix x) {
return JMatrix.zeros(x.getHeight(), x.getWidth());
}
public Function<JMatrix, JMatrix> getActivation() {
return activation;
}
public Function<JMatrix, JMatrix> getDerivative() {
return derivative;
}
}