mirror of
https://github.com/jneug/zeichenmaschine.git
synced 2026-04-14 14:43:33 +02:00
Matric interface umbenannt
This commit is contained in:
@@ -28,7 +28,7 @@ dependencies {
|
|||||||
runtimeOnly 'com.googlecode.soundlibs:tritonus-share:0.3.7.4'
|
runtimeOnly 'com.googlecode.soundlibs:tritonus-share:0.3.7.4'
|
||||||
runtimeOnly 'com.googlecode.soundlibs:mp3spi:1.9.5.4'
|
runtimeOnly 'com.googlecode.soundlibs:mp3spi:1.9.5.4'
|
||||||
|
|
||||||
api 'colt:colt:1.2.0'
|
compileOnlyApi 'colt:colt:1.2.0'
|
||||||
|
|
||||||
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
|
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
|
||||||
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
|
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import java.util.function.DoubleUnaryOperator;
|
|||||||
|
|
||||||
// TODO: Move Math into Matrix class
|
// TODO: Move Math into Matrix class
|
||||||
// TODO: Implement support for optional sci libs
|
// TODO: Implement support for optional sci libs
|
||||||
public class DoubleMatrix implements Matrix {
|
public class DoubleMatrix implements MLMatrix {
|
||||||
|
|
||||||
private int columns, rows;
|
private int columns, rows;
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ public class DoubleMatrix implements Matrix {
|
|||||||
return rows;
|
return rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double[][] getCoefficients() {
|
public double[][] coefficients() {
|
||||||
return coefficients;
|
return coefficients;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,27 +41,27 @@ public class DoubleMatrix implements Matrix {
|
|||||||
return coefficients[row][col];
|
return coefficients[row][col];
|
||||||
}
|
}
|
||||||
|
|
||||||
public Matrix set( int row, int col, double value ) {
|
public MLMatrix set( int row, int col, double value ) {
|
||||||
coefficients[row][col] = value;
|
coefficients[row][col] = value;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Matrix initializeRandom() {
|
public MLMatrix initializeRandom() {
|
||||||
coefficients = MLMath.matrixApply(coefficients, (d) -> Constants.randomGaussian());
|
coefficients = MLMath.matrixApply(coefficients, (d) -> Constants.randomGaussian());
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Matrix initializeRandom( double lower, double upper ) {
|
public MLMatrix initializeRandom( double lower, double upper ) {
|
||||||
coefficients = MLMath.matrixApply(coefficients, (d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower);
|
coefficients = MLMath.matrixApply(coefficients, (d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Matrix initializeOne() {
|
public MLMatrix initializeOne() {
|
||||||
coefficients = MLMath.matrixApply(coefficients, (d) -> 1.0);
|
coefficients = MLMath.matrixApply(coefficients, (d) -> 1.0);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Matrix initializeZero() {
|
public MLMatrix initializeZero() {
|
||||||
coefficients = MLMath.matrixApply(coefficients, (d) -> 0.0);
|
coefficients = MLMath.matrixApply(coefficients, (d) -> 0.0);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
@@ -83,64 +83,64 @@ public class DoubleMatrix implements Matrix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix transpose() {
|
public MLMatrix transpose() {
|
||||||
coefficients = MLMath.matrixTranspose(coefficients);
|
coefficients = MLMath.matrixTranspose(coefficients);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix multiply( Matrix B ) {
|
public MLMatrix multiply( MLMatrix B ) {
|
||||||
coefficients = MLMath.matrixMultiply(coefficients, B.getCoefficients());
|
coefficients = MLMath.matrixMultiply(coefficients, B.coefficients());
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix multiplyAddBias( Matrix B, Matrix C ) {
|
public MLMatrix multiplyAddBias( MLMatrix B, MLMatrix C ) {
|
||||||
double[] biases = Arrays.stream(C.getCoefficients()).mapToDouble((arr) -> arr[0]).toArray();
|
double[] biases = Arrays.stream(C.coefficients()).mapToDouble(( arr) -> arr[0]).toArray();
|
||||||
coefficients = MLMath.biasAdd(
|
coefficients = MLMath.biasAdd(
|
||||||
MLMath.matrixMultiply(coefficients, B.getCoefficients()),
|
MLMath.matrixMultiply(coefficients, B.coefficients()),
|
||||||
biases
|
biases
|
||||||
);
|
);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix multiplyLeft( Matrix B ) {
|
public MLMatrix multiplyLeft( MLMatrix B ) {
|
||||||
coefficients = MLMath.matrixMultiply(B.getCoefficients(), coefficients);
|
coefficients = MLMath.matrixMultiply(B.coefficients(), coefficients);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix add( Matrix B ) {
|
public MLMatrix add( MLMatrix B ) {
|
||||||
coefficients = MLMath.matrixAdd(coefficients, B.getCoefficients());
|
coefficients = MLMath.matrixAdd(coefficients, B.coefficients());
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix sub( Matrix B ) {
|
public MLMatrix sub( MLMatrix B ) {
|
||||||
coefficients = MLMath.matrixSub(coefficients, B.getCoefficients());
|
coefficients = MLMath.matrixSub(coefficients, B.coefficients());
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix scale( double scalar ) {
|
public MLMatrix scale( double scalar ) {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix scale( Matrix S ) {
|
public MLMatrix scale( MLMatrix S ) {
|
||||||
coefficients = MLMath.matrixScale(coefficients, S.getCoefficients());
|
coefficients = MLMath.matrixScale(coefficients, S.coefficients());
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix apply( DoubleUnaryOperator op ) {
|
public MLMatrix apply( DoubleUnaryOperator op ) {
|
||||||
this.coefficients = MLMath.matrixApply(coefficients, op);
|
this.coefficients = MLMath.matrixApply(coefficients, op);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix duplicate() {
|
public MLMatrix duplicate() {
|
||||||
return new DoubleMatrix(MLMath.copyMatrix(coefficients));
|
return new DoubleMatrix(MLMath.copyMatrix(coefficients));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
52
src/main/java/schule/ngb/zm/ml/MLMatrix.java
Normal file
52
src/main/java/schule/ngb/zm/ml/MLMatrix.java
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package schule.ngb.zm.ml;
|
||||||
|
|
||||||
|
import java.util.function.DoubleUnaryOperator;
|
||||||
|
|
||||||
|
public interface MLMatrix {
|
||||||
|
|
||||||
|
int columns();
|
||||||
|
|
||||||
|
int rows();
|
||||||
|
|
||||||
|
double[][] coefficients();
|
||||||
|
|
||||||
|
double get( int row, int col );
|
||||||
|
|
||||||
|
MLMatrix set( int row, int col, double value );
|
||||||
|
|
||||||
|
MLMatrix initializeRandom();
|
||||||
|
|
||||||
|
MLMatrix initializeRandom( double lower, double upper );
|
||||||
|
|
||||||
|
MLMatrix initializeOne();
|
||||||
|
|
||||||
|
MLMatrix initializeZero();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Erzeugt die transponierte Matrix zu dieser.
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
MLMatrix transpose();
|
||||||
|
|
||||||
|
MLMatrix multiply( MLMatrix B );
|
||||||
|
|
||||||
|
MLMatrix multiplyAddBias( MLMatrix B, MLMatrix C );
|
||||||
|
|
||||||
|
MLMatrix multiplyLeft( MLMatrix B );
|
||||||
|
|
||||||
|
MLMatrix add( MLMatrix B );
|
||||||
|
|
||||||
|
MLMatrix sub( MLMatrix B );
|
||||||
|
|
||||||
|
|
||||||
|
MLMatrix scale( double scalar );
|
||||||
|
|
||||||
|
MLMatrix scale( MLMatrix S );
|
||||||
|
|
||||||
|
MLMatrix apply( DoubleUnaryOperator op );
|
||||||
|
|
||||||
|
MLMatrix duplicate();
|
||||||
|
|
||||||
|
String toString();
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
package schule.ngb.zm.ml;
|
|
||||||
|
|
||||||
import java.util.function.DoubleUnaryOperator;
|
|
||||||
|
|
||||||
public interface Matrix {
|
|
||||||
|
|
||||||
int columns();
|
|
||||||
|
|
||||||
int rows();
|
|
||||||
|
|
||||||
double[][] getCoefficients();
|
|
||||||
|
|
||||||
double get( int row, int col );
|
|
||||||
|
|
||||||
Matrix set( int row, int col, double value );
|
|
||||||
|
|
||||||
Matrix initializeRandom();
|
|
||||||
|
|
||||||
Matrix initializeRandom( double lower, double upper );
|
|
||||||
|
|
||||||
Matrix initializeOne();
|
|
||||||
|
|
||||||
Matrix initializeZero();
|
|
||||||
|
|
||||||
Matrix transpose();
|
|
||||||
|
|
||||||
Matrix multiply( Matrix B );
|
|
||||||
|
|
||||||
Matrix multiplyAddBias( Matrix B, Matrix C );
|
|
||||||
|
|
||||||
Matrix multiplyLeft( Matrix B );
|
|
||||||
|
|
||||||
Matrix add( Matrix B );
|
|
||||||
|
|
||||||
Matrix sub( Matrix B );
|
|
||||||
|
|
||||||
|
|
||||||
Matrix scale( double scalar );
|
|
||||||
|
|
||||||
Matrix scale( Matrix S );
|
|
||||||
|
|
||||||
Matrix apply( DoubleUnaryOperator op );
|
|
||||||
|
|
||||||
Matrix duplicate();
|
|
||||||
|
|
||||||
String toString();
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -5,8 +5,6 @@ import cern.colt.matrix.impl.DenseDoubleMatrix2D;
|
|||||||
import schule.ngb.zm.Constants;
|
import schule.ngb.zm.Constants;
|
||||||
import schule.ngb.zm.util.Log;
|
import schule.ngb.zm.util.Log;
|
||||||
|
|
||||||
import java.lang.reflect.Constructor;
|
|
||||||
import java.lang.reflect.InvocationTargetException;
|
|
||||||
import java.util.function.DoubleUnaryOperator;
|
import java.util.function.DoubleUnaryOperator;
|
||||||
|
|
||||||
public class MatrixFactory {
|
public class MatrixFactory {
|
||||||
@@ -17,7 +15,7 @@ public class MatrixFactory {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static final Matrix create( int rows, int cols ) {
|
public static final MLMatrix create( int rows, int cols ) {
|
||||||
try {
|
try {
|
||||||
return getMatrixType().getDeclaredConstructor(int.class, int.class).newInstance(rows, cols);
|
return getMatrixType().getDeclaredConstructor(int.class, int.class).newInstance(rows, cols);
|
||||||
} catch( Exception ex ) {
|
} catch( Exception ex ) {
|
||||||
@@ -26,7 +24,7 @@ public class MatrixFactory {
|
|||||||
return new DoubleMatrix(rows, cols);
|
return new DoubleMatrix(rows, cols);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static final Matrix create( double[][] values ) {
|
public static final MLMatrix create( double[][] values ) {
|
||||||
try {
|
try {
|
||||||
return getMatrixType().getDeclaredConstructor(double[][].class).newInstance((Object)values);
|
return getMatrixType().getDeclaredConstructor(double[][].class).newInstance((Object)values);
|
||||||
} catch( Exception ex ) {
|
} catch( Exception ex ) {
|
||||||
@@ -35,9 +33,9 @@ public class MatrixFactory {
|
|||||||
return new DoubleMatrix(values);
|
return new DoubleMatrix(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Class<? extends Matrix> matrixType = null;
|
private static Class<? extends MLMatrix> matrixType = null;
|
||||||
|
|
||||||
private static final Class<? extends Matrix> getMatrixType() {
|
private static final Class<? extends MLMatrix> getMatrixType() {
|
||||||
if( matrixType == null ) {
|
if( matrixType == null ) {
|
||||||
try {
|
try {
|
||||||
Class<?> clazz = Class.forName("cern.colt.matrix.impl.DenseDoubleMatrix2D", false, MatrixFactory.class.getClassLoader());
|
Class<?> clazz = Class.forName("cern.colt.matrix.impl.DenseDoubleMatrix2D", false, MatrixFactory.class.getClassLoader());
|
||||||
@@ -53,7 +51,7 @@ public class MatrixFactory {
|
|||||||
|
|
||||||
private static final Log LOG = Log.getLogger(MatrixFactory.class);
|
private static final Log LOG = Log.getLogger(MatrixFactory.class);
|
||||||
|
|
||||||
static class ColtMatrix implements Matrix {
|
static class ColtMatrix implements MLMatrix {
|
||||||
|
|
||||||
cern.colt.matrix.DoubleMatrix2D matrix;
|
cern.colt.matrix.DoubleMatrix2D matrix;
|
||||||
|
|
||||||
@@ -81,68 +79,68 @@ public class MatrixFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix set( int row, int col, double value ) {
|
public MLMatrix set( int row, int col, double value ) {
|
||||||
matrix.set(row, col, value);
|
matrix.set(row, col, value);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double[][] getCoefficients() {
|
public double[][] coefficients() {
|
||||||
return this.matrix.toArray();
|
return this.matrix.toArray();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix initializeRandom() {
|
public MLMatrix initializeRandom() {
|
||||||
matrix.assign((d) -> Constants.randomGaussian());
|
matrix.assign((d) -> Constants.randomGaussian());
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix initializeRandom( double lower, double upper ) {
|
public MLMatrix initializeRandom( double lower, double upper ) {
|
||||||
matrix.assign((d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower);
|
matrix.assign((d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix initializeOne() {
|
public MLMatrix initializeOne() {
|
||||||
this.matrix.assign(1.0);
|
this.matrix.assign(1.0);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix initializeZero() {
|
public MLMatrix initializeZero() {
|
||||||
this.matrix.assign(0.0);
|
this.matrix.assign(0.0);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix apply( DoubleUnaryOperator op ) {
|
public MLMatrix apply( DoubleUnaryOperator op ) {
|
||||||
this.matrix.assign((d) -> op.applyAsDouble(d));
|
this.matrix.assign((d) -> op.applyAsDouble(d));
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix transpose() {
|
public MLMatrix transpose() {
|
||||||
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.transpose(this.matrix);
|
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.transpose(this.matrix);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix multiply( Matrix B ) {
|
public MLMatrix multiply( MLMatrix B ) {
|
||||||
ColtMatrix CB = (ColtMatrix)B;
|
ColtMatrix CB = (ColtMatrix)B;
|
||||||
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(matrix, CB.matrix);
|
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(matrix, CB.matrix);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix multiplyLeft( Matrix B ) {
|
public MLMatrix multiplyLeft( MLMatrix B ) {
|
||||||
ColtMatrix CB = (ColtMatrix)B;
|
ColtMatrix CB = (ColtMatrix)B;
|
||||||
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(CB.matrix, matrix);
|
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(CB.matrix, matrix);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix multiplyAddBias( Matrix B, Matrix C ) {
|
public MLMatrix multiplyAddBias( MLMatrix B, MLMatrix C ) {
|
||||||
ColtMatrix CB = (ColtMatrix)B;
|
ColtMatrix CB = (ColtMatrix)B;
|
||||||
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(matrix, CB.matrix);
|
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(matrix, CB.matrix);
|
||||||
// TODO: add bias
|
// TODO: add bias
|
||||||
@@ -150,33 +148,33 @@ public class MatrixFactory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix add( Matrix B ) {
|
public MLMatrix add( MLMatrix B ) {
|
||||||
ColtMatrix CB = (ColtMatrix)B;
|
ColtMatrix CB = (ColtMatrix)B;
|
||||||
matrix.assign(CB.matrix, (d1,d2) -> d1+d2);
|
matrix.assign(CB.matrix, (d1,d2) -> d1+d2);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix sub( Matrix B ) {
|
public MLMatrix sub( MLMatrix B ) {
|
||||||
ColtMatrix CB = (ColtMatrix)B;
|
ColtMatrix CB = (ColtMatrix)B;
|
||||||
matrix.assign(CB.matrix, (d1,d2) -> d1-d2);
|
matrix.assign(CB.matrix, (d1,d2) -> d1-d2);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix scale( double scalar ) {
|
public MLMatrix scale( double scalar ) {
|
||||||
this.matrix.assign((d) -> d*scalar);
|
this.matrix.assign((d) -> d*scalar);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix scale( Matrix S ) {
|
public MLMatrix scale( MLMatrix S ) {
|
||||||
this.matrix.forEachNonZero((r, c, d) -> d * S.get(r, c));
|
this.matrix.forEachNonZero((r, c, d) -> d * S.get(r, c));
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix duplicate() {
|
public MLMatrix duplicate() {
|
||||||
ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
|
ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
|
||||||
newMatrix.matrix.assign(this.matrix);
|
newMatrix.matrix.assign(this.matrix);
|
||||||
return newMatrix;
|
return newMatrix;
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ public class NeuralNetwork {
|
|||||||
|
|
||||||
private NeuronLayer[] layers;
|
private NeuronLayer[] layers;
|
||||||
|
|
||||||
private Matrix output;
|
private MLMatrix output;
|
||||||
|
|
||||||
private double learningRate = 0.1;
|
private double learningRate = 0.1;
|
||||||
|
|
||||||
@@ -162,16 +162,16 @@ public class NeuralNetwork {
|
|||||||
this.learningRate = pLearningRate;
|
this.learningRate = pLearningRate;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Matrix getOutput() {
|
public MLMatrix getOutput() {
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Matrix predict( double[][] inputs ) {
|
public MLMatrix predict( double[][] inputs ) {
|
||||||
//this.output = layers[1].apply(layers[0].apply(inputs));
|
//this.output = layers[1].apply(layers[0].apply(inputs));
|
||||||
return predict(MatrixFactory.create(inputs));
|
return predict(MatrixFactory.create(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
public Matrix predict( Matrix inputs ) {
|
public MLMatrix predict( MLMatrix inputs ) {
|
||||||
this.output = layers[0].apply(inputs);
|
this.output = layers[0].apply(inputs);
|
||||||
return this.output;
|
return this.output;
|
||||||
}
|
}
|
||||||
@@ -180,7 +180,7 @@ public class NeuralNetwork {
|
|||||||
learn(MatrixFactory.create(expected));
|
learn(MatrixFactory.create(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void learn( Matrix expected ) {
|
public void learn( MLMatrix expected ) {
|
||||||
layers[layers.length - 1].backprop(expected, learningRate);
|
layers[layers.length - 1].backprop(expected, learningRate);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
package schule.ngb.zm.ml;
|
package schule.ngb.zm.ml;
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.function.DoubleUnaryOperator;
|
import java.util.function.DoubleUnaryOperator;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
|
|
||||||
public class NeuronLayer implements Function<Matrix, Matrix> {
|
public class NeuronLayer implements Function<MLMatrix, MLMatrix> {
|
||||||
|
|
||||||
/*public static NeuronLayer fromArray( double[][] weights ) {
|
/*public static NeuronLayer fromArray( double[][] weights ) {
|
||||||
NeuronLayer layer = new NeuronLayer(weights[0].length, weights.length);
|
NeuronLayer layer = new NeuronLayer(weights[0].length, weights.length);
|
||||||
@@ -29,15 +28,15 @@ public class NeuronLayer implements Function<Matrix, Matrix> {
|
|||||||
return layer;
|
return layer;
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
Matrix weights;
|
MLMatrix weights;
|
||||||
|
|
||||||
Matrix biases;
|
MLMatrix biases;
|
||||||
|
|
||||||
NeuronLayer previous, next;
|
NeuronLayer previous, next;
|
||||||
|
|
||||||
DoubleUnaryOperator activationFunction, activationFunctionDerivative;
|
DoubleUnaryOperator activationFunction, activationFunctionDerivative;
|
||||||
|
|
||||||
Matrix lastOutput, lastInput;
|
MLMatrix lastOutput, lastInput;
|
||||||
|
|
||||||
public NeuronLayer( int neurons, int inputs ) {
|
public NeuronLayer( int neurons, int inputs ) {
|
||||||
weights = MatrixFactory
|
weights = MatrixFactory
|
||||||
@@ -87,11 +86,11 @@ public class NeuronLayer implements Function<Matrix, Matrix> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Matrix getWeights() {
|
public MLMatrix getWeights() {
|
||||||
return weights;
|
return weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Matrix getBiases() {
|
public MLMatrix getBiases() {
|
||||||
return biases;
|
return biases;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,15 +102,15 @@ public class NeuronLayer implements Function<Matrix, Matrix> {
|
|||||||
return weights.rows();
|
return weights.rows();
|
||||||
}
|
}
|
||||||
|
|
||||||
public Matrix getLastOutput() {
|
public MLMatrix getLastOutput() {
|
||||||
return lastOutput;
|
return lastOutput;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setWeights( Matrix newWeights ) {
|
public void setWeights( MLMatrix newWeights ) {
|
||||||
weights = newWeights.duplicate();
|
weights = newWeights.duplicate();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void adjustWeights( Matrix adjustment ) {
|
public void adjustWeights( MLMatrix adjustment ) {
|
||||||
weights.add(adjustment);
|
weights.add(adjustment);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +120,7 @@ public class NeuronLayer implements Function<Matrix, Matrix> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Matrix apply( Matrix inputs ) {
|
public MLMatrix apply( MLMatrix inputs ) {
|
||||||
lastInput = inputs;
|
lastInput = inputs;
|
||||||
lastOutput = inputs
|
lastOutput = inputs
|
||||||
.multiplyAddBias(weights, biases)
|
.multiplyAddBias(weights, biases)
|
||||||
@@ -135,17 +134,17 @@ public class NeuronLayer implements Function<Matrix, Matrix> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <V> Function<V, Matrix> compose( Function<? super V, ? extends Matrix> before ) {
|
public <V> Function<V, MLMatrix> compose( Function<? super V, ? extends MLMatrix> before ) {
|
||||||
return ( in ) -> apply(before.apply(in));
|
return ( in ) -> apply(before.apply(in));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <V> Function<Matrix, V> andThen( Function<? super Matrix, ? extends V> after ) {
|
public <V> Function<MLMatrix, V> andThen( Function<? super MLMatrix, ? extends V> after ) {
|
||||||
return ( in ) -> after.apply(apply(in));
|
return ( in ) -> after.apply(apply(in));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void backprop( Matrix expected, double learningRate ) {
|
public void backprop( MLMatrix expected, double learningRate ) {
|
||||||
Matrix error, delta, adjustment;
|
MLMatrix error, delta, adjustment;
|
||||||
if( next == null ) {
|
if( next == null ) {
|
||||||
error = expected.duplicate().sub(lastOutput);
|
error = expected.duplicate().sub(lastOutput);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -5,13 +5,10 @@ import org.junit.jupiter.api.Test;
|
|||||||
import schule.ngb.zm.util.Log;
|
import schule.ngb.zm.util.Log;
|
||||||
import schule.ngb.zm.util.Timer;
|
import schule.ngb.zm.util.Timer;
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
|
||||||
|
|
||||||
class NeuralNetworkTest {
|
class NeuralNetworkTest {
|
||||||
|
|
||||||
@BeforeAll
|
@BeforeAll
|
||||||
@@ -79,7 +76,7 @@ class NeuralNetworkTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// calculate predictions
|
// calculate predictions
|
||||||
Matrix predictions = net.predict(inputs);
|
MLMatrix predictions = net.predict(inputs);
|
||||||
for( int i = 0; i < 4; i++ ) {
|
for( int i = 0; i < 4; i++ ) {
|
||||||
int parsed_pred = predictions.get(i, 0) < 0.5 ? 0 : 1;
|
int parsed_pred = predictions.get(i, 0) < 0.5 ? 0 : 1;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user