mirror of
https://github.com/jneug/zeichenmaschine.git
synced 2026-04-14 06:33:34 +02:00
Erste Implementation eines einfachen neuronalen Netzes
Vorbild zur Implementation: https://github.com/wheresvic/neuralnet
This commit is contained in:
152
src/schule/ngb/zm/ml/MLMath.java
Normal file
152
src/schule/ngb/zm/ml/MLMath.java
Normal file
@@ -0,0 +1,152 @@
|
||||
package schule.ngb.zm.ml;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.function.DoubleUnaryOperator;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
// See https://github.com/wheresvic/neuralnet
|
||||
public final class MLMath {
|
||||
|
||||
public static double sigmoid( double x ) {
|
||||
return 1 / (1 + Math.exp(-x));
|
||||
}
|
||||
|
||||
public static double sigmoidDerivative( double x ) {
|
||||
return x * (1 - x);
|
||||
}
|
||||
|
||||
public static double tanh( double x ) {
|
||||
return Math.tanh(x);
|
||||
}
|
||||
|
||||
public static double tanhDerivative( double x ) {
|
||||
return 1 - Math.tanh(x) * Math.tanh(x);
|
||||
}
|
||||
|
||||
|
||||
public static double[] normalize( double[] vector ) {
|
||||
final double sum = Arrays.stream(vector).sum();
|
||||
return Arrays.stream(vector).map(( d ) -> d / sum).toArray();
|
||||
}
|
||||
|
||||
public static double[][] matrixMultiply( double[][] A, double[][] B ) {
|
||||
int a = A.length, b = A[0].length, c = B[0].length;
|
||||
if( B.length != b ) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format("Matrix A needs equal columns to matrix B rows. (Currently <%d> vs <%d>)", a, B.length)
|
||||
);
|
||||
}
|
||||
|
||||
return IntStream.range(0, a).parallel().mapToObj(
|
||||
( i ) -> IntStream.range(0, c).mapToDouble(
|
||||
( j ) -> IntStream.range(0, b).mapToDouble(
|
||||
( k ) -> A[i][k] * B[k][j]
|
||||
).sum()
|
||||
).toArray()
|
||||
).toArray(double[][]::new);
|
||||
}
|
||||
|
||||
public static double[][] matrixScale( final double[][] A, final double[][] S ) {
|
||||
if( A.length != S.length || A[0].length != S[0].length ) {
|
||||
throw new IllegalArgumentException("Matrices need to be same size.");
|
||||
}
|
||||
|
||||
return IntStream.range(0, A.length).parallel().mapToObj(
|
||||
( i ) -> IntStream.range(0, A[i].length).mapToDouble(
|
||||
( j ) -> A[i][j] * S[i][j]
|
||||
).toArray()
|
||||
).toArray(double[][]::new);
|
||||
}
|
||||
|
||||
public static double[][] matrixSub( double[][] A, double[][] B ) {
|
||||
if( A.length != B.length || A[0].length != B[0].length ) {
|
||||
throw new IllegalArgumentException("Cannot subtract unequal matrices");
|
||||
}
|
||||
|
||||
return IntStream.range(0, A.length).parallel().mapToObj(
|
||||
( i ) -> IntStream.range(0, A[i].length).mapToDouble(
|
||||
( j ) -> A[i][j] - B[i][j]
|
||||
).toArray()
|
||||
).toArray(double[][]::new);
|
||||
}
|
||||
|
||||
public static double[][] matrixAdd( double[][] A, double[][] B ) {
|
||||
if( A.length != B.length || A[0].length != B[0].length ) {
|
||||
throw new IllegalArgumentException("Cannot add unequal matrices");
|
||||
}
|
||||
|
||||
return IntStream.range(0, A.length).parallel().mapToObj(
|
||||
( i ) -> IntStream.range(0, A[i].length).mapToDouble(
|
||||
( j ) -> A[i][j] + B[i][j]
|
||||
).toArray()
|
||||
).toArray(double[][]::new);
|
||||
}
|
||||
|
||||
public static double[][] matrixTranspose( double[][] matrix ) {
|
||||
int a = matrix.length, b = matrix[0].length;
|
||||
|
||||
double[][] result = new double[matrix[0].length][matrix.length];
|
||||
for( int i = 0; i < a; i++ ) {
|
||||
for( int j = 0; j < b; ++j ) {
|
||||
result[j][i] = matrix[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public static double[][] matrixApply( double[][] A, DoubleUnaryOperator op ) {
|
||||
return Arrays.stream(A).parallel().map(
|
||||
( arr ) -> Arrays.stream(arr).map(op).toArray()
|
||||
).toArray(double[][]::new);
|
||||
}
|
||||
|
||||
public static double[][] copyMatrix( double[][] matrix ) {
|
||||
/*return Arrays.stream(matrix).map(
|
||||
(arr) -> Arrays.copyOf(arr, arr.length)
|
||||
).toArray(double[][]::new);*/
|
||||
|
||||
double[][] result = new double[matrix.length][matrix[0].length];
|
||||
for( int i = 0; i < matrix.length; i++ ) {
|
||||
result[i] = Arrays.copyOf(matrix[i], matrix[i].length);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public static double[] toVector( double[][] matrix ) {
|
||||
return Arrays.stream(matrix).mapToDouble(
|
||||
( arr ) -> arr[0]
|
||||
).toArray();
|
||||
}
|
||||
|
||||
public static double[][] toMatrix( double[] vector ) {
|
||||
return Arrays.stream(vector).mapToObj(
|
||||
( d ) -> new double[]{d}
|
||||
).toArray(double[][]::new);
|
||||
}
|
||||
|
||||
public static double entropy(double[][] A, double[][] Y, int batch_size) {
|
||||
int m = A.length;
|
||||
int n = A[0].length;
|
||||
double[][] z = new double[m][n];
|
||||
|
||||
for (int i = 0; i < m; i++) {
|
||||
for (int j = 0; j < n; j++) {
|
||||
z[i][j] = (Y[i][j] * Math.log(A[i][j])) + ((1 - Y[i][j]) * Math.log(1 - A[i][j]));
|
||||
}
|
||||
}
|
||||
|
||||
double sum = 0;
|
||||
for (int i = 0; i < m; i++) {
|
||||
for (int j = 0; j < n; j++) {
|
||||
sum += z[i][j];
|
||||
}
|
||||
}
|
||||
return -sum / batch_size;
|
||||
}
|
||||
|
||||
private MLMath() {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
78
src/schule/ngb/zm/ml/Matrix.java
Normal file
78
src/schule/ngb/zm/ml/Matrix.java
Normal file
@@ -0,0 +1,78 @@
|
||||
package schule.ngb.zm.ml;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
public class Matrix {
|
||||
|
||||
private int columns, rows;
|
||||
|
||||
double[][] coefficients;
|
||||
|
||||
public Matrix( int rows, int cols ) {
|
||||
this.rows = rows;
|
||||
this.columns = cols;
|
||||
coefficients = new double[rows][cols];
|
||||
}
|
||||
|
||||
public Matrix( double[][] coefficients ) {
|
||||
this.coefficients = coefficients;
|
||||
this.rows = coefficients.length;
|
||||
this.columns = coefficients[0].length;
|
||||
}
|
||||
|
||||
public int getColumns() {
|
||||
return columns;
|
||||
}
|
||||
|
||||
public int getRows() {
|
||||
return rows;
|
||||
}
|
||||
|
||||
public double[][] getCoefficients() {
|
||||
return coefficients;
|
||||
}
|
||||
|
||||
public double get( int row, int col ) {
|
||||
return coefficients[row][col];
|
||||
}
|
||||
|
||||
public void initializeRandom() {
|
||||
coefficients = MLMath.matrixApply(coefficients, (d) -> Math.random());
|
||||
}
|
||||
|
||||
public void initializeRandom( double lower, double upper ) {
|
||||
coefficients = MLMath.matrixApply(coefficients, (d) -> ((upper-lower) * Math.random()) + lower);
|
||||
}
|
||||
|
||||
public void initializeIdentity() {
|
||||
initializeZero();
|
||||
for( int i = 0; i < Math.min(rows, columns); i++ ) {
|
||||
this.coefficients[i][i] = 1.0;
|
||||
}
|
||||
}
|
||||
|
||||
public void initializeOne() {
|
||||
coefficients = MLMath.matrixApply(coefficients, (d) -> 1.0);
|
||||
}
|
||||
|
||||
public void initializeZero() {
|
||||
coefficients = MLMath.matrixApply(coefficients, (d) -> 0.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
//return Arrays.deepToString(coefficients);
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append('[');
|
||||
sb.append('\n');
|
||||
for( int i = 0; i < coefficients.length; i++ ) {
|
||||
sb.append('\t');
|
||||
sb.append(Arrays.toString(coefficients[i]));
|
||||
sb.append('\n');
|
||||
}
|
||||
sb.append(']');
|
||||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
}
|
||||
84
src/schule/ngb/zm/ml/NeuralNetwork.java
Normal file
84
src/schule/ngb/zm/ml/NeuralNetwork.java
Normal file
@@ -0,0 +1,84 @@
|
||||
package schule.ngb.zm.ml;
|
||||
|
||||
import schule.ngb.zm.util.Log;
|
||||
|
||||
public class NeuralNetwork {
|
||||
|
||||
private NeuronLayer[] layers;
|
||||
|
||||
private double[][] output;
|
||||
|
||||
private double learningRate = 0.1;
|
||||
|
||||
public NeuralNetwork( int inputs, int layer1, int outputs ) {
|
||||
this(new NeuronLayer(layer1, inputs), new NeuronLayer(outputs, layer1));
|
||||
}
|
||||
|
||||
public NeuralNetwork( int inputs, int layer1, int layer2, int outputs ) {
|
||||
this(new NeuronLayer(layer1, inputs), new NeuronLayer(layer2, layer1), new NeuronLayer(outputs, layer2));
|
||||
}
|
||||
|
||||
public NeuralNetwork( NeuronLayer layer1, NeuronLayer layer2 ) {
|
||||
this.layers = new NeuronLayer[2];
|
||||
this.layers[0] = layer1;
|
||||
this.layers[1] = layer2;
|
||||
layer1.connect(null, layer2);
|
||||
layer2.connect(layer1, null);
|
||||
}
|
||||
|
||||
public NeuralNetwork( NeuronLayer layer1, NeuronLayer layer2, NeuronLayer layer3 ) {
|
||||
this.layers = new NeuronLayer[3];
|
||||
this.layers[0] = layer1;
|
||||
this.layers[1] = layer2;
|
||||
this.layers[2] = layer3;
|
||||
layer1.connect(null, layer2);
|
||||
layer2.connect(layer1, layer3);
|
||||
layer3.connect(layer2, null);
|
||||
}
|
||||
|
||||
public int getLayerCount() {
|
||||
return layers.length;
|
||||
}
|
||||
|
||||
public NeuronLayer getLayer( int i ) {
|
||||
return layers[i - 1];
|
||||
}
|
||||
|
||||
public double getLearningRate() {
|
||||
return learningRate;
|
||||
}
|
||||
|
||||
public void setLearningRate( double pLearningRate ) {
|
||||
this.learningRate = pLearningRate;
|
||||
}
|
||||
|
||||
public double[][] getOutput() {
|
||||
return output;
|
||||
}
|
||||
|
||||
public double[][] predict( double[][] inputs ) {
|
||||
//this.output = layers[1].apply(layers[0].apply(inputs));
|
||||
this.output = layers[0].apply(inputs);
|
||||
return this.output;
|
||||
}
|
||||
|
||||
public void learn( double[][] expected ) {
|
||||
layers[layers.length-1].backprop(expected, learningRate);
|
||||
}
|
||||
|
||||
public void train( double[][] inputs, double[][] expected, int iterations/*, double minChange, int timeout */ ) {
|
||||
for( int i = 0; i < iterations; i++ ) {
|
||||
// pass the training set through the network
|
||||
predict(inputs);
|
||||
// start backpropagation through all layers
|
||||
learn(expected);
|
||||
|
||||
if( i % 10000 == 0 ) {
|
||||
LOG.trace("Training iteration %d of %d", i, iterations);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static final Log LOG = Log.getLogger(NeuralNetwork.class);
|
||||
|
||||
}
|
||||
127
src/schule/ngb/zm/ml/NeuronLayer.java
Normal file
127
src/schule/ngb/zm/ml/NeuronLayer.java
Normal file
@@ -0,0 +1,127 @@
|
||||
package schule.ngb.zm.ml;
|
||||
|
||||
import java.util.function.DoubleUnaryOperator;
|
||||
import java.util.function.Function;
|
||||
|
||||
public class NeuronLayer implements Function<double[][], double[][]> {
|
||||
|
||||
Matrix weights;
|
||||
|
||||
NeuronLayer previous, next;
|
||||
|
||||
DoubleUnaryOperator activationFunction, activationFunctionDerivative;
|
||||
|
||||
double[][] lastOutput, lastInput;
|
||||
|
||||
public NeuronLayer( int neurons, int inputs ) {
|
||||
weights = new Matrix(inputs, neurons);
|
||||
weights.initializeRandom(-1, 1);
|
||||
|
||||
activationFunction = MLMath::sigmoid;
|
||||
activationFunctionDerivative = MLMath::sigmoidDerivative;
|
||||
}
|
||||
|
||||
public void connect( NeuronLayer prev, NeuronLayer next ) {
|
||||
setPreviousLayer(prev);
|
||||
setNextLayer(next);
|
||||
}
|
||||
|
||||
public NeuronLayer getPreviousLayer() {
|
||||
return previous;
|
||||
}
|
||||
|
||||
public boolean hasPreviousLayer() {
|
||||
return previous != null;
|
||||
}
|
||||
|
||||
public void setPreviousLayer( NeuronLayer pPreviousLayer ) {
|
||||
this.previous = pPreviousLayer;
|
||||
if( pPreviousLayer != null ) {
|
||||
pPreviousLayer.next = this;
|
||||
}
|
||||
}
|
||||
|
||||
public NeuronLayer getNextLayer() {
|
||||
return next;
|
||||
}
|
||||
|
||||
public boolean hasNextLayer() {
|
||||
return next != null;
|
||||
}
|
||||
|
||||
public void setNextLayer( NeuronLayer pNextLayer ) {
|
||||
this.next = pNextLayer;
|
||||
if( pNextLayer != null ) {
|
||||
pNextLayer.previous = this;
|
||||
}
|
||||
}
|
||||
|
||||
public Matrix getWeights() {
|
||||
return weights;
|
||||
}
|
||||
|
||||
public int getNeuronCount() {
|
||||
return weights.coefficients[0].length;
|
||||
}
|
||||
|
||||
public int getInputCount() {
|
||||
return weights.coefficients.length;
|
||||
}
|
||||
|
||||
public double[][] getLastOutput() {
|
||||
return lastOutput;
|
||||
}
|
||||
|
||||
public void setWeights( double[][] newWeights ) {
|
||||
weights.coefficients = MLMath.copyMatrix(newWeights);
|
||||
}
|
||||
|
||||
public void adjustWeights( double[][] adjustment ) {
|
||||
weights.coefficients = MLMath.matrixAdd(weights.coefficients, adjustment);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return weights.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[][] apply( double[][] inputs ) {
|
||||
lastInput = inputs;
|
||||
lastOutput = MLMath.matrixApply(MLMath.matrixMultiply(inputs, weights.coefficients), activationFunction);
|
||||
if( next != null ) {
|
||||
return next.apply(lastOutput);
|
||||
} else {
|
||||
return lastOutput;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public <V> Function<V, double[][]> compose( Function<? super V, ? extends double[][]> before ) {
|
||||
return ( in ) -> apply(before.apply(in));
|
||||
}
|
||||
|
||||
@Override
|
||||
public <V> Function<double[][], V> andThen( Function<? super double[][], ? extends V> after ) {
|
||||
return ( in ) -> after.apply(apply(in));
|
||||
}
|
||||
|
||||
public void backprop( double[][] expected, double learningRate ) {
|
||||
double[][] error, delta, adjustment;
|
||||
if( next == null ) {
|
||||
error = MLMath.matrixSub(expected, this.lastOutput);
|
||||
} else {
|
||||
error = MLMath.matrixMultiply(expected, MLMath.matrixTranspose(next.weights.coefficients));
|
||||
}
|
||||
|
||||
delta = MLMath.matrixScale(error, MLMath.matrixApply(this.lastOutput,this.activationFunctionDerivative));
|
||||
if( previous != null ) {
|
||||
previous.backprop(delta, learningRate);
|
||||
}
|
||||
|
||||
adjustment = MLMath.matrixMultiply(MLMath.matrixTranspose(lastInput), delta);
|
||||
adjustment = MLMath.matrixApply(adjustment, ( x ) -> learningRate * x);
|
||||
this.adjustWeights(adjustment);
|
||||
}
|
||||
|
||||
}
|
||||
101
test/src/schule/ngb/zm/ml/MLMathTest.java
Normal file
101
test/src/schule/ngb/zm/ml/MLMathTest.java
Normal file
@@ -0,0 +1,101 @@
|
||||
package schule.ngb.zm.ml;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class MLMathTest {
|
||||
|
||||
@Test
|
||||
void matrixMultiply() {
|
||||
double[][] A = new double[][]{
|
||||
{1.0, 2.0, 3.0},
|
||||
{-5.0, -4.0, -3.0},
|
||||
{0.0, -10.0, 10.0}
|
||||
};
|
||||
double[][] B = new double[][]{
|
||||
{0.0, 1.0},
|
||||
{2.0, -2.0},
|
||||
{5.0, -10.0}
|
||||
};
|
||||
double[][] result = MLMath.matrixMultiply(A, B);
|
||||
|
||||
assertNotNull(result);
|
||||
assertEquals(A.length, result.length);
|
||||
assertEquals(B[0].length, result[0].length);
|
||||
assertArrayEquals(new double[]{19.0, -33.0}, result[0]);
|
||||
assertArrayEquals(new double[]{-23.0, 33.0}, result[1]);
|
||||
assertArrayEquals(new double[]{30.0, -80.0}, result[2]);
|
||||
|
||||
assertThrowsExactly(IllegalArgumentException.class, () -> MLMath.matrixMultiply(B, A));
|
||||
}
|
||||
|
||||
@Test
|
||||
void matrixScale() {
|
||||
double[][] matrix = new double[][]{
|
||||
{1.0, 2.0, 3.0},
|
||||
{-5.0, -4.0, -3.0},
|
||||
{0.0, -10.0, 10.0}
|
||||
};
|
||||
double[][] scalars = new double[][]{
|
||||
{0.0, 1.0, -1.0},
|
||||
{2.0, -2.0, 10.0},
|
||||
{5.0, -10.0, 10.0}
|
||||
};
|
||||
double[][] result = MLMath.matrixScale(matrix, scalars);
|
||||
|
||||
assertNotNull(result);
|
||||
assertNotSame(matrix, result);
|
||||
assertArrayEquals(new double[]{0.0, 2.0, -3.0}, result[0]);
|
||||
assertArrayEquals(new double[]{-10.0, 8.0, -30.0}, result[1]);
|
||||
assertArrayEquals(new double[]{0.0, 100.0, 100.0}, result[2]);
|
||||
}
|
||||
|
||||
@Test
|
||||
void matrixApply() {
|
||||
double[][] matrix = new double[][]{
|
||||
{1.0, 2.0, 3.0},
|
||||
{-5.0, -4.0, -3.0},
|
||||
{0.0, -10.0, 10.0}
|
||||
};
|
||||
double[][] result = MLMath.matrixApply(matrix, (d) -> -1*d);
|
||||
|
||||
assertNotNull(result);
|
||||
assertNotSame(matrix, result);
|
||||
assertArrayEquals(new double[]{-1.0, -2.0, -3.0}, result[0]);
|
||||
assertArrayEquals(new double[]{5.0, 4.0, 3.0}, result[1]);
|
||||
assertArrayEquals(new double[]{-0.0, 10.0, -10.0}, result[2]);
|
||||
}
|
||||
|
||||
@Test
|
||||
void matrixSubtract() {
|
||||
}
|
||||
|
||||
@Test
|
||||
void matrixAdd() {
|
||||
}
|
||||
|
||||
@Test
|
||||
void matrixTranspose() {
|
||||
double[][] matrix = new double[][]{
|
||||
{1.0, 2.0, 3.0, 4.5},
|
||||
{-5.0, -4.0, -3.0, 2.1},
|
||||
{0.0, -10.0, 10.0, 0.9}
|
||||
};
|
||||
double[][] result = MLMath.matrixTranspose(matrix);
|
||||
|
||||
assertNotNull(result);
|
||||
assertEquals(4, result.length);
|
||||
assertEquals(3, result[0].length);
|
||||
|
||||
assertArrayEquals(new double[]{1.0, -5.0, 0.0}, result[0]);
|
||||
assertArrayEquals(new double[]{2.0, -4.0, -10.0}, result[1]);
|
||||
assertArrayEquals(new double[]{3.0, -3.0, 10.0}, result[2]);
|
||||
assertArrayEquals(new double[]{4.5, 2.1, 0.9}, result[3]);
|
||||
}
|
||||
|
||||
@Test
|
||||
void normalize() {
|
||||
}
|
||||
|
||||
}
|
||||
57
test/src/schule/ngb/zm/ml/MatrixTest.java
Normal file
57
test/src/schule/ngb/zm/ml/MatrixTest.java
Normal file
@@ -0,0 +1,57 @@
|
||||
package schule.ngb.zm.ml;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class MatrixTest {
|
||||
|
||||
@Test
|
||||
void initializeIdentity() {
|
||||
Matrix m = new Matrix(4, 4);
|
||||
m.initializeIdentity();
|
||||
|
||||
assertArrayEquals(new double[]{1.0, 0.0, 0.0, 0.0}, m.coefficients[0]);
|
||||
assertArrayEquals(new double[]{0.0, 1.0, 0.0, 0.0}, m.coefficients[1]);
|
||||
assertArrayEquals(new double[]{0.0, 0.0, 1.0, 0.0}, m.coefficients[2]);
|
||||
assertArrayEquals(new double[]{0.0, 0.0, 0.0, 1.0}, m.coefficients[3]);
|
||||
}
|
||||
|
||||
@Test
|
||||
void initializeOne() {
|
||||
Matrix m = new Matrix(4, 4);
|
||||
m.initializeOne();
|
||||
|
||||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0};
|
||||
assertArrayEquals(ones, m.coefficients[0]);
|
||||
assertArrayEquals(ones, m.coefficients[1]);
|
||||
assertArrayEquals(ones, m.coefficients[2]);
|
||||
assertArrayEquals(ones, m.coefficients[3]);
|
||||
}
|
||||
|
||||
@Test
|
||||
void initializeZero() {
|
||||
Matrix m = new Matrix(4, 4);
|
||||
m.initializeZero();
|
||||
|
||||
double[] zeros = new double[]{0.0, 0.0, 0.0, 0.0};
|
||||
assertArrayEquals(zeros, m.coefficients[0]);
|
||||
assertArrayEquals(zeros, m.coefficients[1]);
|
||||
assertArrayEquals(zeros, m.coefficients[2]);
|
||||
assertArrayEquals(zeros, m.coefficients[3]);
|
||||
}
|
||||
|
||||
@Test
|
||||
void initializeRandom() {
|
||||
Matrix m = new Matrix(4, 4);
|
||||
m.initializeRandom(-1, 1);
|
||||
|
||||
assertTrue(Arrays.stream(m.coefficients[0]).allMatch((d) -> -1.0 <= d && d < 1.0));
|
||||
assertTrue(Arrays.stream(m.coefficients[1]).allMatch((d) -> -1.0 <= d && d < 1.0));
|
||||
assertTrue(Arrays.stream(m.coefficients[2]).allMatch((d) -> -1.0 <= d && d < 1.0));
|
||||
assertTrue(Arrays.stream(m.coefficients[3]).allMatch((d) -> -1.0 <= d && d < 1.0));
|
||||
}
|
||||
|
||||
}
|
||||
171
test/src/schule/ngb/zm/ml/NeuralNetworkTest.java
Normal file
171
test/src/schule/ngb/zm/ml/NeuralNetworkTest.java
Normal file
@@ -0,0 +1,171 @@
|
||||
package schule.ngb.zm.ml;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import schule.ngb.zm.util.Log;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
class NeuralNetworkTest {
|
||||
|
||||
@Test
|
||||
void learnCalc() {
|
||||
Log.enableGlobalDebugging();
|
||||
|
||||
int INPUT_SIZE = 50;
|
||||
int PREDICT_SIZE = 4;
|
||||
int TRAINING_CYCLES = 40000;
|
||||
CalcType OPERATION = CalcType.SUB;
|
||||
|
||||
// Create neural network with layer1: 4 neurones, layer2: 1 neuron
|
||||
NeuralNetwork net = new NeuralNetwork(2, 8, 4, 1);
|
||||
|
||||
List<TestData> trainingData = createTrainingSet(INPUT_SIZE, OPERATION);
|
||||
|
||||
double[][] inputs = new double[INPUT_SIZE][2];
|
||||
double[][] outputs = new double[INPUT_SIZE][1];
|
||||
for( int i = 0; i < trainingData.size(); i++ ) {
|
||||
inputs[i][0] = trainingData.get(i).a;
|
||||
inputs[i][1] = trainingData.get(i).b;
|
||||
outputs[i][0] = trainingData.get(i).result;
|
||||
}
|
||||
|
||||
System.out.println("Training the neural net to learn "+OPERATION+"...");
|
||||
net.train(inputs, outputs, TRAINING_CYCLES);
|
||||
System.out.println(" finished training");
|
||||
|
||||
for( int i = 1; i <= net.getLayerCount(); i++ ) {
|
||||
System.out.println("Layer " +i + " weights");
|
||||
System.out.println(net.getLayer(i).weights);
|
||||
}
|
||||
|
||||
// calculate the predictions on unknown data
|
||||
List<TestData> predictionSet = createTrainingSet(PREDICT_SIZE, OPERATION);
|
||||
for( TestData t : predictionSet ) {
|
||||
predict(t, net);
|
||||
}
|
||||
}
|
||||
|
||||
public static void predict( TestData data, NeuralNetwork net ) {
|
||||
double[][] testInput = new double[][]{{data.a, data.b}};
|
||||
net.predict(testInput);
|
||||
|
||||
// then
|
||||
System.out.printf(
|
||||
"Prediction on data (%.2f, %.2f) was %.4f, expected %.2f (of by %.4f)\n",
|
||||
data.a, data.b,
|
||||
net.getOutput()[0][0],
|
||||
data.result,
|
||||
net.getOutput()[0][0] - data.result
|
||||
);
|
||||
}
|
||||
|
||||
private List<TestData> createTrainingSet( int trainingSetSize, CalcType operation ) {
|
||||
Random random = new Random();
|
||||
List<TestData> tuples = new ArrayList<>();
|
||||
|
||||
for( int i = 0; i < trainingSetSize; i++ ) {
|
||||
double s1 = random.nextDouble() * 0.5;
|
||||
double s2 = random.nextDouble() * 0.5;
|
||||
|
||||
switch( operation ) {
|
||||
case ADD:
|
||||
tuples.add(new AddData(s1, s2));
|
||||
break;
|
||||
case SUB:
|
||||
tuples.add(new SubData(s1, s2));
|
||||
break;
|
||||
case MUL:
|
||||
tuples.add(new MulData(s1, s2));
|
||||
break;
|
||||
case DIV:
|
||||
tuples.add(new DivData(s1, s2));
|
||||
break;
|
||||
case MOD:
|
||||
tuples.add(new ModData(s1, s2));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return tuples;
|
||||
}
|
||||
|
||||
|
||||
private static enum CalcType {
|
||||
ADD, SUB, MUL, DIV, MOD
|
||||
}
|
||||
|
||||
private static abstract class TestData {
|
||||
|
||||
double a;
|
||||
double b;
|
||||
double result;
|
||||
CalcType type;
|
||||
|
||||
TestData( double a, double b ) {
|
||||
this.a = a;
|
||||
this.b = b;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static final class AddData extends TestData {
|
||||
|
||||
CalcType type = CalcType.ADD;
|
||||
|
||||
public AddData( double a, double b ) {
|
||||
super(a, b);
|
||||
result = a + b;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static final class SubData extends TestData {
|
||||
|
||||
CalcType type = CalcType.SUB;
|
||||
|
||||
public SubData( double a, double b ) {
|
||||
super(a, b);
|
||||
result = a - b;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static final class MulData extends TestData {
|
||||
|
||||
CalcType type = CalcType.MUL;
|
||||
|
||||
public MulData( double a, double b ) {
|
||||
super(a, b);
|
||||
result = a * b;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static final class DivData extends TestData {
|
||||
|
||||
CalcType type = CalcType.DIV;
|
||||
|
||||
public DivData( double a, double b ) {
|
||||
super(a, b);
|
||||
if( b == 0.0 ) {
|
||||
b = .1;
|
||||
}
|
||||
result = a / b;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static final class ModData extends TestData {
|
||||
|
||||
CalcType type = CalcType.MOD;
|
||||
|
||||
public ModData( double b, double a ) {
|
||||
super(b, a);
|
||||
result = a % b;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user