Colt als optionale Abhängigkeit
DAs Anlernen des NN geht um den Faktor 20 schneller, wenn Colt benutzt wird.
This commit is contained in:
parent
b79f26f51e
commit
bf261b5e9b
|
@ -28,10 +28,13 @@ dependencies {
|
|||
runtimeOnly 'com.googlecode.soundlibs:tritonus-share:0.3.7.4'
|
||||
runtimeOnly 'com.googlecode.soundlibs:mp3spi:1.9.5.4'
|
||||
|
||||
compileOnlyApi 'colt:colt:1.2.0'
|
||||
//compileOnlyApi 'colt:colt:1.2.0'
|
||||
api 'colt:colt:1.2.0'
|
||||
//api 'net.sourceforge.parallelcolt:parallelcolt:0.10.1'
|
||||
|
||||
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
|
||||
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
|
||||
testImplementation 'org.junit.jupiter:junit-jupiter-params:5.8.1'
|
||||
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
|
||||
}
|
||||
|
||||
test {
|
||||
|
|
|
@ -1269,7 +1269,8 @@ public class Constants {
|
|||
}
|
||||
|
||||
/**
|
||||
* Erzeugt eine Pseudozufallszahl nach einer Gaussverteilung.
|
||||
* Erzeugt eine Pseudozufallszahl zwischen -1 und 1 nach einer
|
||||
* Normalverteilung mit Mittelwert 0 und Standardabweichung 1.
|
||||
*
|
||||
* @return Eine Zufallszahl.
|
||||
* @see Random#nextGaussian()
|
||||
|
|
|
@ -4,10 +4,11 @@ import schule.ngb.zm.Constants;
|
|||
|
||||
import java.util.Arrays;
|
||||
import java.util.function.DoubleUnaryOperator;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
// TODO: Move Math into Matrix class
|
||||
// TODO: Implement support for optional sci libs
|
||||
public class DoubleMatrix implements MLMatrix {
|
||||
public final class DoubleMatrix implements MLMatrix {
|
||||
|
||||
private int columns, rows;
|
||||
|
||||
|
@ -22,7 +23,9 @@ public class DoubleMatrix implements MLMatrix {
|
|||
public DoubleMatrix( double[][] coefficients ) {
|
||||
this.rows = coefficients.length;
|
||||
this.columns = coefficients[0].length;
|
||||
this.coefficients = coefficients;
|
||||
this.coefficients = Arrays.stream(coefficients)
|
||||
.map(double[]::clone)
|
||||
.toArray(double[][]::new);
|
||||
}
|
||||
|
||||
public int columns() {
|
||||
|
@ -47,22 +50,133 @@ public class DoubleMatrix implements MLMatrix {
|
|||
}
|
||||
|
||||
public MLMatrix initializeRandom() {
|
||||
coefficients = MLMath.matrixApply(coefficients, (d) -> Constants.randomGaussian());
|
||||
return this;
|
||||
return initializeRandom(-1.0, 1.0);
|
||||
}
|
||||
|
||||
public MLMatrix initializeRandom( double lower, double upper ) {
|
||||
coefficients = MLMath.matrixApply(coefficients, (d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower);
|
||||
applyInPlace((d) -> ((upper-lower) * Constants.random()) + lower);
|
||||
return this;
|
||||
}
|
||||
|
||||
public MLMatrix initializeOne() {
|
||||
coefficients = MLMath.matrixApply(coefficients, (d) -> 1.0);
|
||||
applyInPlace((d) -> 1.0);
|
||||
return this;
|
||||
}
|
||||
|
||||
public MLMatrix initializeZero() {
|
||||
coefficients = MLMath.matrixApply(coefficients, (d) -> 0.0);
|
||||
applyInPlace((d) -> 0.0);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix duplicate() {
|
||||
return new DoubleMatrix(coefficients);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix multiplyTransposed( MLMatrix B ) {
|
||||
return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
|
||||
( i ) -> IntStream.range(0, B.rows()).mapToDouble(
|
||||
( j ) -> IntStream.range(0, columns).mapToDouble(
|
||||
(k) -> coefficients[i][k]*B.get(j,k)
|
||||
).sum()
|
||||
).toArray()
|
||||
).toArray(double[][]::new));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix multiplyAddBias( final MLMatrix B, final MLMatrix C ) {
|
||||
return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
|
||||
( i ) -> IntStream.range(0, B.columns()).mapToDouble(
|
||||
( j ) -> IntStream.range(0, columns).mapToDouble(
|
||||
(k) -> coefficients[i][k]*B.get(k,j)
|
||||
).sum() + C.get(0, j)
|
||||
).toArray()
|
||||
).toArray(double[][]::new));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix transposedMultiplyAndScale( final MLMatrix B, final double scalar ) {
|
||||
return new DoubleMatrix(IntStream.range(0, columns).parallel().mapToObj(
|
||||
( i ) -> IntStream.range(0, B.columns()).mapToDouble(
|
||||
( j ) -> IntStream.range(0, rows).mapToDouble(
|
||||
(k) -> coefficients[k][i]*B.get(k,j)*scalar
|
||||
).sum()
|
||||
).toArray()
|
||||
).toArray(double[][]::new));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix add( MLMatrix B ) {
|
||||
return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
|
||||
( i ) -> IntStream.range(0, columns).mapToDouble(
|
||||
( j ) -> coefficients[i][j] + B.get(i, j)
|
||||
).toArray()
|
||||
).toArray(double[][]::new));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix addInPlace( MLMatrix B ) {
|
||||
coefficients = IntStream.range(0, rows).parallel().mapToObj(
|
||||
( i ) -> IntStream.range(0, columns).mapToDouble(
|
||||
( j ) -> coefficients[i][j] + B.get(i, j)
|
||||
).toArray()
|
||||
).toArray(double[][]::new);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix sub( MLMatrix B ) {
|
||||
return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
|
||||
( i ) -> IntStream.range(0, columns).mapToDouble(
|
||||
( j ) -> coefficients[i][j] - B.get(i, j)
|
||||
).toArray()
|
||||
).toArray(double[][]::new));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix colSums() {
|
||||
double[][] sums = new double[1][columns];
|
||||
for( int c = 0; c < columns; c++ ) {
|
||||
for( int r = 0; r < rows; r++ ) {
|
||||
sums[0][c] += coefficients[r][c];
|
||||
}
|
||||
}
|
||||
return new DoubleMatrix(sums);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix scaleInPlace( final double scalar ) {
|
||||
coefficients = Arrays.stream(coefficients).parallel().map(
|
||||
( arr ) -> Arrays.stream(arr).map(
|
||||
(d) -> d * scalar
|
||||
).toArray()
|
||||
).toArray(double[][]::new);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix scaleInPlace( final MLMatrix S ) {
|
||||
coefficients = IntStream.range(0, coefficients.length).parallel().mapToObj(
|
||||
( i ) -> IntStream.range(0, coefficients[i].length).mapToDouble(
|
||||
( j ) -> coefficients[i][j] * S.get(i, j)
|
||||
).toArray()
|
||||
).toArray(double[][]::new);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix apply( DoubleUnaryOperator op ) {
|
||||
return new DoubleMatrix(Arrays.stream(coefficients).parallel().map(
|
||||
( arr ) -> Arrays.stream(arr).map(op).toArray()
|
||||
).toArray(double[][]::new));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix applyInPlace( DoubleUnaryOperator op ) {
|
||||
this.coefficients = Arrays.stream(coefficients).parallel().map(
|
||||
( arr ) -> Arrays.stream(arr).map(op).toArray()
|
||||
).toArray(double[][]::new);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -82,66 +196,4 @@ public class DoubleMatrix implements MLMatrix {
|
|||
return sb.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix transpose() {
|
||||
coefficients = MLMath.matrixTranspose(coefficients);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix multiply( MLMatrix B ) {
|
||||
coefficients = MLMath.matrixMultiply(coefficients, B.coefficients());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix multiplyAddBias( MLMatrix B, MLMatrix C ) {
|
||||
double[] biases = Arrays.stream(C.coefficients()).mapToDouble(( arr) -> arr[0]).toArray();
|
||||
coefficients = MLMath.biasAdd(
|
||||
MLMath.matrixMultiply(coefficients, B.coefficients()),
|
||||
biases
|
||||
);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix multiplyLeft( MLMatrix B ) {
|
||||
coefficients = MLMath.matrixMultiply(B.coefficients(), coefficients);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix add( MLMatrix B ) {
|
||||
coefficients = MLMath.matrixAdd(coefficients, B.coefficients());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix sub( MLMatrix B ) {
|
||||
coefficients = MLMath.matrixSub(coefficients, B.coefficients());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix scale( double scalar ) {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix scale( MLMatrix S ) {
|
||||
coefficients = MLMath.matrixScale(coefficients, S.coefficients());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix apply( DoubleUnaryOperator op ) {
|
||||
this.coefficients = MLMath.matrixApply(coefficients, op);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix duplicate() {
|
||||
return new DoubleMatrix(MLMath.copyMatrix(coefficients));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -22,29 +22,123 @@ public interface MLMatrix {
|
|||
|
||||
MLMatrix initializeZero();
|
||||
|
||||
//MLMatrix transpose();
|
||||
|
||||
//MLMatrix multiply( MLMatrix B );
|
||||
|
||||
/**
|
||||
* Erzeugt die transponierte Matrix zu dieser.
|
||||
* Erzeugt eine neue Matrix <em>C</em> mit dem Ergebnis der Matrixoperation
|
||||
* <pre>
|
||||
* C = A.B + V
|
||||
* </pre>
|
||||
* wobei <em>A</em> dieses Matrixobjekt ist und {@code .} für die
|
||||
* Matrixmultiplikation steht.
|
||||
*
|
||||
* @param B
|
||||
* @param V
|
||||
* @return
|
||||
*/
|
||||
MLMatrix transpose();
|
||||
MLMatrix multiplyAddBias( MLMatrix B, MLMatrix V );
|
||||
|
||||
MLMatrix multiply( MLMatrix B );
|
||||
/**
|
||||
* Erzeugt eine neue Matrix <em>C</em> mit dem Ergebnis der Matrixoperation
|
||||
* <pre>
|
||||
* C = A.t(B)
|
||||
* </pre>
|
||||
* wobei <em>A</em> dieses Matrixobjekt ist und {@code t(B)} für die
|
||||
* Transposition der Matrix <em>B</em>> steht.
|
||||
*
|
||||
* @param B
|
||||
* @return
|
||||
*/
|
||||
MLMatrix multiplyTransposed( MLMatrix B );
|
||||
|
||||
MLMatrix multiplyAddBias( MLMatrix B, MLMatrix C );
|
||||
|
||||
MLMatrix multiplyLeft( MLMatrix B );
|
||||
MLMatrix transposedMultiplyAndScale( MLMatrix B, double scalar );
|
||||
|
||||
/**
|
||||
* Erzeugt eine neue Matrix <em>C</em> mit dem Ergebnis der
|
||||
* komponentenweisen Matrix-Addition
|
||||
* <pre>
|
||||
* C = A+B
|
||||
* </pre>
|
||||
* wobei <em>A</em> dieses Matrixobjekt ist. Für ein Element
|
||||
* <em>C_ij</em> in <em>C</em> gilt
|
||||
* <pre>
|
||||
* C_ij = A_ij + B_ij
|
||||
* </pre>
|
||||
*
|
||||
* @param B Die zweite Matrix.
|
||||
* @return Ein neues Matrixobjekt mit dem Ergebnis.
|
||||
*/
|
||||
MLMatrix add( MLMatrix B );
|
||||
|
||||
/**
|
||||
* Setzt dies Matrix auf das Ergebnis der
|
||||
* komponentenweisen Matrix-Addition
|
||||
* <pre>
|
||||
* A = A+B
|
||||
* </pre>
|
||||
* wobei <em>A</em> dieses Matrixobjekt ist. Für ein Element
|
||||
* <em>A_ij</em> in <em>A</em> gilt
|
||||
* <pre>
|
||||
* A_ij = A_ij + B_ij
|
||||
* </pre>
|
||||
*
|
||||
* @param B Die zweite Matrix.
|
||||
* @return Diese Matrix selbst (method chaining).
|
||||
*/
|
||||
MLMatrix addInPlace( MLMatrix B );
|
||||
|
||||
/**
|
||||
* Erzeugt eine neue Matrix <em>C</em> mit dem Ergebnis der
|
||||
* komponentenweisen Matrix-Subtraktion
|
||||
* <pre>
|
||||
* C = A-B
|
||||
* </pre>
|
||||
* wobei <em>A</em> dieses Matrixobjekt ist. Für ein Element
|
||||
* <em>C_ij</em> in <em>C</em> gilt
|
||||
* <pre>
|
||||
* C_ij = A_ij - B_ij
|
||||
* </pre>
|
||||
*
|
||||
* @param B
|
||||
* @return
|
||||
*/
|
||||
MLMatrix sub( MLMatrix B );
|
||||
|
||||
MLMatrix scaleInPlace( double scalar );
|
||||
|
||||
MLMatrix scale( double scalar );
|
||||
MLMatrix scaleInPlace( MLMatrix S );
|
||||
|
||||
MLMatrix scale( MLMatrix S );
|
||||
/**
|
||||
* Berechnet eine neue Matrix mit nur einer Zeile, die die Spaltensummen
|
||||
* dieser Matrix enthalten.
|
||||
* @return
|
||||
*/
|
||||
MLMatrix colSums();
|
||||
|
||||
/**
|
||||
* Endet die gegebene Funktion auf jeden Wert der Matrix an.
|
||||
*
|
||||
* @param op
|
||||
* @return
|
||||
*/
|
||||
MLMatrix apply( DoubleUnaryOperator op );
|
||||
|
||||
/**
|
||||
* Endet die gegebene Funktion auf jeden Wert der Matrix an.
|
||||
*
|
||||
* @param op
|
||||
* @return
|
||||
*/
|
||||
MLMatrix applyInPlace( DoubleUnaryOperator op );
|
||||
|
||||
/**
|
||||
* Erzeugt eine neue Matrix mit denselben Dimenstionen und Koeffizienten wie
|
||||
* diese Matrix.
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
MLMatrix duplicate();
|
||||
|
||||
String toString();
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package schule.ngb.zm.ml;
|
||||
|
||||
import cern.colt.matrix.DoubleMatrix2D;
|
||||
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
|
||||
import cern.colt.matrix.DoubleFactory2D;
|
||||
import schule.ngb.zm.Constants;
|
||||
import schule.ngb.zm.util.Log;
|
||||
|
||||
|
@ -11,7 +10,7 @@ public class MatrixFactory {
|
|||
|
||||
public static void main( String[] args ) {
|
||||
System.out.println(
|
||||
MatrixFactory.create(new double[][]{ {1.0, 0.0}, {0.0, 1.0} }).toString()
|
||||
MatrixFactory.create(new double[][]{{1.0, 0.0}, {0.0, 1.0}}).toString()
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -26,14 +25,14 @@ public class MatrixFactory {
|
|||
|
||||
public static final MLMatrix create( double[][] values ) {
|
||||
try {
|
||||
return getMatrixType().getDeclaredConstructor(double[][].class).newInstance((Object)values);
|
||||
return getMatrixType().getDeclaredConstructor(double[][].class).newInstance((Object) values);
|
||||
} catch( Exception ex ) {
|
||||
LOG.error(ex, "Could not initialize matrix implementation for class <%s>. Using internal implementation.", matrixType);
|
||||
}
|
||||
return new DoubleMatrix(values);
|
||||
}
|
||||
|
||||
private static Class<? extends MLMatrix> matrixType = null;
|
||||
static Class<? extends MLMatrix> matrixType = null;
|
||||
|
||||
private static final Class<? extends MLMatrix> getMatrixType() {
|
||||
if( matrixType == null ) {
|
||||
|
@ -63,6 +62,10 @@ public class MatrixFactory {
|
|||
matrix = new cern.colt.matrix.impl.DenseDoubleMatrix2D(rows, cols);
|
||||
}
|
||||
|
||||
public ColtMatrix( ColtMatrix matrix ) {
|
||||
this.matrix = matrix.matrix.copy();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int columns() {
|
||||
return matrix.columns();
|
||||
|
@ -91,13 +94,12 @@ public class MatrixFactory {
|
|||
|
||||
@Override
|
||||
public MLMatrix initializeRandom() {
|
||||
matrix.assign((d) -> Constants.randomGaussian());
|
||||
return this;
|
||||
return initializeRandom(-1.0, 1.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix initializeRandom( double lower, double upper ) {
|
||||
matrix.assign((d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower);
|
||||
matrix.assign(( d ) -> ((upper - lower) * Constants.random()) + lower);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -114,72 +116,97 @@ public class MatrixFactory {
|
|||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix apply( DoubleUnaryOperator op ) {
|
||||
this.matrix.assign((d) -> op.applyAsDouble(d));
|
||||
return this;
|
||||
public MLMatrix duplicate() {
|
||||
ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
|
||||
newMatrix.matrix.assign(this.matrix);
|
||||
return newMatrix;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix transpose() {
|
||||
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.transpose(this.matrix);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix multiply( MLMatrix B ) {
|
||||
ColtMatrix CB = (ColtMatrix)B;
|
||||
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(matrix, CB.matrix);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix multiplyLeft( MLMatrix B ) {
|
||||
ColtMatrix CB = (ColtMatrix)B;
|
||||
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(CB.matrix, matrix);
|
||||
return this;
|
||||
public MLMatrix multiplyTransposed( MLMatrix B ) {
|
||||
ColtMatrix CB = (ColtMatrix) B;
|
||||
ColtMatrix newMatrix = new ColtMatrix(0, 0);
|
||||
newMatrix.matrix = matrix.zMult(CB.matrix, null, 1.0, 0.0, false, true);
|
||||
return newMatrix;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix multiplyAddBias( MLMatrix B, MLMatrix C ) {
|
||||
ColtMatrix CB = (ColtMatrix)B;
|
||||
this.matrix = cern.colt.matrix.linalg.Algebra.DEFAULT.mult(matrix, CB.matrix);
|
||||
// TODO: add bias
|
||||
return this;
|
||||
ColtMatrix CB = (ColtMatrix) B;
|
||||
ColtMatrix newMatrix = new ColtMatrix(0, 0);
|
||||
newMatrix.matrix = DoubleFactory2D.dense.repeat(((ColtMatrix) C).matrix, rows(), 1);
|
||||
matrix.zMult(CB.matrix, newMatrix.matrix, 1.0, 1.0, false, false);
|
||||
return newMatrix;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix transposedMultiplyAndScale( final MLMatrix B, final double scalar ) {
|
||||
ColtMatrix CB = (ColtMatrix) B;
|
||||
ColtMatrix newMatrix = new ColtMatrix(0, 0);
|
||||
newMatrix.matrix = matrix.zMult(CB.matrix, null, scalar, 0.0, true, false);
|
||||
return newMatrix;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix add( MLMatrix B ) {
|
||||
ColtMatrix CB = (ColtMatrix)B;
|
||||
matrix.assign(CB.matrix, (d1,d2) -> d1+d2);
|
||||
ColtMatrix CB = (ColtMatrix) B;
|
||||
ColtMatrix newMatrix = new ColtMatrix(this);
|
||||
newMatrix.matrix.assign(CB.matrix, ( d1, d2 ) -> d1 + d2);
|
||||
return newMatrix;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix addInPlace( MLMatrix B ) {
|
||||
ColtMatrix CB = (ColtMatrix) B;
|
||||
matrix.assign(CB.matrix, ( d1, d2 ) -> d1 + d2);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix sub( MLMatrix B ) {
|
||||
ColtMatrix CB = (ColtMatrix)B;
|
||||
matrix.assign(CB.matrix, (d1,d2) -> d1-d2);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix scale( double scalar ) {
|
||||
this.matrix.assign((d) -> d*scalar);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix scale( MLMatrix S ) {
|
||||
this.matrix.forEachNonZero((r, c, d) -> d * S.get(r, c));
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix duplicate() {
|
||||
ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
|
||||
newMatrix.matrix.assign(this.matrix);
|
||||
ColtMatrix CB = (ColtMatrix) B;
|
||||
ColtMatrix newMatrix = new ColtMatrix(this);
|
||||
newMatrix.matrix.assign(CB.matrix, ( d1, d2 ) -> d1 - d2);
|
||||
return newMatrix;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix colSums() {
|
||||
double[][] sums = new double[1][matrix.columns()];
|
||||
for( int c = 0; c < matrix.columns(); c++ ) {
|
||||
for( int r = 0; r < matrix.rows(); r++ ) {
|
||||
sums[0][c] += matrix.getQuick(r, c);
|
||||
}
|
||||
}
|
||||
return new ColtMatrix(sums);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix scaleInPlace( double scalar ) {
|
||||
this.matrix.assign(( d ) -> d * scalar);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix scaleInPlace( MLMatrix S ) {
|
||||
this.matrix.forEachNonZero(( r, c, d ) -> d * S.get(r, c));
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix apply( DoubleUnaryOperator op ) {
|
||||
ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
|
||||
newMatrix.matrix.assign(matrix);
|
||||
newMatrix.matrix.assign(( d ) -> op.applyAsDouble(d));
|
||||
return newMatrix;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MLMatrix applyInPlace( DoubleUnaryOperator op ) {
|
||||
this.matrix.assign(( d ) -> op.applyAsDouble(d));
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return matrix.toString();
|
||||
|
|
|
@ -41,10 +41,10 @@ public class NeuronLayer implements Function<MLMatrix, MLMatrix> {
|
|||
public NeuronLayer( int neurons, int inputs ) {
|
||||
weights = MatrixFactory
|
||||
.create(inputs, neurons)
|
||||
.initializeRandom(-1, 1);
|
||||
.initializeRandom();
|
||||
|
||||
biases = MatrixFactory
|
||||
.create(neurons, 1)
|
||||
.create(1, neurons)
|
||||
.initializeZero();
|
||||
|
||||
activationFunction = MLMath::sigmoid;
|
||||
|
@ -110,10 +110,6 @@ public class NeuronLayer implements Function<MLMatrix, MLMatrix> {
|
|||
weights = newWeights.duplicate();
|
||||
}
|
||||
|
||||
public void adjustWeights( MLMatrix adjustment ) {
|
||||
weights.add(adjustment);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return weights.toString() + "\n" + biases.toString();
|
||||
|
@ -121,10 +117,10 @@ public class NeuronLayer implements Function<MLMatrix, MLMatrix> {
|
|||
|
||||
@Override
|
||||
public MLMatrix apply( MLMatrix inputs ) {
|
||||
lastInput = inputs;
|
||||
lastInput = inputs.duplicate();
|
||||
lastOutput = inputs
|
||||
.multiplyAddBias(weights, biases)
|
||||
.apply(activationFunction);
|
||||
.applyInPlace(activationFunction);
|
||||
|
||||
if( next != null ) {
|
||||
return next.apply(lastOutput);
|
||||
|
@ -144,14 +140,16 @@ public class NeuronLayer implements Function<MLMatrix, MLMatrix> {
|
|||
}
|
||||
|
||||
public void backprop( MLMatrix expected, double learningRate ) {
|
||||
MLMatrix error, delta, adjustment;
|
||||
MLMatrix error, adjustment;
|
||||
if( next == null ) {
|
||||
error = expected.duplicate().sub(lastOutput);
|
||||
error = expected.sub(lastOutput);
|
||||
} else {
|
||||
error = expected.duplicate().multiply(next.weights.transpose());
|
||||
error = expected.multiplyTransposed(next.weights);
|
||||
}
|
||||
|
||||
error.scale(lastOutput.duplicate().apply(this.activationFunctionDerivative));
|
||||
error.scaleInPlace(
|
||||
lastOutput.apply(this.activationFunctionDerivative)
|
||||
);
|
||||
// Hier schon leraningRate anwenden?
|
||||
// See https://towardsdatascience.com/understanding-and-implementing-neural-networks-in-java-from-scratch-61421bb6352c
|
||||
//delta = MLMath.matrixApply(delta, ( x ) -> learningRate * x);
|
||||
|
@ -159,10 +157,14 @@ public class NeuronLayer implements Function<MLMatrix, MLMatrix> {
|
|||
previous.backprop(error, learningRate);
|
||||
}
|
||||
|
||||
// biases = MLMath.biasAdjust(biases, MLMath.matrixApply(delta, ( x ) -> learningRate * x));
|
||||
biases.addInPlace(
|
||||
error.colSums().scaleInPlace(
|
||||
-learningRate / (double) error.rows()
|
||||
)
|
||||
);
|
||||
|
||||
adjustment = lastInput.duplicate().transpose().multiply(error).apply((d) -> learningRate*d);
|
||||
this.adjustWeights(adjustment);
|
||||
adjustment = lastInput.transposedMultiplyAndScale(error, learningRate);
|
||||
weights.addInPlace(adjustment);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
package schule.ngb.zm.ml;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class DoubleMatrixTest {
|
||||
|
||||
@Test
|
||||
void initializeOne() {
|
||||
DoubleMatrix m = new DoubleMatrix(4, 4);
|
||||
m.initializeOne();
|
||||
|
||||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0};
|
||||
assertArrayEquals(ones, m.coefficients[0]);
|
||||
assertArrayEquals(ones, m.coefficients[1]);
|
||||
assertArrayEquals(ones, m.coefficients[2]);
|
||||
assertArrayEquals(ones, m.coefficients[3]);
|
||||
}
|
||||
|
||||
@Test
|
||||
void initializeZero() {
|
||||
DoubleMatrix m = new DoubleMatrix(4, 4);
|
||||
m.initializeZero();
|
||||
|
||||
double[] zeros = new double[]{0.0, 0.0, 0.0, 0.0};
|
||||
assertArrayEquals(zeros, m.coefficients[0]);
|
||||
assertArrayEquals(zeros, m.coefficients[1]);
|
||||
assertArrayEquals(zeros, m.coefficients[2]);
|
||||
assertArrayEquals(zeros, m.coefficients[3]);
|
||||
}
|
||||
|
||||
@Test
|
||||
void initializeRandom() {
|
||||
DoubleMatrix m = new DoubleMatrix(4, 4);
|
||||
m.initializeRandom(-1, 1);
|
||||
|
||||
assertTrue(Arrays.stream(m.coefficients[0]).allMatch((d) -> -1.0 <= d && d < 1.0));
|
||||
assertTrue(Arrays.stream(m.coefficients[1]).allMatch((d) -> -1.0 <= d && d < 1.0));
|
||||
assertTrue(Arrays.stream(m.coefficients[2]).allMatch((d) -> -1.0 <= d && d < 1.0));
|
||||
assertTrue(Arrays.stream(m.coefficients[3]).allMatch((d) -> -1.0 <= d && d < 1.0));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,375 @@
|
|||
package schule.ngb.zm.ml;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.TestInfo;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class MLMatrixTest {
|
||||
|
||||
private TestInfo info;
|
||||
|
||||
@BeforeEach
|
||||
void saveTestInfo( TestInfo info ) {
|
||||
this.info = info;
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||
void initializeOne( Class<? extends MLMatrix> mType ) {
|
||||
MatrixFactory.matrixType = mType;
|
||||
|
||||
MLMatrix m = MatrixFactory.create(4, 4);
|
||||
m.initializeOne();
|
||||
|
||||
assertEquals(mType, m.getClass());
|
||||
|
||||
for( int i = 0; i < m.rows(); i++ ) {
|
||||
for( int j = 0; j < m.columns(); j++ ) {
|
||||
assertEquals(1.0, m.get(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||
void initializeZero( Class<? extends MLMatrix> mType ) {
|
||||
MatrixFactory.matrixType = mType;
|
||||
|
||||
MLMatrix m = MatrixFactory.create(4, 4);
|
||||
m.initializeZero();
|
||||
|
||||
assertEquals(mType, m.getClass());
|
||||
|
||||
for( int i = 0; i < m.rows(); i++ ) {
|
||||
for( int j = 0; j < m.columns(); j++ ) {
|
||||
assertEquals(0.0, m.get(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||
void initializeRandom( Class<? extends MLMatrix> mType ) {
|
||||
MatrixFactory.matrixType = mType;
|
||||
|
||||
MLMatrix m = MatrixFactory.create(4, 4);
|
||||
m.initializeRandom();
|
||||
|
||||
assertEquals(mType, m.getClass());
|
||||
|
||||
for( int i = 0; i < m.rows(); i++ ) {
|
||||
for( int j = 0; j < m.columns(); j++ ) {
|
||||
double d = m.get(i, j);
|
||||
assertTrue(-1.0 <= d && d < 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||
void multiplyTransposed( Class<? extends MLMatrix> mType ) {
|
||||
MatrixFactory.matrixType = mType;
|
||||
|
||||
MLMatrix A = MatrixFactory.create(new double[][]{
|
||||
{1.0, 2.0, 3.0, 4.0},
|
||||
{1.0, 2.0, 3.0, 4.0},
|
||||
{1.0, 2.0, 3.0, 4.0}
|
||||
});
|
||||
MLMatrix B = MatrixFactory.create(new double[][]{
|
||||
{1, 3, 5, 7},
|
||||
{2, 4, 6, 8}
|
||||
});
|
||||
|
||||
MLMatrix C = A.multiplyTransposed(B);
|
||||
|
||||
assertEquals(mType, A.getClass());
|
||||
assertEquals(mType, B.getClass());
|
||||
assertEquals(mType, C.getClass());
|
||||
|
||||
assertEquals(3, C.rows());
|
||||
assertEquals(2, C.columns());
|
||||
for( int i = 0; i < C.rows(); i++ ) {
|
||||
assertEquals(50.0, C.get(i, 0));
|
||||
assertEquals(60.0, C.get(i, 1));
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||
void multiplyAddBias( Class<? extends MLMatrix> mType ) {
|
||||
MatrixFactory.matrixType = mType;
|
||||
|
||||
MLMatrix A = MatrixFactory.create(new double[][]{
|
||||
{1.0, 2.0, 3.0, 4.0},
|
||||
{1.0, 2.0, 3.0, 4.0},
|
||||
{1.0, 2.0, 3.0, 4.0}
|
||||
});
|
||||
MLMatrix B = MatrixFactory.create(new double[][]{
|
||||
{1.0, 2.0},
|
||||
{3.0, 4.0},
|
||||
{5.0, 6.0},
|
||||
{7.0, 8.0}
|
||||
});
|
||||
MLMatrix V = MatrixFactory.create(new double[][]{
|
||||
{1000.0, 2000.0}
|
||||
});
|
||||
|
||||
MLMatrix C = A.multiplyAddBias(B, V);
|
||||
|
||||
assertEquals(mType, A.getClass());
|
||||
assertEquals(mType, B.getClass());
|
||||
assertEquals(mType, C.getClass());
|
||||
assertEquals(mType, V.getClass());
|
||||
|
||||
assertEquals(3, C.rows());
|
||||
assertEquals(2, C.columns());
|
||||
for( int i = 0; i < C.rows(); i++ ) {
|
||||
assertEquals(1050.0, C.get(i, 0));
|
||||
assertEquals(2060.0, C.get(i, 1));
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||
void transposedMultiplyAndScale( Class<? extends MLMatrix> mType ) {
|
||||
MatrixFactory.matrixType = mType;
|
||||
|
||||
MLMatrix A = MatrixFactory.create(new double[][]{
|
||||
{1, 1, 1},
|
||||
{2, 2, 2},
|
||||
{3, 3, 3},
|
||||
{4, 4, 4}
|
||||
});
|
||||
MLMatrix B = MatrixFactory.create(new double[][]{
|
||||
{1.0, 2.0},
|
||||
{3.0, 4.0},
|
||||
{5.0, 6.0},
|
||||
{7.0, 8.0}
|
||||
});
|
||||
|
||||
MLMatrix C = A.transposedMultiplyAndScale(B, 2.0);
|
||||
|
||||
assertEquals(mType, A.getClass());
|
||||
assertEquals(mType, B.getClass());
|
||||
assertEquals(mType, C.getClass());
|
||||
|
||||
assertEquals(3, C.rows());
|
||||
assertEquals(2, C.columns());
|
||||
for( int i = 0; i < C.rows(); i++ ) {
|
||||
assertEquals(100.0, C.get(i, 0));
|
||||
assertEquals(120.0, C.get(i, 1));
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||
void apply( Class<? extends MLMatrix> mType ) {
|
||||
MatrixFactory.matrixType = mType;
|
||||
|
||||
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||
{1, 1, 1},
|
||||
{2, 2, 2},
|
||||
{3, 3, 3},
|
||||
{4, 4, 4}
|
||||
});
|
||||
|
||||
MLMatrix R = M.apply(( d ) -> d * d);
|
||||
|
||||
assertEquals(mType, M.getClass());
|
||||
assertEquals(mType, R.getClass());
|
||||
assertNotSame(M, R);
|
||||
|
||||
for( int i = 0; i < M.rows(); i++ ) {
|
||||
for( int j = 0; j < M.columns(); j++ ) {
|
||||
assertEquals(
|
||||
(i + 1) * (i + 1), R.get(i, j),
|
||||
msg("(%d,%d)", "apply", i, j)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
MLMatrix M2 = M.applyInPlace(( d ) -> d * d * d);
|
||||
assertSame(M, M2);
|
||||
for( int i = 0; i < M.rows(); i++ ) {
|
||||
for( int j = 0; j < M.columns(); j++ ) {
|
||||
assertEquals(
|
||||
(i + 1) * (i + 1) * (i + 1), M.get(i, j),
|
||||
msg("(%d,%d)", "applyInPlace", i, j)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||
void add( Class<? extends MLMatrix> mType ) {
|
||||
MatrixFactory.matrixType = mType;
|
||||
|
||||
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||
{1, 1, 1},
|
||||
{2, 2, 2},
|
||||
{3, 3, 3},
|
||||
{4, 4, 4}
|
||||
});
|
||||
|
||||
MLMatrix R = M.add(M);
|
||||
|
||||
assertEquals(mType, M.getClass());
|
||||
assertEquals(mType, R.getClass());
|
||||
assertNotSame(M, R);
|
||||
|
||||
for( int i = 0; i < M.rows(); i++ ) {
|
||||
for( int j = 0; j < M.columns(); j++ ) {
|
||||
assertEquals(
|
||||
(i + 1) + (i + 1), R.get(i, j),
|
||||
msg("(%d,%d)", "add", i, j)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
MLMatrix M2 = M.addInPlace(R);
|
||||
assertSame(M, M2);
|
||||
for( int i = 0; i < M.rows(); i++ ) {
|
||||
for( int j = 0; j < M.columns(); j++ ) {
|
||||
assertEquals(
|
||||
(i + 1) + (i + 1) + (i + 1), M.get(i, j),
|
||||
msg("(%d,%d)", "addInPlace", i, j)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||
void sub( Class<? extends MLMatrix> mType ) {
|
||||
MatrixFactory.matrixType = mType;
|
||||
|
||||
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||
{1, 1, 1},
|
||||
{2, 2, 2},
|
||||
{3, 3, 3},
|
||||
{4, 4, 4}
|
||||
});
|
||||
|
||||
MLMatrix R = M.sub(M);
|
||||
|
||||
assertEquals(mType, M.getClass());
|
||||
assertEquals(mType, R.getClass());
|
||||
assertNotSame(M, R);
|
||||
|
||||
for( int i = 0; i < M.rows(); i++ ) {
|
||||
for( int j = 0; j < M.columns(); j++ ) {
|
||||
assertEquals(
|
||||
0.0, R.get(i, j),
|
||||
msg("(%d,%d)", "sub", i, j)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||
void colSums( Class<? extends MLMatrix> mType ) {
|
||||
MatrixFactory.matrixType = mType;
|
||||
|
||||
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||
{1, 2, 3},
|
||||
{1, 2, 3},
|
||||
{1, 2, 3},
|
||||
{1, 2, 3}
|
||||
});
|
||||
|
||||
MLMatrix R = M.colSums();
|
||||
|
||||
assertEquals(mType, M.getClass());
|
||||
assertEquals(mType, R.getClass());
|
||||
assertNotSame(M, R);
|
||||
|
||||
assertEquals(1, R.rows());
|
||||
assertEquals(3, R.columns());
|
||||
for( int j = 0; j < M.columns(); j++ ) {
|
||||
assertEquals(
|
||||
(j+1)*4, R.get(0, j),
|
||||
msg("(%d,%d)", "colSums", 0, j)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||
void duplicate( Class<? extends MLMatrix> mType ) {
|
||||
MatrixFactory.matrixType = mType;
|
||||
|
||||
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||
{1, 2, 3},
|
||||
{1, 2, 3},
|
||||
{1, 2, 3},
|
||||
{1, 2, 3}
|
||||
});
|
||||
|
||||
MLMatrix R = M.duplicate();
|
||||
|
||||
assertEquals(mType, M.getClass());
|
||||
assertEquals(mType, R.getClass());
|
||||
assertNotSame(M, R);
|
||||
|
||||
for( int i = 0; i < M.rows(); i++ ) {
|
||||
for( int j = 0; j < M.columns(); j++ ) {
|
||||
assertEquals(
|
||||
M.get(i, j), R.get(i, j),
|
||||
msg("(%d,%d)", "duplicate", i, j)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||
void scale( Class<? extends MLMatrix> mType ) {
|
||||
MatrixFactory.matrixType = mType;
|
||||
|
||||
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||
{1, 1, 1},
|
||||
{2, 2, 2},
|
||||
{3, 3, 3},
|
||||
{4, 4, 4}
|
||||
});
|
||||
|
||||
MLMatrix M2 = M.scaleInPlace(2.0);
|
||||
|
||||
assertEquals(mType, M.getClass());
|
||||
assertEquals(mType, M2.getClass());
|
||||
assertSame(M, M2);
|
||||
|
||||
for( int i = 0; i < M.rows(); i++ ) {
|
||||
for( int j = 0; j < M.columns(); j++ ) {
|
||||
assertEquals(
|
||||
(i+1)*2.0, M2.get(i, j),
|
||||
msg("(%d,%d)", "scaleInPlace", i, j)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
MLMatrix M3 = M.scaleInPlace(M);
|
||||
assertSame(M, M3);
|
||||
for( int i = 0; i < M.rows(); i++ ) {
|
||||
for( int j = 0; j < M.columns(); j++ ) {
|
||||
assertEquals(
|
||||
((i+1)*2.0)*((i+1)*2.0), M.get(i, j),
|
||||
msg("(%d,%d)", "addInPlace", i, j)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private String msg( String msg, String methodName, Object... args ) {
|
||||
String testName = this.info.getTestMethod().get().getName();
|
||||
String className = MatrixFactory.matrixType.getSimpleName();
|
||||
return String.format("[" + testName + "(" + className + ") " + methodName + "()] " + msg, args);
|
||||
}
|
||||
|
||||
}
|
|
@ -2,6 +2,7 @@ package schule.ngb.zm.ml;
|
|||
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import schule.ngb.zm.Constants;
|
||||
import schule.ngb.zm.util.Log;
|
||||
import schule.ngb.zm.util.Timer;
|
||||
|
||||
|
@ -16,6 +17,13 @@ class NeuralNetworkTest {
|
|||
Log.enableGlobalDebugging();
|
||||
}
|
||||
|
||||
@BeforeAll
|
||||
static void setupMatrixLibrary() {
|
||||
Constants.setSeed(1001);
|
||||
MatrixFactory.matrixType = MatrixFactory.ColtMatrix.class;
|
||||
//MatrixFactory.matrixType = DoubleMatrix.class;
|
||||
}
|
||||
|
||||
/*@Test
|
||||
void readWrite() {
|
||||
// XOR Dataset
|
||||
|
@ -107,7 +115,7 @@ class NeuralNetworkTest {
|
|||
for( int i = 0; i < trainingData.size(); i++ ) {
|
||||
inputs[i][0] = trainingData.get(i).a;
|
||||
inputs[i][1] = trainingData.get(i).b;
|
||||
outputs[i][0] = trainingData.get(i).result;
|
||||
outputs[i][0] = trainingData.get(i).getResult();
|
||||
}
|
||||
|
||||
Timer timer = new Timer();
|
||||
|
@ -139,8 +147,8 @@ class NeuralNetworkTest {
|
|||
"Prediction on data (%.2f, %.2f) was %.4f, expected %.2f (of by %.4f)\n",
|
||||
data.a, data.b,
|
||||
net.getOutput().get(0, 0),
|
||||
data.result,
|
||||
net.getOutput().get(0, 0) - data.result
|
||||
data.getResult(),
|
||||
net.getOutput().get(0, 0) - data.getResult()
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -183,7 +191,6 @@ class NeuralNetworkTest {
|
|||
|
||||
double a;
|
||||
double b;
|
||||
double result;
|
||||
CalcType type;
|
||||
|
||||
TestData( double a, double b ) {
|
||||
|
@ -191,6 +198,8 @@ class NeuralNetworkTest {
|
|||
this.b = b;
|
||||
}
|
||||
|
||||
abstract double getResult();
|
||||
|
||||
}
|
||||
|
||||
private static final class AddData extends TestData {
|
||||
|
@ -199,7 +208,9 @@ class NeuralNetworkTest {
|
|||
|
||||
public AddData( double a, double b ) {
|
||||
super(a, b);
|
||||
result = a + b;
|
||||
}
|
||||
double getResult() {
|
||||
return a+b;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -210,7 +221,9 @@ class NeuralNetworkTest {
|
|||
|
||||
public SubData( double a, double b ) {
|
||||
super(a, b);
|
||||
result = a - b;
|
||||
}
|
||||
double getResult() {
|
||||
return a-b;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -221,7 +234,9 @@ class NeuralNetworkTest {
|
|||
|
||||
public MulData( double a, double b ) {
|
||||
super(a, b);
|
||||
result = a * b;
|
||||
}
|
||||
double getResult() {
|
||||
return a*b;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -235,7 +250,9 @@ class NeuralNetworkTest {
|
|||
if( b == 0.0 ) {
|
||||
b = .1;
|
||||
}
|
||||
result = a / b;
|
||||
}
|
||||
double getResult() {
|
||||
return a/b;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -246,7 +263,12 @@ class NeuralNetworkTest {
|
|||
|
||||
public ModData( double b, double a ) {
|
||||
super(b, a);
|
||||
result = a % b;
|
||||
if( b == 0.0 ) {
|
||||
b = .1;
|
||||
}
|
||||
}
|
||||
double getResult() {
|
||||
return a%b;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue