Colt als optionale Abhängigkeit

DAs Anlernen des NN geht um den Faktor 20 schneller, wenn Colt benutzt wird.
This commit is contained in:
ngb 2022-07-19 20:05:37 +02:00
parent b79f26f51e
commit bf261b5e9b
9 changed files with 735 additions and 205 deletions

View File

@ -28,10 +28,13 @@ dependencies {
runtimeOnly 'com.googlecode.soundlibs:tritonus-share:0.3.7.4' runtimeOnly 'com.googlecode.soundlibs:tritonus-share:0.3.7.4'
runtimeOnly 'com.googlecode.soundlibs:mp3spi:1.9.5.4' runtimeOnly 'com.googlecode.soundlibs:mp3spi:1.9.5.4'
compileOnlyApi 'colt:colt:1.2.0' //compileOnlyApi 'colt:colt:1.2.0'
api 'colt:colt:1.2.0'
//api 'net.sourceforge.parallelcolt:parallelcolt:0.10.1'
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1' testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1' testImplementation 'org.junit.jupiter:junit-jupiter-params:5.8.1'
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
} }
test { test {

View File

@ -1269,7 +1269,8 @@ public class Constants {
} }
/** /**
* Erzeugt eine Pseudozufallszahl nach einer Gaussverteilung. * Erzeugt eine Pseudozufallszahl zwischen -1 und 1 nach einer
* Normalverteilung mit Mittelwert 0 und Standardabweichung 1.
* *
* @return Eine Zufallszahl. * @return Eine Zufallszahl.
* @see Random#nextGaussian() * @see Random#nextGaussian()

View File

@ -4,10 +4,11 @@ import schule.ngb.zm.Constants;
import java.util.Arrays; import java.util.Arrays;
import java.util.function.DoubleUnaryOperator; import java.util.function.DoubleUnaryOperator;
import java.util.stream.IntStream;
// TODO: Move Math into Matrix class // TODO: Move Math into Matrix class
// TODO: Implement support for optional sci libs // TODO: Implement support for optional sci libs
public class DoubleMatrix implements MLMatrix { public final class DoubleMatrix implements MLMatrix {
private int columns, rows; private int columns, rows;
@ -22,7 +23,9 @@ public class DoubleMatrix implements MLMatrix {
public DoubleMatrix( double[][] coefficients ) { public DoubleMatrix( double[][] coefficients ) {
this.rows = coefficients.length; this.rows = coefficients.length;
this.columns = coefficients[0].length; this.columns = coefficients[0].length;
this.coefficients = coefficients; this.coefficients = Arrays.stream(coefficients)
.map(double[]::clone)
.toArray(double[][]::new);
} }
public int columns() { public int columns() {
@ -47,22 +50,133 @@ public class DoubleMatrix implements MLMatrix {
} }
public MLMatrix initializeRandom() { public MLMatrix initializeRandom() {
coefficients = MLMath.matrixApply(coefficients, (d) -> Constants.randomGaussian()); return initializeRandom(-1.0, 1.0);
return this;
} }
public MLMatrix initializeRandom( double lower, double upper ) { public MLMatrix initializeRandom( double lower, double upper ) {
coefficients = MLMath.matrixApply(coefficients, (d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower); applyInPlace((d) -> ((upper-lower) * Constants.random()) + lower);
return this; return this;
} }
public MLMatrix initializeOne() { public MLMatrix initializeOne() {
coefficients = MLMath.matrixApply(coefficients, (d) -> 1.0); applyInPlace((d) -> 1.0);
return this; return this;
} }
public MLMatrix initializeZero() { public MLMatrix initializeZero() {
coefficients = MLMath.matrixApply(coefficients, (d) -> 0.0); applyInPlace((d) -> 0.0);
return this;
}
@Override
public MLMatrix duplicate() {
return new DoubleMatrix(coefficients);
}
@Override
public MLMatrix multiplyTransposed( MLMatrix B ) {
return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
( i ) -> IntStream.range(0, B.rows()).mapToDouble(
( j ) -> IntStream.range(0, columns).mapToDouble(
(k) -> coefficients[i][k]*B.get(j,k)
).sum()
).toArray()
).toArray(double[][]::new));
}
@Override
public MLMatrix multiplyAddBias( final MLMatrix B, final MLMatrix C ) {
return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
( i ) -> IntStream.range(0, B.columns()).mapToDouble(
( j ) -> IntStream.range(0, columns).mapToDouble(
(k) -> coefficients[i][k]*B.get(k,j)
).sum() + C.get(0, j)
).toArray()
).toArray(double[][]::new));
}
@Override
public MLMatrix transposedMultiplyAndScale( final MLMatrix B, final double scalar ) {
return new DoubleMatrix(IntStream.range(0, columns).parallel().mapToObj(
( i ) -> IntStream.range(0, B.columns()).mapToDouble(
( j ) -> IntStream.range(0, rows).mapToDouble(
(k) -> coefficients[k][i]*B.get(k,j)*scalar
).sum()
).toArray()
).toArray(double[][]::new));
}
@Override
public MLMatrix add( MLMatrix B ) {
return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
( i ) -> IntStream.range(0, columns).mapToDouble(
( j ) -> coefficients[i][j] + B.get(i, j)
).toArray()
).toArray(double[][]::new));
}
@Override
public MLMatrix addInPlace( MLMatrix B ) {
coefficients = IntStream.range(0, rows).parallel().mapToObj(
( i ) -> IntStream.range(0, columns).mapToDouble(
( j ) -> coefficients[i][j] + B.get(i, j)
).toArray()
).toArray(double[][]::new);
return this;
}
@Override
public MLMatrix sub( MLMatrix B ) {
return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
( i ) -> IntStream.range(0, columns).mapToDouble(
( j ) -> coefficients[i][j] - B.get(i, j)
).toArray()
).toArray(double[][]::new));
}
@Override
public MLMatrix colSums() {
double[][] sums = new double[1][columns];
for( int c = 0; c < columns; c++ ) {
for( int r = 0; r < rows; r++ ) {
sums[0][c] += coefficients[r][c];
}
}
return new DoubleMatrix(sums);
}
@Override
public MLMatrix scaleInPlace( final double scalar ) {
coefficients = Arrays.stream(coefficients).parallel().map(
( arr ) -> Arrays.stream(arr).map(
(d) -> d * scalar
).toArray()
).toArray(double[][]::new);
return this;
}
@Override
public MLMatrix scaleInPlace( final MLMatrix S ) {
coefficients = IntStream.range(0, coefficients.length).parallel().mapToObj(
( i ) -> IntStream.range(0, coefficients[i].length).mapToDouble(
( j ) -> coefficients[i][j] * S.get(i, j)
).toArray()
).toArray(double[][]::new);
return this;
}
@Override
public MLMatrix apply( DoubleUnaryOperator op ) {
return new DoubleMatrix(Arrays.stream(coefficients).parallel().map(
( arr ) -> Arrays.stream(arr).map(op).toArray()
).toArray(double[][]::new));
}
@Override
public MLMatrix applyInPlace( DoubleUnaryOperator op ) {
this.coefficients = Arrays.stream(coefficients).parallel().map(
( arr ) -> Arrays.stream(arr).map(op).toArray()
).toArray(double[][]::new);
return this; return this;
} }
@ -82,66 +196,4 @@ public class DoubleMatrix implements MLMatrix {
return sb.toString(); return sb.toString();
} }
@Override
public MLMatrix transpose() {
coefficients = MLMath.matrixTranspose(coefficients);
return this;
}
@Override
public MLMatrix multiply( MLMatrix B ) {
coefficients = MLMath.matrixMultiply(coefficients, B.coefficients());
return this;
}
@Override
public MLMatrix multiplyAddBias( MLMatrix B, MLMatrix C ) {
double[] biases = Arrays.stream(C.coefficients()).mapToDouble(( arr) -> arr[0]).toArray();
coefficients = MLMath.biasAdd(
MLMath.matrixMultiply(coefficients, B.coefficients()),
biases
);
return this;
}
@Override
public MLMatrix multiplyLeft( MLMatrix B ) {
coefficients = MLMath.matrixMultiply(B.coefficients(), coefficients);
return this;
}
@Override
public MLMatrix add( MLMatrix B ) {
coefficients = MLMath.matrixAdd(coefficients, B.coefficients());
return this;
}
@Override
public MLMatrix sub( MLMatrix B ) {
coefficients = MLMath.matrixSub(coefficients, B.coefficients());
return this;
}
@Override
public MLMatrix scale( double scalar ) {
return this;
}
@Override
public MLMatrix scale( MLMatrix S ) {
coefficients = MLMath.matrixScale(coefficients, S.coefficients());
return this;
}
@Override
public MLMatrix apply( DoubleUnaryOperator op ) {
this.coefficients = MLMath.matrixApply(coefficients, op);
return this;
}
@Override
public MLMatrix duplicate() {
return new DoubleMatrix(MLMath.copyMatrix(coefficients));
}
} }

View File

@ -22,29 +22,123 @@ public interface MLMatrix {
MLMatrix initializeZero(); MLMatrix initializeZero();
//MLMatrix transpose();
//MLMatrix multiply( MLMatrix B );
/** /**
* Erzeugt die transponierte Matrix zu dieser. * Erzeugt eine neue Matrix <em>C</em> mit dem Ergebnis der Matrixoperation
* <pre>
* C = A.B + V
* </pre>
* wobei <em>A</em> dieses Matrixobjekt ist und {@code .} für die
* Matrixmultiplikation steht.
*
* @param B
* @param V
* @return * @return
*/ */
MLMatrix transpose(); MLMatrix multiplyAddBias( MLMatrix B, MLMatrix V );
MLMatrix multiply( MLMatrix B ); /**
* Erzeugt eine neue Matrix <em>C</em> mit dem Ergebnis der Matrixoperation
* <pre>
* C = A.t(B)
* </pre>
* wobei <em>A</em> dieses Matrixobjekt ist und {@code t(B)} für die
* Transposition der Matrix <em>B</em>> steht.
*
* @param B
* @return
*/
MLMatrix multiplyTransposed( MLMatrix B );
MLMatrix multiplyAddBias( MLMatrix B, MLMatrix C ); MLMatrix transposedMultiplyAndScale( MLMatrix B, double scalar );
MLMatrix multiplyLeft( MLMatrix B );
/**
* Erzeugt eine neue Matrix <em>C</em> mit dem Ergebnis der
* komponentenweisen Matrix-Addition
* <pre>
* C = A+B
* </pre>
* wobei <em>A</em> dieses Matrixobjekt ist. Für ein Element
* <em>C_ij</em> in <em>C</em> gilt
* <pre>
* C_ij = A_ij + B_ij
* </pre>
*
* @param B Die zweite Matrix.
* @return Ein neues Matrixobjekt mit dem Ergebnis.
*/
MLMatrix add( MLMatrix B ); MLMatrix add( MLMatrix B );
/**
* Setzt dies Matrix auf das Ergebnis der
* komponentenweisen Matrix-Addition
* <pre>
* A = A+B
* </pre>
* wobei <em>A</em> dieses Matrixobjekt ist. Für ein Element
* <em>A_ij</em> in <em>A</em> gilt
* <pre>
* A_ij = A_ij + B_ij
* </pre>
*
* @param B Die zweite Matrix.
* @return Diese Matrix selbst (method chaining).
*/
MLMatrix addInPlace( MLMatrix B );
/**
* Erzeugt eine neue Matrix <em>C</em> mit dem Ergebnis der
* komponentenweisen Matrix-Subtraktion
* <pre>
* C = A-B
* </pre>
* wobei <em>A</em> dieses Matrixobjekt ist. Für ein Element
* <em>C_ij</em> in <em>C</em> gilt
* <pre>
* C_ij = A_ij - B_ij
* </pre>
*
* @param B
* @return
*/
MLMatrix sub( MLMatrix B ); MLMatrix sub( MLMatrix B );
MLMatrix scaleInPlace( double scalar );
MLMatrix scale( double scalar ); MLMatrix scaleInPlace( MLMatrix S );
MLMatrix scale( MLMatrix S ); /**
* Berechnet eine neue Matrix mit nur einer Zeile, die die Spaltensummen
* dieser Matrix enthalten.
* @return
*/
MLMatrix colSums();
/**
* Endet die gegebene Funktion auf jeden Wert der Matrix an.
*
* @param op
* @return
*/
MLMatrix apply( DoubleUnaryOperator op ); MLMatrix apply( DoubleUnaryOperator op );
/**
* Endet die gegebene Funktion auf jeden Wert der Matrix an.
*
* @param op
* @return
*/
MLMatrix applyInPlace( DoubleUnaryOperator op );
/**
* Erzeugt eine neue Matrix mit denselben Dimenstionen und Koeffizienten wie
* diese Matrix.
*
* @return
*/
MLMatrix duplicate(); MLMatrix duplicate();
String toString(); String toString();

View File

@ -1,7 +1,6 @@
package schule.ngb.zm.ml; package schule.ngb.zm.ml;
import cern.colt.matrix.DoubleMatrix2D; import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import schule.ngb.zm.Constants; import schule.ngb.zm.Constants;
import schule.ngb.zm.util.Log; import schule.ngb.zm.util.Log;
@ -11,7 +10,7 @@ public class MatrixFactory {
public static void main( String[] args ) { public static void main( String[] args ) {
System.out.println( System.out.println(
MatrixFactory.create(new double[][]{ {1.0, 0.0}, {0.0, 1.0} }).toString() MatrixFactory.create(new double[][]{{1.0, 0.0}, {0.0, 1.0}}).toString()
); );
} }
@ -26,14 +25,14 @@ public class MatrixFactory {
public static final MLMatrix create( double[][] values ) { public static final MLMatrix create( double[][] values ) {
try { try {
return getMatrixType().getDeclaredConstructor(double[][].class).newInstance((Object)values); return getMatrixType().getDeclaredConstructor(double[][].class).newInstance((Object) values);
} catch( Exception ex ) { } catch( Exception ex ) {
LOG.error(ex, "Could not initialize matrix implementation for class <%s>. Using internal implementation.", matrixType); LOG.error(ex, "Could not initialize matrix implementation for class <%s>. Using internal implementation.", matrixType);
} }
return new DoubleMatrix(values); return new DoubleMatrix(values);
} }
private static Class<? extends MLMatrix> matrixType = null; static Class<? extends MLMatrix> matrixType = null;
private static final Class<? extends MLMatrix> getMatrixType() { private static final Class<? extends MLMatrix> getMatrixType() {
if( matrixType == null ) { if( matrixType == null ) {
@ -63,6 +62,10 @@ public class MatrixFactory {
matrix = new cern.colt.matrix.impl.DenseDoubleMatrix2D(rows, cols); matrix = new cern.colt.matrix.impl.DenseDoubleMatrix2D(rows, cols);
} }
public ColtMatrix( ColtMatrix matrix ) {
this.matrix = matrix.matrix.copy();
}
@Override @Override
public int columns() { public int columns() {
return matrix.columns(); return matrix.columns();
@ -91,13 +94,12 @@ public class MatrixFactory {
@Override @Override
public MLMatrix initializeRandom() { public MLMatrix initializeRandom() {
matrix.assign((d) -> Constants.randomGaussian()); return initializeRandom(-1.0, 1.0);
return this;
} }
@Override @Override
public MLMatrix initializeRandom( double lower, double upper ) { public MLMatrix initializeRandom( double lower, double upper ) {
matrix.assign((d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower); matrix.assign(( d ) -> ((upper - lower) * Constants.random()) + lower);
return this; return this;
} }
@ -114,72 +116,97 @@ public class MatrixFactory {
} }
@Override @Override
public MLMatrix apply( DoubleUnaryOperator op ) { public MLMatrix duplicate() {
this.matrix.assign((d) -> op.applyAsDouble(d)); ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
return this; newMatrix.matrix.assign(this.matrix);
return newMatrix;
} }
@Override @Override
public MLMatrix transpose() { public MLMatrix multiplyTransposed( MLMatrix B ) {
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.transpose(this.matrix); ColtMatrix CB = (ColtMatrix) B;
return this; ColtMatrix newMatrix = new ColtMatrix(0, 0);
} newMatrix.matrix = matrix.zMult(CB.matrix, null, 1.0, 0.0, false, true);
return newMatrix;
@Override
public MLMatrix multiply( MLMatrix B ) {
ColtMatrix CB = (ColtMatrix)B;
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(matrix, CB.matrix);
return this;
}
@Override
public MLMatrix multiplyLeft( MLMatrix B ) {
ColtMatrix CB = (ColtMatrix)B;
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(CB.matrix, matrix);
return this;
} }
@Override @Override
public MLMatrix multiplyAddBias( MLMatrix B, MLMatrix C ) { public MLMatrix multiplyAddBias( MLMatrix B, MLMatrix C ) {
ColtMatrix CB = (ColtMatrix)B; ColtMatrix CB = (ColtMatrix) B;
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(matrix, CB.matrix); ColtMatrix newMatrix = new ColtMatrix(0, 0);
// TODO: add bias newMatrix.matrix = DoubleFactory2D.dense.repeat(((ColtMatrix) C).matrix, rows(), 1);
return this; matrix.zMult(CB.matrix, newMatrix.matrix, 1.0, 1.0, false, false);
return newMatrix;
}
@Override
public MLMatrix transposedMultiplyAndScale( final MLMatrix B, final double scalar ) {
ColtMatrix CB = (ColtMatrix) B;
ColtMatrix newMatrix = new ColtMatrix(0, 0);
newMatrix.matrix = matrix.zMult(CB.matrix, null, scalar, 0.0, true, false);
return newMatrix;
} }
@Override @Override
public MLMatrix add( MLMatrix B ) { public MLMatrix add( MLMatrix B ) {
ColtMatrix CB = (ColtMatrix)B; ColtMatrix CB = (ColtMatrix) B;
matrix.assign(CB.matrix, (d1,d2) -> d1+d2); ColtMatrix newMatrix = new ColtMatrix(this);
newMatrix.matrix.assign(CB.matrix, ( d1, d2 ) -> d1 + d2);
return newMatrix;
}
@Override
public MLMatrix addInPlace( MLMatrix B ) {
ColtMatrix CB = (ColtMatrix) B;
matrix.assign(CB.matrix, ( d1, d2 ) -> d1 + d2);
return this; return this;
} }
@Override @Override
public MLMatrix sub( MLMatrix B ) { public MLMatrix sub( MLMatrix B ) {
ColtMatrix CB = (ColtMatrix)B; ColtMatrix CB = (ColtMatrix) B;
matrix.assign(CB.matrix, (d1,d2) -> d1-d2); ColtMatrix newMatrix = new ColtMatrix(this);
return this; newMatrix.matrix.assign(CB.matrix, ( d1, d2 ) -> d1 - d2);
}
@Override
public MLMatrix scale( double scalar ) {
this.matrix.assign((d) -> d*scalar);
return this;
}
@Override
public MLMatrix scale( MLMatrix S ) {
this.matrix.forEachNonZero((r, c, d) -> d * S.get(r, c));
return this;
}
@Override
public MLMatrix duplicate() {
ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
newMatrix.matrix.assign(this.matrix);
return newMatrix; return newMatrix;
} }
@Override
public MLMatrix colSums() {
double[][] sums = new double[1][matrix.columns()];
for( int c = 0; c < matrix.columns(); c++ ) {
for( int r = 0; r < matrix.rows(); r++ ) {
sums[0][c] += matrix.getQuick(r, c);
}
}
return new ColtMatrix(sums);
}
@Override
public MLMatrix scaleInPlace( double scalar ) {
this.matrix.assign(( d ) -> d * scalar);
return this;
}
@Override
public MLMatrix scaleInPlace( MLMatrix S ) {
this.matrix.forEachNonZero(( r, c, d ) -> d * S.get(r, c));
return this;
}
@Override
public MLMatrix apply( DoubleUnaryOperator op ) {
ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
newMatrix.matrix.assign(matrix);
newMatrix.matrix.assign(( d ) -> op.applyAsDouble(d));
return newMatrix;
}
@Override
public MLMatrix applyInPlace( DoubleUnaryOperator op ) {
this.matrix.assign(( d ) -> op.applyAsDouble(d));
return this;
}
@Override @Override
public String toString() { public String toString() {
return matrix.toString(); return matrix.toString();

View File

@ -41,10 +41,10 @@ public class NeuronLayer implements Function<MLMatrix, MLMatrix> {
public NeuronLayer( int neurons, int inputs ) { public NeuronLayer( int neurons, int inputs ) {
weights = MatrixFactory weights = MatrixFactory
.create(inputs, neurons) .create(inputs, neurons)
.initializeRandom(-1, 1); .initializeRandom();
biases = MatrixFactory biases = MatrixFactory
.create(neurons, 1) .create(1, neurons)
.initializeZero(); .initializeZero();
activationFunction = MLMath::sigmoid; activationFunction = MLMath::sigmoid;
@ -110,10 +110,6 @@ public class NeuronLayer implements Function<MLMatrix, MLMatrix> {
weights = newWeights.duplicate(); weights = newWeights.duplicate();
} }
public void adjustWeights( MLMatrix adjustment ) {
weights.add(adjustment);
}
@Override @Override
public String toString() { public String toString() {
return weights.toString() + "\n" + biases.toString(); return weights.toString() + "\n" + biases.toString();
@ -121,10 +117,10 @@ public class NeuronLayer implements Function<MLMatrix, MLMatrix> {
@Override @Override
public MLMatrix apply( MLMatrix inputs ) { public MLMatrix apply( MLMatrix inputs ) {
lastInput = inputs; lastInput = inputs.duplicate();
lastOutput = inputs lastOutput = inputs
.multiplyAddBias(weights, biases) .multiplyAddBias(weights, biases)
.apply(activationFunction); .applyInPlace(activationFunction);
if( next != null ) { if( next != null ) {
return next.apply(lastOutput); return next.apply(lastOutput);
@ -144,14 +140,16 @@ public class NeuronLayer implements Function<MLMatrix, MLMatrix> {
} }
public void backprop( MLMatrix expected, double learningRate ) { public void backprop( MLMatrix expected, double learningRate ) {
MLMatrix error, delta, adjustment; MLMatrix error, adjustment;
if( next == null ) { if( next == null ) {
error = expected.duplicate().sub(lastOutput); error = expected.sub(lastOutput);
} else { } else {
error = expected.duplicate().multiply(next.weights.transpose()); error = expected.multiplyTransposed(next.weights);
} }
error.scale(lastOutput.duplicate().apply(this.activationFunctionDerivative)); error.scaleInPlace(
lastOutput.apply(this.activationFunctionDerivative)
);
// Hier schon leraningRate anwenden? // Hier schon leraningRate anwenden?
// See https://towardsdatascience.com/understanding-and-implementing-neural-networks-in-java-from-scratch-61421bb6352c // See https://towardsdatascience.com/understanding-and-implementing-neural-networks-in-java-from-scratch-61421bb6352c
//delta = MLMath.matrixApply(delta, ( x ) -> learningRate * x); //delta = MLMath.matrixApply(delta, ( x ) -> learningRate * x);
@ -159,10 +157,14 @@ public class NeuronLayer implements Function<MLMatrix, MLMatrix> {
previous.backprop(error, learningRate); previous.backprop(error, learningRate);
} }
// biases = MLMath.biasAdjust(biases, MLMath.matrixApply(delta, ( x ) -> learningRate * x)); biases.addInPlace(
error.colSums().scaleInPlace(
-learningRate / (double) error.rows()
)
);
adjustment = lastInput.duplicate().transpose().multiply(error).apply((d) -> learningRate*d); adjustment = lastInput.transposedMultiplyAndScale(error, learningRate);
this.adjustWeights(adjustment); weights.addInPlace(adjustment);
} }
} }

View File

@ -1,46 +0,0 @@
package schule.ngb.zm.ml;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.*;
class DoubleMatrixTest {
@Test
void initializeOne() {
DoubleMatrix m = new DoubleMatrix(4, 4);
m.initializeOne();
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0};
assertArrayEquals(ones, m.coefficients[0]);
assertArrayEquals(ones, m.coefficients[1]);
assertArrayEquals(ones, m.coefficients[2]);
assertArrayEquals(ones, m.coefficients[3]);
}
@Test
void initializeZero() {
DoubleMatrix m = new DoubleMatrix(4, 4);
m.initializeZero();
double[] zeros = new double[]{0.0, 0.0, 0.0, 0.0};
assertArrayEquals(zeros, m.coefficients[0]);
assertArrayEquals(zeros, m.coefficients[1]);
assertArrayEquals(zeros, m.coefficients[2]);
assertArrayEquals(zeros, m.coefficients[3]);
}
@Test
void initializeRandom() {
DoubleMatrix m = new DoubleMatrix(4, 4);
m.initializeRandom(-1, 1);
assertTrue(Arrays.stream(m.coefficients[0]).allMatch((d) -> -1.0 <= d && d < 1.0));
assertTrue(Arrays.stream(m.coefficients[1]).allMatch((d) -> -1.0 <= d && d < 1.0));
assertTrue(Arrays.stream(m.coefficients[2]).allMatch((d) -> -1.0 <= d && d < 1.0));
assertTrue(Arrays.stream(m.coefficients[3]).allMatch((d) -> -1.0 <= d && d < 1.0));
}
}

View File

@ -0,0 +1,375 @@
package schule.ngb.zm.ml;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import static org.junit.jupiter.api.Assertions.*;
class MLMatrixTest {
private TestInfo info;
@BeforeEach
void saveTestInfo( TestInfo info ) {
this.info = info;
}
@ParameterizedTest
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
void initializeOne( Class<? extends MLMatrix> mType ) {
MatrixFactory.matrixType = mType;
MLMatrix m = MatrixFactory.create(4, 4);
m.initializeOne();
assertEquals(mType, m.getClass());
for( int i = 0; i < m.rows(); i++ ) {
for( int j = 0; j < m.columns(); j++ ) {
assertEquals(1.0, m.get(i, j));
}
}
}
@ParameterizedTest
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
void initializeZero( Class<? extends MLMatrix> mType ) {
MatrixFactory.matrixType = mType;
MLMatrix m = MatrixFactory.create(4, 4);
m.initializeZero();
assertEquals(mType, m.getClass());
for( int i = 0; i < m.rows(); i++ ) {
for( int j = 0; j < m.columns(); j++ ) {
assertEquals(0.0, m.get(i, j));
}
}
}
@ParameterizedTest
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
void initializeRandom( Class<? extends MLMatrix> mType ) {
MatrixFactory.matrixType = mType;
MLMatrix m = MatrixFactory.create(4, 4);
m.initializeRandom();
assertEquals(mType, m.getClass());
for( int i = 0; i < m.rows(); i++ ) {
for( int j = 0; j < m.columns(); j++ ) {
double d = m.get(i, j);
assertTrue(-1.0 <= d && d < 1.0);
}
}
}
@ParameterizedTest
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
void multiplyTransposed( Class<? extends MLMatrix> mType ) {
MatrixFactory.matrixType = mType;
MLMatrix A = MatrixFactory.create(new double[][]{
{1.0, 2.0, 3.0, 4.0},
{1.0, 2.0, 3.0, 4.0},
{1.0, 2.0, 3.0, 4.0}
});
MLMatrix B = MatrixFactory.create(new double[][]{
{1, 3, 5, 7},
{2, 4, 6, 8}
});
MLMatrix C = A.multiplyTransposed(B);
assertEquals(mType, A.getClass());
assertEquals(mType, B.getClass());
assertEquals(mType, C.getClass());
assertEquals(3, C.rows());
assertEquals(2, C.columns());
for( int i = 0; i < C.rows(); i++ ) {
assertEquals(50.0, C.get(i, 0));
assertEquals(60.0, C.get(i, 1));
}
}
@ParameterizedTest
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
void multiplyAddBias( Class<? extends MLMatrix> mType ) {
MatrixFactory.matrixType = mType;
MLMatrix A = MatrixFactory.create(new double[][]{
{1.0, 2.0, 3.0, 4.0},
{1.0, 2.0, 3.0, 4.0},
{1.0, 2.0, 3.0, 4.0}
});
MLMatrix B = MatrixFactory.create(new double[][]{
{1.0, 2.0},
{3.0, 4.0},
{5.0, 6.0},
{7.0, 8.0}
});
MLMatrix V = MatrixFactory.create(new double[][]{
{1000.0, 2000.0}
});
MLMatrix C = A.multiplyAddBias(B, V);
assertEquals(mType, A.getClass());
assertEquals(mType, B.getClass());
assertEquals(mType, C.getClass());
assertEquals(mType, V.getClass());
assertEquals(3, C.rows());
assertEquals(2, C.columns());
for( int i = 0; i < C.rows(); i++ ) {
assertEquals(1050.0, C.get(i, 0));
assertEquals(2060.0, C.get(i, 1));
}
}
@ParameterizedTest
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
void transposedMultiplyAndScale( Class<? extends MLMatrix> mType ) {
MatrixFactory.matrixType = mType;
MLMatrix A = MatrixFactory.create(new double[][]{
{1, 1, 1},
{2, 2, 2},
{3, 3, 3},
{4, 4, 4}
});
MLMatrix B = MatrixFactory.create(new double[][]{
{1.0, 2.0},
{3.0, 4.0},
{5.0, 6.0},
{7.0, 8.0}
});
MLMatrix C = A.transposedMultiplyAndScale(B, 2.0);
assertEquals(mType, A.getClass());
assertEquals(mType, B.getClass());
assertEquals(mType, C.getClass());
assertEquals(3, C.rows());
assertEquals(2, C.columns());
for( int i = 0; i < C.rows(); i++ ) {
assertEquals(100.0, C.get(i, 0));
assertEquals(120.0, C.get(i, 1));
}
}
@ParameterizedTest
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
void apply( Class<? extends MLMatrix> mType ) {
MatrixFactory.matrixType = mType;
MLMatrix M = MatrixFactory.create(new double[][]{
{1, 1, 1},
{2, 2, 2},
{3, 3, 3},
{4, 4, 4}
});
MLMatrix R = M.apply(( d ) -> d * d);
assertEquals(mType, M.getClass());
assertEquals(mType, R.getClass());
assertNotSame(M, R);
for( int i = 0; i < M.rows(); i++ ) {
for( int j = 0; j < M.columns(); j++ ) {
assertEquals(
(i + 1) * (i + 1), R.get(i, j),
msg("(%d,%d)", "apply", i, j)
);
}
}
MLMatrix M2 = M.applyInPlace(( d ) -> d * d * d);
assertSame(M, M2);
for( int i = 0; i < M.rows(); i++ ) {
for( int j = 0; j < M.columns(); j++ ) {
assertEquals(
(i + 1) * (i + 1) * (i + 1), M.get(i, j),
msg("(%d,%d)", "applyInPlace", i, j)
);
}
}
}
@ParameterizedTest
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
void add( Class<? extends MLMatrix> mType ) {
MatrixFactory.matrixType = mType;
MLMatrix M = MatrixFactory.create(new double[][]{
{1, 1, 1},
{2, 2, 2},
{3, 3, 3},
{4, 4, 4}
});
MLMatrix R = M.add(M);
assertEquals(mType, M.getClass());
assertEquals(mType, R.getClass());
assertNotSame(M, R);
for( int i = 0; i < M.rows(); i++ ) {
for( int j = 0; j < M.columns(); j++ ) {
assertEquals(
(i + 1) + (i + 1), R.get(i, j),
msg("(%d,%d)", "add", i, j)
);
}
}
MLMatrix M2 = M.addInPlace(R);
assertSame(M, M2);
for( int i = 0; i < M.rows(); i++ ) {
for( int j = 0; j < M.columns(); j++ ) {
assertEquals(
(i + 1) + (i + 1) + (i + 1), M.get(i, j),
msg("(%d,%d)", "addInPlace", i, j)
);
}
}
}
@ParameterizedTest
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
void sub( Class<? extends MLMatrix> mType ) {
MatrixFactory.matrixType = mType;
MLMatrix M = MatrixFactory.create(new double[][]{
{1, 1, 1},
{2, 2, 2},
{3, 3, 3},
{4, 4, 4}
});
MLMatrix R = M.sub(M);
assertEquals(mType, M.getClass());
assertEquals(mType, R.getClass());
assertNotSame(M, R);
for( int i = 0; i < M.rows(); i++ ) {
for( int j = 0; j < M.columns(); j++ ) {
assertEquals(
0.0, R.get(i, j),
msg("(%d,%d)", "sub", i, j)
);
}
}
}
@ParameterizedTest
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
void colSums( Class<? extends MLMatrix> mType ) {
MatrixFactory.matrixType = mType;
MLMatrix M = MatrixFactory.create(new double[][]{
{1, 2, 3},
{1, 2, 3},
{1, 2, 3},
{1, 2, 3}
});
MLMatrix R = M.colSums();
assertEquals(mType, M.getClass());
assertEquals(mType, R.getClass());
assertNotSame(M, R);
assertEquals(1, R.rows());
assertEquals(3, R.columns());
for( int j = 0; j < M.columns(); j++ ) {
assertEquals(
(j+1)*4, R.get(0, j),
msg("(%d,%d)", "colSums", 0, j)
);
}
}
@ParameterizedTest
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
void duplicate( Class<? extends MLMatrix> mType ) {
MatrixFactory.matrixType = mType;
MLMatrix M = MatrixFactory.create(new double[][]{
{1, 2, 3},
{1, 2, 3},
{1, 2, 3},
{1, 2, 3}
});
MLMatrix R = M.duplicate();
assertEquals(mType, M.getClass());
assertEquals(mType, R.getClass());
assertNotSame(M, R);
for( int i = 0; i < M.rows(); i++ ) {
for( int j = 0; j < M.columns(); j++ ) {
assertEquals(
M.get(i, j), R.get(i, j),
msg("(%d,%d)", "duplicate", i, j)
);
}
}
}
@ParameterizedTest
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
void scale( Class<? extends MLMatrix> mType ) {
MatrixFactory.matrixType = mType;
MLMatrix M = MatrixFactory.create(new double[][]{
{1, 1, 1},
{2, 2, 2},
{3, 3, 3},
{4, 4, 4}
});
MLMatrix M2 = M.scaleInPlace(2.0);
assertEquals(mType, M.getClass());
assertEquals(mType, M2.getClass());
assertSame(M, M2);
for( int i = 0; i < M.rows(); i++ ) {
for( int j = 0; j < M.columns(); j++ ) {
assertEquals(
(i+1)*2.0, M2.get(i, j),
msg("(%d,%d)", "scaleInPlace", i, j)
);
}
}
MLMatrix M3 = M.scaleInPlace(M);
assertSame(M, M3);
for( int i = 0; i < M.rows(); i++ ) {
for( int j = 0; j < M.columns(); j++ ) {
assertEquals(
((i+1)*2.0)*((i+1)*2.0), M.get(i, j),
msg("(%d,%d)", "addInPlace", i, j)
);
}
}
}
private String msg( String msg, String methodName, Object... args ) {
String testName = this.info.getTestMethod().get().getName();
String className = MatrixFactory.matrixType.getSimpleName();
return String.format("[" + testName + "(" + className + ") " + methodName + "()] " + msg, args);
}
}

View File

@ -2,6 +2,7 @@ package schule.ngb.zm.ml;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import schule.ngb.zm.Constants;
import schule.ngb.zm.util.Log; import schule.ngb.zm.util.Log;
import schule.ngb.zm.util.Timer; import schule.ngb.zm.util.Timer;
@ -16,6 +17,13 @@ class NeuralNetworkTest {
Log.enableGlobalDebugging(); Log.enableGlobalDebugging();
} }
@BeforeAll
static void setupMatrixLibrary() {
Constants.setSeed(1001);
MatrixFactory.matrixType = MatrixFactory.ColtMatrix.class;
//MatrixFactory.matrixType = DoubleMatrix.class;
}
/*@Test /*@Test
void readWrite() { void readWrite() {
// XOR Dataset // XOR Dataset
@ -107,7 +115,7 @@ class NeuralNetworkTest {
for( int i = 0; i < trainingData.size(); i++ ) { for( int i = 0; i < trainingData.size(); i++ ) {
inputs[i][0] = trainingData.get(i).a; inputs[i][0] = trainingData.get(i).a;
inputs[i][1] = trainingData.get(i).b; inputs[i][1] = trainingData.get(i).b;
outputs[i][0] = trainingData.get(i).result; outputs[i][0] = trainingData.get(i).getResult();
} }
Timer timer = new Timer(); Timer timer = new Timer();
@ -139,8 +147,8 @@ class NeuralNetworkTest {
"Prediction on data (%.2f, %.2f) was %.4f, expected %.2f (of by %.4f)\n", "Prediction on data (%.2f, %.2f) was %.4f, expected %.2f (of by %.4f)\n",
data.a, data.b, data.a, data.b,
net.getOutput().get(0, 0), net.getOutput().get(0, 0),
data.result, data.getResult(),
net.getOutput().get(0, 0) - data.result net.getOutput().get(0, 0) - data.getResult()
); );
} }
@ -183,7 +191,6 @@ class NeuralNetworkTest {
double a; double a;
double b; double b;
double result;
CalcType type; CalcType type;
TestData( double a, double b ) { TestData( double a, double b ) {
@ -191,6 +198,8 @@ class NeuralNetworkTest {
this.b = b; this.b = b;
} }
abstract double getResult();
} }
private static final class AddData extends TestData { private static final class AddData extends TestData {
@ -199,7 +208,9 @@ class NeuralNetworkTest {
public AddData( double a, double b ) { public AddData( double a, double b ) {
super(a, b); super(a, b);
result = a + b; }
double getResult() {
return a+b;
} }
} }
@ -210,7 +221,9 @@ class NeuralNetworkTest {
public SubData( double a, double b ) { public SubData( double a, double b ) {
super(a, b); super(a, b);
result = a - b; }
double getResult() {
return a-b;
} }
} }
@ -221,7 +234,9 @@ class NeuralNetworkTest {
public MulData( double a, double b ) { public MulData( double a, double b ) {
super(a, b); super(a, b);
result = a * b; }
double getResult() {
return a*b;
} }
} }
@ -235,7 +250,9 @@ class NeuralNetworkTest {
if( b == 0.0 ) { if( b == 0.0 ) {
b = .1; b = .1;
} }
result = a / b; }
double getResult() {
return a/b;
} }
} }
@ -246,7 +263,12 @@ class NeuralNetworkTest {
public ModData( double b, double a ) { public ModData( double b, double a ) {
super(b, a); super(b, a);
result = a % b; if( b == 0.0 ) {
b = .1;
}
}
double getResult() {
return a%b;
} }
} }