USing Colt library as optional dependency

This commit is contained in:
ngb
2022-07-18 11:06:08 +02:00
parent 9a9a714050
commit 4c8e5c8939
8 changed files with 439 additions and 128 deletions

View File

@@ -28,6 +28,8 @@ dependencies {
runtimeOnly 'com.googlecode.soundlibs:tritonus-share:0.3.7.4' runtimeOnly 'com.googlecode.soundlibs:tritonus-share:0.3.7.4'
runtimeOnly 'com.googlecode.soundlibs:mp3spi:1.9.5.4' runtimeOnly 'com.googlecode.soundlibs:mp3spi:1.9.5.4'
api 'colt:colt:1.2.0'
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1' testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
} }

View File

@@ -0,0 +1,147 @@
package schule.ngb.zm.ml;
import schule.ngb.zm.Constants;
import java.util.Arrays;
import java.util.function.DoubleUnaryOperator;
// TODO: Move Math into Matrix class
// TODO: Implement support for optional sci libs
public class DoubleMatrix implements Matrix {
private int columns, rows;
double[][] coefficients;
public DoubleMatrix( int rows, int cols ) {
this.rows = rows;
this.columns = cols;
coefficients = new double[rows][cols];
}
public DoubleMatrix( double[][] coefficients ) {
this.rows = coefficients.length;
this.columns = coefficients[0].length;
this.coefficients = coefficients;
}
public int columns() {
return columns;
}
public int rows() {
return rows;
}
public double[][] getCoefficients() {
return coefficients;
}
public double get( int row, int col ) {
return coefficients[row][col];
}
public Matrix set( int row, int col, double value ) {
coefficients[row][col] = value;
return this;
}
public Matrix initializeRandom() {
coefficients = MLMath.matrixApply(coefficients, (d) -> Constants.randomGaussian());
return this;
}
public Matrix initializeRandom( double lower, double upper ) {
coefficients = MLMath.matrixApply(coefficients, (d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower);
return this;
}
public Matrix initializeOne() {
coefficients = MLMath.matrixApply(coefficients, (d) -> 1.0);
return this;
}
public Matrix initializeZero() {
coefficients = MLMath.matrixApply(coefficients, (d) -> 0.0);
return this;
}
@Override
public String toString() {
//return Arrays.deepToString(coefficients);
StringBuilder sb = new StringBuilder();
sb.append('[');
sb.append('\n');
for( int i = 0; i < coefficients.length; i++ ) {
sb.append('\t');
sb.append(Arrays.toString(coefficients[i]));
sb.append('\n');
}
sb.append(']');
return sb.toString();
}
@Override
public Matrix transpose() {
coefficients = MLMath.matrixTranspose(coefficients);
return this;
}
@Override
public Matrix multiply( Matrix B ) {
coefficients = MLMath.matrixMultiply(coefficients, B.getCoefficients());
return this;
}
@Override
public Matrix multiplyAddBias( Matrix B, Matrix C ) {
double[] biases = Arrays.stream(C.getCoefficients()).mapToDouble((arr) -> arr[0]).toArray();
coefficients = MLMath.biasAdd(
MLMath.matrixMultiply(coefficients, B.getCoefficients()),
biases
);
return this;
}
@Override
public Matrix multiplyLeft( Matrix B ) {
coefficients = MLMath.matrixMultiply(B.getCoefficients(), coefficients);
return this;
}
@Override
public Matrix add( Matrix B ) {
coefficients = MLMath.matrixAdd(coefficients, B.getCoefficients());
return this;
}
@Override
public Matrix sub( Matrix B ) {
coefficients = MLMath.matrixSub(coefficients, B.getCoefficients());
return this;
}
@Override
public Matrix scale( double scalar ) {
return this;
}
@Override
public Matrix scale( Matrix S ) {
coefficients = MLMath.matrixScale(coefficients, S.getCoefficients());
return this;
}
@Override
public Matrix apply( DoubleUnaryOperator op ) {
this.coefficients = MLMath.matrixApply(coefficients, op);
return this;
}
@Override
public Matrix duplicate() {
return new DoubleMatrix(MLMath.copyMatrix(coefficients));
}
}

View File

@@ -1,82 +1,48 @@
package schule.ngb.zm.ml; package schule.ngb.zm.ml;
import schule.ngb.zm.Constants; import java.util.function.DoubleUnaryOperator;
import java.util.Arrays; public interface Matrix {
// TODO: Move Math into Matrix class int columns();
// TODO: Implement support for optional sci libs
public class Matrix {
private int columns, rows; int rows();
double[][] coefficients; double[][] getCoefficients();
public Matrix( int rows, int cols ) { double get( int row, int col );
this.rows = rows;
this.columns = cols;
coefficients = new double[rows][cols];
}
public Matrix( double[][] coefficients ) { Matrix set( int row, int col, double value );
this.coefficients = coefficients;
this.rows = coefficients.length;
this.columns = coefficients[0].length;
}
public int getColumns() { Matrix initializeRandom();
return columns;
}
public int getRows() { Matrix initializeRandom( double lower, double upper );
return rows;
}
public double[][] getCoefficients() { Matrix initializeOne();
return coefficients;
}
public double get( int row, int col ) { Matrix initializeZero();
return coefficients[row][col];
}
public void initializeRandom() { Matrix transpose();
coefficients = MLMath.matrixApply(coefficients, (d) -> Constants.randomGaussian());
}
public void initializeRandom( double lower, double upper ) { Matrix multiply( Matrix B );
coefficients = MLMath.matrixApply(coefficients, (d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower);
}
public void initializeIdentity() { Matrix multiplyAddBias( Matrix B, Matrix C );
initializeZero();
for( int i = 0; i < Math.min(rows, columns); i++ ) {
this.coefficients[i][i] = 1.0;
}
}
public void initializeOne() { Matrix multiplyLeft( Matrix B );
coefficients = MLMath.matrixApply(coefficients, (d) -> 1.0);
}
public void initializeZero() { Matrix add( Matrix B );
coefficients = MLMath.matrixApply(coefficients, (d) -> 0.0);
}
@Override Matrix sub( Matrix B );
public String toString() {
//return Arrays.deepToString(coefficients);
StringBuilder sb = new StringBuilder();
sb.append('[');
sb.append('\n');
for( int i = 0; i < coefficients.length; i++ ) {
sb.append('\t');
sb.append(Arrays.toString(coefficients[i]));
sb.append('\n');
}
sb.append(']');
return sb.toString();
} Matrix scale( double scalar );
Matrix scale( Matrix S );
Matrix apply( DoubleUnaryOperator op );
Matrix duplicate();
String toString();
} }

View File

@@ -0,0 +1,192 @@
package schule.ngb.zm.ml;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import schule.ngb.zm.Constants;
import schule.ngb.zm.util.Log;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.function.DoubleUnaryOperator;
public class MatrixFactory {
public static void main( String[] args ) {
System.out.println(
MatrixFactory.create(new double[][]{ {1.0, 0.0}, {0.0, 1.0} }).toString()
);
}
public static final Matrix create( int rows, int cols ) {
try {
return getMatrixType().getDeclaredConstructor(int.class, int.class).newInstance(rows, cols);
} catch( Exception ex ) {
LOG.error(ex, "Could not initialize matrix implementation for class <%s>. Using internal implementation.", matrixType);
}
return new DoubleMatrix(rows, cols);
}
public static final Matrix create( double[][] values ) {
try {
return getMatrixType().getDeclaredConstructor(double[][].class).newInstance((Object)values);
} catch( Exception ex ) {
LOG.error(ex, "Could not initialize matrix implementation for class <%s>. Using internal implementation.", matrixType);
}
return new DoubleMatrix(values);
}
private static Class<? extends Matrix> matrixType = null;
private static final Class<? extends Matrix> getMatrixType() {
if( matrixType == null ) {
try {
Class<?> clazz = Class.forName("cern.colt.matrix.impl.DenseDoubleMatrix2D", false, MatrixFactory.class.getClassLoader());
matrixType = ColtMatrix.class;
LOG.info("Colt library found. Using <cern.colt.matrix.impl.DenseDoubleMatrix2D> as matrix implementation.");
} catch( ClassNotFoundException e ) {
LOG.info("Colt library not found. Falling back on internal implementation.");
matrixType = DoubleMatrix.class;
}
}
return matrixType;
}
private static final Log LOG = Log.getLogger(MatrixFactory.class);
static class ColtMatrix implements Matrix {
cern.colt.matrix.DoubleMatrix2D matrix;
public ColtMatrix( double[][] doubles ) {
matrix = new cern.colt.matrix.impl.DenseDoubleMatrix2D(doubles);
}
public ColtMatrix( int rows, int cols ) {
matrix = new cern.colt.matrix.impl.DenseDoubleMatrix2D(rows, cols);
}
@Override
public int columns() {
return matrix.columns();
}
@Override
public int rows() {
return matrix.rows();
}
@Override
public double get( int row, int col ) {
return matrix.get(row, col);
}
@Override
public Matrix set( int row, int col, double value ) {
matrix.set(row, col, value);
return this;
}
@Override
public double[][] getCoefficients() {
return this.matrix.toArray();
}
@Override
public Matrix initializeRandom() {
matrix.assign((d) -> Constants.randomGaussian());
return this;
}
@Override
public Matrix initializeRandom( double lower, double upper ) {
matrix.assign((d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower);
return this;
}
@Override
public Matrix initializeOne() {
this.matrix.assign(1.0);
return this;
}
@Override
public Matrix initializeZero() {
this.matrix.assign(0.0);
return this;
}
@Override
public Matrix apply( DoubleUnaryOperator op ) {
this.matrix.assign((d) -> op.applyAsDouble(d));
return this;
}
@Override
public Matrix transpose() {
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.transpose(this.matrix);
return this;
}
@Override
public Matrix multiply( Matrix B ) {
ColtMatrix CB = (ColtMatrix)B;
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(matrix, CB.matrix);
return this;
}
@Override
public Matrix multiplyLeft( Matrix B ) {
ColtMatrix CB = (ColtMatrix)B;
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(CB.matrix, matrix);
return this;
}
@Override
public Matrix multiplyAddBias( Matrix B, Matrix C ) {
ColtMatrix CB = (ColtMatrix)B;
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(matrix, CB.matrix);
// TODO: add bias
return this;
}
@Override
public Matrix add( Matrix B ) {
ColtMatrix CB = (ColtMatrix)B;
matrix.assign(CB.matrix, (d1,d2) -> d1+d2);
return this;
}
@Override
public Matrix sub( Matrix B ) {
ColtMatrix CB = (ColtMatrix)B;
matrix.assign(CB.matrix, (d1,d2) -> d1-d2);
return this;
}
@Override
public Matrix scale( double scalar ) {
this.matrix.assign((d) -> d*scalar);
return this;
}
@Override
public Matrix scale( Matrix S ) {
this.matrix.forEachNonZero((r, c, d) -> d * S.get(r, c));
return this;
}
@Override
public Matrix duplicate() {
ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
newMatrix.matrix.assign(this.matrix);
return newMatrix;
}
@Override
public String toString() {
return matrix.toString();
}
}
}

View File

@@ -23,13 +23,13 @@ public class NeuralNetwork {
for( int i = 0; i < layer.getInputCount(); i++ ) { for( int i = 0; i < layer.getInputCount(); i++ ) {
for( int j = 0; j < layer.getNeuronCount(); j++ ) { for( int j = 0; j < layer.getNeuronCount(); j++ ) {
out.print(layer.weights.coefficients[i][j]); //out.print(layer.weights.coefficients[i][j]);
out.print(' '); out.print(' ');
} }
out.println(); out.println();
} }
for( int j = 0; j < layer.getNeuronCount(); j++ ) { for( int j = 0; j < layer.getNeuronCount(); j++ ) {
out.print(layer.biases[j]); //out.print(layer.biases[j]);
out.print(' '); out.print(' ');
} }
out.println(); out.println();
@@ -56,13 +56,13 @@ public class NeuralNetwork {
for( int i = 0; i < inputs; i++ ) { for( int i = 0; i < inputs; i++ ) {
split = in.readLine().split(" "); split = in.readLine().split(" ");
for( int j = 0; j < neurons; j++ ) { for( int j = 0; j < neurons; j++ ) {
layer.weights.coefficients[i][j] = Double.parseDouble(split[j]); //layer.weights.coefficients[i][j] = Double.parseDouble(split[j]);
} }
} }
// Load Biases // Load Biases
split = in.readLine().split(" "); split = in.readLine().split(" ");
for( int j = 0; j < neurons; j++ ) { for( int j = 0; j < neurons; j++ ) {
layer.biases[j] = Double.parseDouble(split[j]); //layer.biases[j] = Double.parseDouble(split[j]);
} }
layers.add(layer); layers.add(layer);
@@ -107,7 +107,7 @@ public class NeuralNetwork {
private NeuronLayer[] layers; private NeuronLayer[] layers;
private double[][] output; private Matrix output;
private double learningRate = 0.1; private double learningRate = 0.1;
@@ -162,17 +162,25 @@ public class NeuralNetwork {
this.learningRate = pLearningRate; this.learningRate = pLearningRate;
} }
public double[][] getOutput() { public Matrix getOutput() {
return output; return output;
} }
public double[][] predict( double[][] inputs ) { public Matrix predict( double[][] inputs ) {
//this.output = layers[1].apply(layers[0].apply(inputs)); //this.output = layers[1].apply(layers[0].apply(inputs));
return predict(MatrixFactory.create(inputs));
}
public Matrix predict( Matrix inputs ) {
this.output = layers[0].apply(inputs); this.output = layers[0].apply(inputs);
return this.output; return this.output;
} }
public void learn( double[][] expected ) { public void learn( double[][] expected ) {
learn(MatrixFactory.create(expected));
}
public void learn( Matrix expected ) {
layers[layers.length - 1].backprop(expected, learningRate); layers[layers.length - 1].backprop(expected, learningRate);
} }

View File

@@ -4,9 +4,9 @@ import java.util.Arrays;
import java.util.function.DoubleUnaryOperator; import java.util.function.DoubleUnaryOperator;
import java.util.function.Function; import java.util.function.Function;
public class NeuronLayer implements Function<double[][], double[][]> { public class NeuronLayer implements Function<Matrix, Matrix> {
public static NeuronLayer fromArray( double[][] weights ) { /*public static NeuronLayer fromArray( double[][] weights ) {
NeuronLayer layer = new NeuronLayer(weights[0].length, weights.length); NeuronLayer layer = new NeuronLayer(weights[0].length, weights.length);
for( int i = 0; i < weights[0].length; i++ ) { for( int i = 0; i < weights[0].length; i++ ) {
for( int j = 0; j < weights.length; j++ ) { for( int j = 0; j < weights.length; j++ ) {
@@ -27,24 +27,26 @@ public class NeuronLayer implements Function<double[][], double[][]> {
layer.biases[j] = biases[j]; layer.biases[j] = biases[j];
} }
return layer; return layer;
} }*/
Matrix weights; Matrix weights;
double[] biases; Matrix biases;
NeuronLayer previous, next; NeuronLayer previous, next;
DoubleUnaryOperator activationFunction, activationFunctionDerivative; DoubleUnaryOperator activationFunction, activationFunctionDerivative;
double[][] lastOutput, lastInput; Matrix lastOutput, lastInput;
public NeuronLayer( int neurons, int inputs ) { public NeuronLayer( int neurons, int inputs ) {
weights = new Matrix(inputs, neurons); weights = MatrixFactory
weights.initializeRandom(-1, 1); .create(inputs, neurons)
.initializeRandom(-1, 1);
biases = new double[neurons]; biases = MatrixFactory
Arrays.fill(biases, 0.0); // TODO: Random? .create(neurons, 1)
.initializeZero();
activationFunction = MLMath::sigmoid; activationFunction = MLMath::sigmoid;
activationFunctionDerivative = MLMath::sigmoidDerivative; activationFunctionDerivative = MLMath::sigmoidDerivative;
@@ -89,41 +91,42 @@ public class NeuronLayer implements Function<double[][], double[][]> {
return weights; return weights;
} }
public Matrix getBiases() {
return biases;
}
public int getNeuronCount() { public int getNeuronCount() {
return weights.coefficients[0].length; return weights.columns();
} }
public int getInputCount() { public int getInputCount() {
return weights.coefficients.length; return weights.rows();
} }
public double[][] getLastOutput() { public Matrix getLastOutput() {
return lastOutput; return lastOutput;
} }
public void setWeights( double[][] newWeights ) { public void setWeights( Matrix newWeights ) {
weights.coefficients = MLMath.copyMatrix(newWeights); weights = newWeights.duplicate();
} }
public void adjustWeights( double[][] adjustment ) { public void adjustWeights( Matrix adjustment ) {
weights.coefficients = MLMath.matrixAdd(weights.coefficients, adjustment); weights.add(adjustment);
} }
@Override @Override
public String toString() { public String toString() {
return weights.toString() + "\n" + Arrays.toString(biases); return weights.toString() + "\n" + biases.toString();
} }
@Override @Override
public double[][] apply( double[][] inputs ) { public Matrix apply( Matrix inputs ) {
lastInput = inputs; lastInput = inputs;
lastOutput = MLMath.matrixApply( lastOutput = inputs
MLMath.biasAdd( .multiplyAddBias(weights, biases)
MLMath.matrixMultiply(inputs, weights.coefficients), .apply(activationFunction);
biases
),
activationFunction
);
if( next != null ) { if( next != null ) {
return next.apply(lastOutput); return next.apply(lastOutput);
} else { } else {
@@ -132,35 +135,34 @@ public class NeuronLayer implements Function<double[][], double[][]> {
} }
@Override @Override
public <V> Function<V, double[][]> compose( Function<? super V, ? extends double[][]> before ) { public <V> Function<V, Matrix> compose( Function<? super V, ? extends Matrix> before ) {
return ( in ) -> apply(before.apply(in)); return ( in ) -> apply(before.apply(in));
} }
@Override @Override
public <V> Function<double[][], V> andThen( Function<? super double[][], ? extends V> after ) { public <V> Function<Matrix, V> andThen( Function<? super Matrix, ? extends V> after ) {
return ( in ) -> after.apply(apply(in)); return ( in ) -> after.apply(apply(in));
} }
public void backprop( double[][] expected, double learningRate ) { public void backprop( Matrix expected, double learningRate ) {
double[][] error, delta, adjustment; Matrix error, delta, adjustment;
if( next == null ) { if( next == null ) {
error = MLMath.matrixSub(expected, this.lastOutput); error = expected.duplicate().sub(lastOutput);
} else { } else {
error = MLMath.matrixMultiply(expected, MLMath.matrixTranspose(next.weights.coefficients)); error = expected.duplicate().multiply(next.weights.transpose());
} }
delta = MLMath.matrixScale(error, MLMath.matrixApply(this.lastOutput, this.activationFunctionDerivative)); error.scale(lastOutput.duplicate().apply(this.activationFunctionDerivative));
// Hier schon leraningRate anwenden? // Hier schon leraningRate anwenden?
// See https://towardsdatascience.com/understanding-and-implementing-neural-networks-in-java-from-scratch-61421bb6352c // See https://towardsdatascience.com/understanding-and-implementing-neural-networks-in-java-from-scratch-61421bb6352c
//delta = MLMath.matrixApply(delta, ( x ) -> learningRate * x); //delta = MLMath.matrixApply(delta, ( x ) -> learningRate * x);
if( previous != null ) { if( previous != null ) {
previous.backprop(delta, learningRate); previous.backprop(error, learningRate);
} }
biases = MLMath.biasAdjust(biases, MLMath.matrixApply(delta, ( x ) -> learningRate * x)); // biases = MLMath.biasAdjust(biases, MLMath.matrixApply(delta, ( x ) -> learningRate * x));
adjustment = MLMath.matrixMultiply(MLMath.matrixTranspose(lastInput), delta); adjustment = lastInput.duplicate().transpose().multiply(error).apply((d) -> learningRate*d);
adjustment = MLMath.matrixApply(adjustment, ( x ) -> learningRate * x);
this.adjustWeights(adjustment); this.adjustWeights(adjustment);
} }

View File

@@ -6,22 +6,11 @@ import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
class MatrixTest { class DoubleMatrixTest {
@Test
void initializeIdentity() {
Matrix m = new Matrix(4, 4);
m.initializeIdentity();
assertArrayEquals(new double[]{1.0, 0.0, 0.0, 0.0}, m.coefficients[0]);
assertArrayEquals(new double[]{0.0, 1.0, 0.0, 0.0}, m.coefficients[1]);
assertArrayEquals(new double[]{0.0, 0.0, 1.0, 0.0}, m.coefficients[2]);
assertArrayEquals(new double[]{0.0, 0.0, 0.0, 1.0}, m.coefficients[3]);
}
@Test @Test
void initializeOne() { void initializeOne() {
Matrix m = new Matrix(4, 4); DoubleMatrix m = new DoubleMatrix(4, 4);
m.initializeOne(); m.initializeOne();
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0}; double[] ones = new double[]{1.0, 1.0, 1.0, 1.0};
@@ -33,7 +22,7 @@ class MatrixTest {
@Test @Test
void initializeZero() { void initializeZero() {
Matrix m = new Matrix(4, 4); DoubleMatrix m = new DoubleMatrix(4, 4);
m.initializeZero(); m.initializeZero();
double[] zeros = new double[]{0.0, 0.0, 0.0, 0.0}; double[] zeros = new double[]{0.0, 0.0, 0.0, 0.0};
@@ -45,7 +34,7 @@ class MatrixTest {
@Test @Test
void initializeRandom() { void initializeRandom() {
Matrix m = new Matrix(4, 4); DoubleMatrix m = new DoubleMatrix(4, 4);
m.initializeRandom(-1, 1); m.initializeRandom(-1, 1);
assertTrue(Arrays.stream(m.coefficients[0]).allMatch((d) -> -1.0 <= d && d < 1.0)); assertTrue(Arrays.stream(m.coefficients[0]).allMatch((d) -> -1.0 <= d && d < 1.0));

View File

@@ -3,6 +3,7 @@ package schule.ngb.zm.ml;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import schule.ngb.zm.util.Log; import schule.ngb.zm.util.Log;
import schule.ngb.zm.util.Timer;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
@@ -18,7 +19,7 @@ class NeuralNetworkTest {
Log.enableGlobalDebugging(); Log.enableGlobalDebugging();
} }
@Test /*@Test
void readWrite() { void readWrite() {
// XOR Dataset // XOR Dataset
NeuralNetwork net = new NeuralNetwork(2, 4, 1); NeuralNetwork net = new NeuralNetwork(2, 4, 1);
@@ -53,7 +54,7 @@ class NeuralNetworkTest {
} }
assertArrayEquals(net.predict(inputs), net2.predict(inputs)); assertArrayEquals(net.predict(inputs), net2.predict(inputs));
} }*/
@Test @Test
void learnXor() { void learnXor() {
@@ -78,14 +79,14 @@ class NeuralNetworkTest {
} }
// calculate predictions // calculate predictions
double[][] predictions = net.predict(inputs); Matrix predictions = net.predict(inputs);
for( int i = 0; i < 4; i++ ) { for( int i = 0; i < 4; i++ ) {
int parsed_pred = predictions[i][0] < 0.5 ? 0 : 1; int parsed_pred = predictions.get(i, 0) < 0.5 ? 0 : 1;
System.out.printf( System.out.printf(
"{%.0f, %.0f} = %.4f (%d) -> %s\n", "{%.0f, %.0f} = %.4f (%d) -> %s\n",
inputs[i][0], inputs[i][1], inputs[i][0], inputs[i][1],
predictions[i][0], predictions.get(i, 0),
parsed_pred, parsed_pred,
parsed_pred == outputs[i][0] ? "correct" : "miss" parsed_pred == outputs[i][0] ? "correct" : "miss"
); );
@@ -112,9 +113,13 @@ class NeuralNetworkTest {
outputs[i][0] = trainingData.get(i).result; outputs[i][0] = trainingData.get(i).result;
} }
Timer timer = new Timer();
System.out.println("Training the neural net to learn "+OPERATION+"..."); System.out.println("Training the neural net to learn "+OPERATION+"...");
timer.start();
net.train(inputs, outputs, TRAINING_CYCLES); net.train(inputs, outputs, TRAINING_CYCLES);
System.out.println(" finished training"); timer.stop();
System.out.println(" finished training (" + timer.getMillis() + "ms)");
for( int i = 1; i <= net.getLayerCount(); i++ ) { for( int i = 1; i <= net.getLayerCount(); i++ ) {
System.out.println("Layer " +i + " weights"); System.out.println("Layer " +i + " weights");
@@ -136,9 +141,9 @@ class NeuralNetworkTest {
System.out.printf( System.out.printf(
"Prediction on data (%.2f, %.2f) was %.4f, expected %.2f (of by %.4f)\n", "Prediction on data (%.2f, %.2f) was %.4f, expected %.2f (of by %.4f)\n",
data.a, data.b, data.a, data.b,
net.getOutput()[0][0], net.getOutput().get(0, 0),
data.result, data.result,
net.getOutput()[0][0] - data.result net.getOutput().get(0, 0) - data.result
); );
} }