diff --git a/build.gradle b/build.gradle
index 9a02a81..746e57e 100644
--- a/build.gradle
+++ b/build.gradle
@@ -28,8 +28,13 @@ dependencies {
runtimeOnly 'com.googlecode.soundlibs:tritonus-share:0.3.7.4'
runtimeOnly 'com.googlecode.soundlibs:mp3spi:1.9.5.4'
+ //compileOnlyApi 'colt:colt:1.2.0'
+ api 'colt:colt:1.2.0'
+ //api 'net.sourceforge.parallelcolt:parallelcolt:0.10.1'
+
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
- testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
+ testImplementation 'org.junit.jupiter:junit-jupiter-params:5.8.1'
+ testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
}
test {
diff --git a/src/main/java/schule/ngb/zm/Constants.java b/src/main/java/schule/ngb/zm/Constants.java
index f29544f..a6baee1 100644
--- a/src/main/java/schule/ngb/zm/Constants.java
+++ b/src/main/java/schule/ngb/zm/Constants.java
@@ -1269,7 +1269,8 @@ public class Constants {
}
/**
- * Erzeugt eine Pseudozufallszahl nach einer Gaussverteilung.
+ * Erzeugt eine Pseudozufallszahl zwischen -1 und 1 nach einer
+ * Normalverteilung mit Mittelwert 0 und Standardabweichung 1.
*
* @return Eine Zufallszahl.
* @see Random#nextGaussian()
diff --git a/src/main/java/schule/ngb/zm/ml/DoubleMatrix.java b/src/main/java/schule/ngb/zm/ml/DoubleMatrix.java
new file mode 100644
index 0000000..2144217
--- /dev/null
+++ b/src/main/java/schule/ngb/zm/ml/DoubleMatrix.java
@@ -0,0 +1,399 @@
+package schule.ngb.zm.ml;
+
+import schule.ngb.zm.Constants;
+
+import java.util.function.DoubleUnaryOperator;
+
+/**
+ * Eine einfache Implementierung der {@link MLMatrix} zur Verwendung in
+ * {@link NeuralNetwork}s.
+ *
+ * Diese Klasse stellt die interne Implementierung der Matrixoperationen dar,
+ * die zur Berechnung der Gewichte in einem {@link NeuronLayer} notwendig sind.
+ *
+ * Die Klasse ist nur minimal optimiert und sollte nur für kleine Netze
+ * verwendet werden. Für größere Netze sollte auf eine der optionalen
+ * Bibliotheken wie
+ * Colt zurückgegriffen werden.
+ */
+public final class DoubleMatrix implements MLMatrix {
+
+ /**
+ * Anzahl Zeilen der Matrix.
+ */
+ private int rows;
+
+ /**
+ * Anzahl Spalten der Matrix.
+ */
+ private int columns;
+
+ /**
+ * Die Koeffizienten der Matrix.
+ *
+ * Um den Overhead bei Speicher und Zugriffszeiten von zweidimensionalen
+ * Arrays zu vermeiden wird ein eindimensionales Array verwendet und die
+ * Indizes mit Spaltenpriorität berechnet. Der Index i des Koeffizienten
+ * {@code r,c} in Zeile {@code r} und Spalte {@code c} wird bestimmt durch
+ *
+ * i = c * rows + r
+ *
+ *
+ * Die Werte einer Spalte liegen also hintereinander im Array. Dies sollte
+ * einen leichten Vorteil bei der {@link #colSums() Spaltensummen} geben.
+ * Generell sollte eine Iteration über die Matrix der Form
+ *
+ * for( int j = 0; j < columns; j++ ) {
+ * for( int i = 0; i < rows; i++ ) {
+ * // ...
+ * }
+ * }
+ *
+ * etwas schneller sein als
+ *
+ * for( int i = 0; i < rows; i++ ) {
+ * for( int j = 0; j < columns; j++ ) {
+ * // ...
+ * }
+ * }
+ *
+ */
+ double[] coefficients;
+
+ public DoubleMatrix( int rows, int cols ) {
+ this.rows = rows;
+ this.columns = cols;
+ coefficients = new double[rows * cols];
+ }
+
+ public DoubleMatrix( double[][] coefficients ) {
+ this.rows = coefficients.length;
+ this.columns = coefficients[0].length;
+ this.coefficients = new double[rows * columns];
+ for( int j = 0; j < columns; j++ ) {
+ for( int i = 0; i < rows; i++ ) {
+ this.coefficients[idx(i, j)] = coefficients[i][j];
+ }
+ }
+ }
+
+ /**
+ * Initialisiert diese Matrix als Kopie der angegebenen Matrix.
+ *
+ * @param other Die zu kopierende Matrix.
+ */
+ public DoubleMatrix( DoubleMatrix other ) {
+ this.rows = other.rows();
+ this.columns = other.columns();
+ this.coefficients = new double[rows * columns];
+ System.arraycopy(
+ other.coefficients, 0,
+ this.coefficients, 0,
+ rows * columns);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public int columns() {
+ return columns;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public int rows() {
+ return rows;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ int idx( int r, int c ) {
+ return c * rows + r;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public double get( int row, int col ) {
+ try {
+ return coefficients[idx(row, col)];
+ } catch( ArrayIndexOutOfBoundsException ex ) {
+ throw new IllegalArgumentException("No element at row=" + row + ", column=" + col, ex);
+ }
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix set( int row, int col, double value ) {
+ try {
+ coefficients[idx(row, col)] = value;
+ } catch( ArrayIndexOutOfBoundsException ex ) {
+ throw new IllegalArgumentException("No element at row=" + row + ", column=" + col, ex);
+ }
+ return this;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix initializeRandom() {
+ return initializeRandom(-1.0, 1.0);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix initializeRandom( double lower, double upper ) {
+ applyInPlace(( d ) -> ((upper - lower) * Constants.random()) + lower);
+ return this;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix initializeOne() {
+ applyInPlace(( d ) -> 1.0);
+ return this;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix initializeZero() {
+ applyInPlace(( d ) -> 0.0);
+ return this;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix duplicate() {
+ return new DoubleMatrix(this);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix multiplyTransposed( MLMatrix B ) {
+ /*return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
+ ( i ) -> IntStream.range(0, B.rows()).mapToDouble(
+ ( j ) -> IntStream.range(0, columns).mapToDouble(
+ ( k ) -> get(i, k) * B.get(j, k)
+ ).sum()
+ ).toArray()
+ ).toArray(double[][]::new));*/
+ DoubleMatrix result = new DoubleMatrix(rows, B.rows());
+ for( int i = 0; i < rows; i++ ) {
+ for( int j = 0; j < B.rows(); j++ ) {
+ result.coefficients[result.idx(i, j)] = 0.0;
+ for( int k = 0; k < columns; k++ ) {
+ result.coefficients[result.idx(i, j)] += get(i, k) * B.get(j, k);
+ }
+ }
+ }
+ return result;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix multiplyAddBias( final MLMatrix B, final MLMatrix C ) {
+ /*return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
+ ( i ) -> IntStream.range(0, B.columns()).mapToDouble(
+ ( j ) -> IntStream.range(0, columns).mapToDouble(
+ ( k ) -> get(i, k) * B.get(k, j)
+ ).sum() + C.get(0, j)
+ ).toArray()
+ ).toArray(double[][]::new));*/
+ DoubleMatrix result = new DoubleMatrix(rows, B.columns());
+ for( int i = 0; i < rows; i++ ) {
+ for( int j = 0; j < B.columns(); j++ ) {
+ result.coefficients[result.idx(i, j)] = 0.0;
+ for( int k = 0; k < columns; k++ ) {
+ result.coefficients[result.idx(i, j)] += get(i, k) * B.get(k, j);
+ }
+ result.coefficients[result.idx(i, j)] += C.get(0, j);
+ }
+ }
+ return result;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix transposedMultiplyAndScale( final MLMatrix B, final double scalar ) {
+ /*return new DoubleMatrix(IntStream.range(0, columns).parallel().mapToObj(
+ ( i ) -> IntStream.range(0, B.columns()).mapToDouble(
+ ( j ) -> IntStream.range(0, rows).mapToDouble(
+ ( k ) -> get(k, i) * B.get(k, j) * scalar
+ ).sum()
+ ).toArray()
+ ).toArray(double[][]::new));*/
+ DoubleMatrix result = new DoubleMatrix(columns, B.columns());
+ for( int i = 0; i < columns; i++ ) {
+ for( int j = 0; j < B.columns(); j++ ) {
+ result.coefficients[result.idx(i, j)] = 0.0;
+ for( int k = 0; k < rows; k++ ) {
+ result.coefficients[result.idx(i, j)] += get(k, i) * B.get(k, j);
+ }
+ result.coefficients[result.idx(i, j)] *= scalar;
+ }
+ }
+ return result;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix add( MLMatrix B ) {
+ /*return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
+ ( i ) -> IntStream.range(0, columns).mapToDouble(
+ ( j ) -> get(i, j) + B.get(i, j)
+ ).toArray()
+ ).toArray(double[][]::new));*/
+ DoubleMatrix sum = new DoubleMatrix(rows, columns);
+ for( int j = 0; j < columns; j++ ) {
+ for( int i = 0; i < rows; i++ ) {
+ sum.coefficients[idx(i, j)] = coefficients[idx(i, j)] + B.get(i, j);
+ }
+ }
+ return sum;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix addInPlace( MLMatrix B ) {
+ for( int j = 0; j < columns; j++ ) {
+ for( int i = 0; i < rows; i++ ) {
+ coefficients[idx(i, j)] += B.get(i, j);
+ }
+ }
+ return this;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix sub( MLMatrix B ) {
+ /*return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
+ ( i ) -> IntStream.range(0, columns).mapToDouble(
+ ( j ) -> get(i, j) - B.get(i, j)
+ ).toArray()
+ ).toArray(double[][]::new));*/
+ DoubleMatrix diff = new DoubleMatrix(rows, columns);
+ for( int j = 0; j < columns; j++ ) {
+ for( int i = 0; i < rows; i++ ) {
+ diff.coefficients[idx(i, j)] = coefficients[idx(i, j)] - B.get(i, j);
+ }
+ }
+ return diff;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix colSums() {
+ /*DoubleMatrix colSums = new DoubleMatrix(1, columns);
+ colSums.coefficients = IntStream.range(0, columns).parallel().mapToDouble(
+ ( j ) -> IntStream.range(0, rows).mapToDouble(
+ ( i ) -> get(i, j)
+ ).sum()
+ ).toArray();
+ return colSums;*/
+ DoubleMatrix colSums = new DoubleMatrix(1, columns);
+ for( int j = 0; j < columns; j++ ) {
+ colSums.coefficients[j] = 0.0;
+ for( int i = 0; i < rows; i++ ) {
+ colSums.coefficients[j] += coefficients[idx(i, j)];
+ }
+ }
+ return colSums;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix scaleInPlace( final double scalar ) {
+ for( int i = 0; i < coefficients.length; i++ ) {
+ coefficients[i] *= scalar;
+ }
+ return this;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix scaleInPlace( final MLMatrix S ) {
+ for( int j = 0; j < columns; j++ ) {
+ for( int i = 0; i < rows; i++ ) {
+ coefficients[idx(i, j)] *= S.get(i, j);
+ }
+ }
+ return this;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix apply( DoubleUnaryOperator op ) {
+ DoubleMatrix result = new DoubleMatrix(rows, columns);
+ for( int i = 0; i < coefficients.length; i++ ) {
+ result.coefficients[i] = op.applyAsDouble(coefficients[i]);
+ }
+ return result;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public MLMatrix applyInPlace( DoubleUnaryOperator op ) {
+ for( int i = 0; i < coefficients.length; i++ ) {
+ coefficients[i] = op.applyAsDouble(coefficients[i]);
+ }
+ return this;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(rows);
+ sb.append(" x ");
+ sb.append(columns);
+ sb.append(" Matrix");
+ sb.append('\n');
+ for( int i = 0; i < rows; i++ ) {
+ for( int j = 0; j < columns; j++ ) {
+ sb.append(get(i, j));
+ if( j < columns - 1 )
+ sb.append(' ');
+ }
+ sb.append('\n');
+ }
+ return sb.toString();
+ }
+
+}
diff --git a/src/main/java/schule/ngb/zm/ml/MLMatrix.java b/src/main/java/schule/ngb/zm/ml/MLMatrix.java
new file mode 100644
index 0000000..8de5e76
--- /dev/null
+++ b/src/main/java/schule/ngb/zm/ml/MLMatrix.java
@@ -0,0 +1,316 @@
+package schule.ngb.zm.ml;
+
+import java.util.function.DoubleUnaryOperator;
+
+/**
+ * Interface für Matrizen, die in {@link NeuralNetwork} Klassen verwendet
+ * werden.
+ *
+ * Eine implementierende Klasse muss generell zwei Konstruktoren bereitstellen:
+ *
+ * - {@code MLMatrix(int rows, int columns)} erstellt eine Matrix mit den
+ * angegebenen Dimensionen und setzt alle Koeffizienten auf 0.
+ *
- {@code MLMatrix(double[][] coefficients} erstellt eine Matrix mit der
+ * durch das Array gegebenen Dimensionen und setzt die Werte auf die
+ * jeweiligen Werte des Arrays.
+ *
+ *
+ * Das Interface ist nicht dazu gedacht eine allgemeine Umsetzung für
+ * Matrizen-Algebra abzubilden, sondern soll gezielt die im Neuralen Netzwerk
+ * verwendeten Algorithmen umsetzen. Einerseits würde eine ganz allgemeine
+ * Matrizen-Klasse nicht im Rahmen der Zeichenmaschine liegen und auf der
+ * anderen Seite bietet eine Konzentration auf die verwendeten Algorithmen mehr
+ * Spielraum zur Optimierung.
+ *
+ * Intern wird das Interface von {@link DoubleMatrix} implementiert. Die Klasse
+ * ist eine weitestgehend naive Implementierung der Algorithmen mit kleineren
+ * Optimierungen. Die Verwendung eines generalisierten Interfaces erlaubt aber
+ * zukünftig die optionale Integration spezialisierterer Algebra-Bibliotheken
+ * wie
+ * Colt, um auch große
+ * Netze effizient berechnen zu können.
+ */
+public interface MLMatrix {
+
+ /**
+ * Die Anzahl der Spalten der Matrix.
+ *
+ * @return Spaltenzahl.
+ */
+ int columns();
+
+ /**
+ * Die Anzahl der Zeilen der Matrix.
+ *
+ * @return Zeilenzahl.
+ */
+ int rows();
+
+ /**
+ * Gibt den Wert an der angegebenen Stelle der Matrix zurück.
+ *
+ * @param row Die Spaltennummer zwischen 0 und {@code rows()-1}.
+ * @param col Die Zeilennummer zwischen 0 und {@code columns()-1}
+ * @return Den Koeffizienten in der Zeile {@code row} und der Spalte
+ * {@code col}.
+ * @throws IllegalArgumentException Falls {@code row >= rows()} oder
+ * {@code col >= columns()}.
+ */
+ double get( int row, int col ) throws IllegalArgumentException;
+
+ /**
+ * Setzt den Wert an der angegebenen Stelle der Matrix.
+ *
+ * @param row Die Spaltennummer zwischen 0 und {@code rows()-1}.
+ * @param col Die Zeilennummer zwischen 0 und {@code columns()-1}
+ * @param value Der neue Wert.
+ * @return Diese Matrix selbst (method chaining).
+ * @throws IllegalArgumentException Falls {@code row >= rows()} oder
+ * {@code col >= columns()}.
+ */
+ MLMatrix set( int row, int col, double value ) throws IllegalArgumentException;
+
+ /**
+ * Setzt jeden Wert in der Matrix auf eine Zufallszahl zwischen -1 und 1.
+ *
+ * Nach Möglichkeit sollte der
+ * {@link schule.ngb.zm.Constants#random(int, int) Zufallsgenerator der
+ * Zeichenmaschine} verwendet werden.
+ *
+ * @return Diese Matrix selbst (method chaining).
+ */
+ MLMatrix initializeRandom();
+
+ /**
+ * Setzt jeden Wert in der Matrix auf eine Zufallszahl innerhalb der
+ * angegebenen Grenzen.
+ *
+ * Nach Möglichkeit sollte der
+ * {@link schule.ngb.zm.Constants#random(int, int) Zufallsgenerator der
+ * Zeichenmaschine} verwendet werden.
+ *
+ * @param lower Untere Grenze der Zufallszahlen.
+ * @param upper Obere Grenze der Zufallszahlen.
+ * @return Diese Matrix selbst (method chaining).
+ */
+ MLMatrix initializeRandom( double lower, double upper );
+
+ /**
+ * Setzt alle Werte der Matrix auf 1.
+ *
+ * @return Diese Matrix selbst (method chaining).
+ */
+ MLMatrix initializeOne();
+
+ /**
+ * Setzt alle Werte der Matrix auf 0.
+ *
+ * @return Diese Matrix selbst (method chaining).
+ */
+ MLMatrix initializeZero();
+
+ /**
+ * Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der Matrixoperation
+ *
+ * C = this . B + V'
+ *
+ * wobei {@code this} dieses Matrixobjekt ist und {@code .} für die
+ * Matrixmultiplikation steht. {@vode V'} ist die Matrix {@code V}
+ * {@code rows()}-mal untereinander wiederholt.
+ *
+ * Wenn diese Matrix die Dimension r x c hat, dann muss die Matrix {@code B}
+ * die Dimension c x m haben und {@code V} eine 1 x m Matrix sein. Die
+ * Matrix {@code V'} hat also die Dimension r x m, ebenso wie das Ergebnis
+ * der Operation.
+ *
+ * @param B Eine {@code columns()} x m Matrix mit der Multipliziert wird.
+ * @param V Eine 1 x {@code B.columns()} Matrix mit den Bias-Werten.
+ * @return Eine {@code rows()} x m Matrix.
+ * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
+ * zur Operation passen. Also
+ * {@code this.columns() != B.rows()} oder
+ * {@code B.columns() != V.columns()} oder
+ * {@code V.rows() != 1}.
+ */
+ MLMatrix multiplyAddBias( MLMatrix B, MLMatrix V ) throws IllegalArgumentException;
+
+ /**
+ * Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der Matrixoperation
+ *
+ * C = this . t(B)
+ *
+ * wobei {@code this} dieses Matrixobjekt ist, {@code t(B)} die
+ * Transposition der Matrix {@code B} ist und {@code .} für die
+ * Matrixmultiplikation steht.
+ *
+ * Wenn diese Matrix die Dimension r x c hat, dann muss die Matrix {@code B}
+ * die Dimension m x c haben und das Ergebnis ist eine r x m Matrix.
+ *
+ * @param B Eine m x {@code columns()} Matrix.
+ * @return Eine {@code rows()} x m Matrix.
+ * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
+ * zur Operation passen. Also
+ * {@code this.columns() != B.columns()}.
+ */
+ MLMatrix multiplyTransposed( MLMatrix B ) throws IllegalArgumentException;
+
+ /**
+ * Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der Matrixoperation
+ *
+ * C = t(this) . B * scalar
+ *
+ * wobei {@code this} dieses Matrixobjekt ist, {@code t(this)} die
+ * Transposition dieser Matrix ist und {@code .} für die
+ * Matrixmultiplikation steht. {@code *} bezeichnet die
+ * Skalarmultiplikation, bei der jeder Wert der Matrix mit {@code scalar}
+ * multipliziert wird.
+ *
+ * Wenn diese Matrix die Dimension r x c hat, dann muss die Matrix {@code B}
+ * die Dimension r x m haben und das Ergebnis ist eine c x m Matrix.
+ *
+ * @param B Eine m x {@code columns()} Matrix.
+ * @return Eine {@code rows()} x m Matrix.
+ * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
+ * zur Operation passen. Also
+ * {@code this.rows() != B.rows()}.
+ */
+ MLMatrix transposedMultiplyAndScale( MLMatrix B, double scalar ) throws IllegalArgumentException;
+
+ /**
+ * Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der komponentenweisen
+ * Matrix-Addition
+ *
+ * C = this + B
+ *
+ * wobei {@code this} dieses Matrixobjekt ist. Für ein Element {@code C_ij}
+ * in {@code C} gilt
+ *
+ * C_ij = A_ij + B_ij
+ *
+ *
+ * Die Matrix {@code B} muss dieselbe Dimension wie diese Matrix haben.
+ *
+ * @param B Eine {@code rows()} x {@code columns()} Matrix.
+ * @return Eine {@code rows()} x {@code columns()} Matrix.
+ * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
+ * zur Operation passen. Also
+ * {@code this.rows() != B.rows()} oder
+ * {@code this.columns() != B.columns()}.
+ */
+ MLMatrix add( MLMatrix B ) throws IllegalArgumentException;
+
+ /**
+ * Setzt diese Matrix auf das Ergebnis der komponentenweisen
+ * Matrix-Addition
+ *
+ * A' = A + B
+ *
+ * wobei {@code A} dieses Matrixobjekt ist und {@code A'} diese Matrix nach
+ * der Operation. Für ein Element {@code A'_ij} in {@code A'} gilt
+ *
+ * A'_ij = A_ij + B_ij
+ *
+ *
+ * Die Matrix {@code B} muss dieselbe Dimension wie diese Matrix haben.
+ *
+ * @param B Eine {@code rows()} x {@code columns()} Matrix.
+ * @return Eine {@code rows()} x {@code columns()} Matrix.
+ * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
+ * zur Operation passen. Also
+ * {@code this.rows() != B.rows()} oder
+ * {@code this.columns() != B.columns()}.
+ */
+ MLMatrix addInPlace( MLMatrix B ) throws IllegalArgumentException;
+
+ /**
+ * Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der komponentenweisen
+ * Matrix-Subtraktion
+ *
+ * C = A - B
+ *
+ * wobei {@code A} dieses Matrixobjekt ist. Für ein Element {@code C_ij} in
+ * {@code C} gilt
+ *
+ * C_ij = A_ij - B_ij
+ *
+ *
+ * Die Matrix {@code B} muss dieselbe Dimension wie diese Matrix haben.
+ *
+ * @param B Eine {@code rows()} x {@code columns()} Matrix.
+ * @return Eine {@code rows()} x {@code columns()} Matrix.
+ * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
+ * zur Operation passen. Also
+ * {@code this.rows() != B.rows()} oder
+ * {@code this.columns() != B.columns()}.
+ */
+ MLMatrix sub( MLMatrix B ) throws IllegalArgumentException;
+
+ /**
+ * Multipliziert jeden Wert dieser Matrix mit dem angegebenen Skalar.
+ *
+ * Ist {@code A} dieses Matrixobjekt und {@code A'} diese Matrix nach der
+ * Operation, dann gilt für ein Element {@code A'_ij} in {@code A'}
+ *
+ * A'_ij = A_ij * scalar
+ *
+ *
+ * @param scalar Ein Skalar.
+ * @return Diese Matrix selbst (method chaining)
+ */
+ MLMatrix scaleInPlace( double scalar );
+
+ /**
+ * Multipliziert jeden Wert dieser Matrix mit dem entsprechenden Wert in der
+ * Matrix {@code S}.
+ *
+ * Ist {@code A} dieses Matrixobjekt und {@code A'} diese Matrix nach der
+ * Operation, dann gilt für ein Element {@code A'_ij} in {@code A'}
+ *
+ * A'_ij = A_ij * S_ij
+ *
+ *
+ * @param S Eine {@code rows()} x {@code columns()} Matrix.
+ * @return Diese Matrix selbst (method chaining)
+ * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
+ * zur Operation passen. Also
+ * {@code this.rows() != B.rows()} oder
+ * {@code this.columns() != B.columns()}.
+ */
+ MLMatrix scaleInPlace( MLMatrix S ) throws IllegalArgumentException;
+
+ /**
+ * Berechnet eine neue Matrix mit nur einer Zeile, die die Spaltensummen
+ * dieser Matrix enthalten.
+ *
+ * @return Eine 1 x {@code columns()} Matrix.
+ */
+ MLMatrix colSums();
+
+ /**
+ * Erzeugt eine neue Matrix, deren Werte gleich den Werten dieser Matrix
+ * nach der Anwendung der angegebenen Funktion sind.
+ *
+ * @param op Eine Operation {@code (double) -> double}.
+ * @return Eine {@code rows()} x {@code columns()} Matrix.
+ */
+ MLMatrix apply( DoubleUnaryOperator op );
+
+ /**
+ * Endet die gegebene Funktion auf jeden Wert der Matrix an.
+ *
+ * @param op Eine Operation {@code (double) -> double}.
+ * @return Diese Matrix selbst (method chaining)
+ */
+ MLMatrix applyInPlace( DoubleUnaryOperator op );
+
+ /**
+ * Erzeugt eine neue Matrix mit denselben Dimensionen und Koeffizienten wie
+ * diese Matrix.
+ *
+ * @return Eine Kopie dieser Matrix.
+ */
+ MLMatrix duplicate();
+
+ String toString();
+
+}
diff --git a/src/main/java/schule/ngb/zm/ml/Matrix.java b/src/main/java/schule/ngb/zm/ml/Matrix.java
deleted file mode 100644
index 734fe7b..0000000
--- a/src/main/java/schule/ngb/zm/ml/Matrix.java
+++ /dev/null
@@ -1,82 +0,0 @@
-package schule.ngb.zm.ml;
-
-import schule.ngb.zm.Constants;
-
-import java.util.Arrays;
-
-// TODO: Move Math into Matrix class
-// TODO: Implement support for optional sci libs
-public class Matrix {
-
- private int columns, rows;
-
- double[][] coefficients;
-
- public Matrix( int rows, int cols ) {
- this.rows = rows;
- this.columns = cols;
- coefficients = new double[rows][cols];
- }
-
- public Matrix( double[][] coefficients ) {
- this.coefficients = coefficients;
- this.rows = coefficients.length;
- this.columns = coefficients[0].length;
- }
-
- public int getColumns() {
- return columns;
- }
-
- public int getRows() {
- return rows;
- }
-
- public double[][] getCoefficients() {
- return coefficients;
- }
-
- public double get( int row, int col ) {
- return coefficients[row][col];
- }
-
- public void initializeRandom() {
- coefficients = MLMath.matrixApply(coefficients, (d) -> Constants.randomGaussian());
- }
-
- public void initializeRandom( double lower, double upper ) {
- coefficients = MLMath.matrixApply(coefficients, (d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower);
- }
-
- public void initializeIdentity() {
- initializeZero();
- for( int i = 0; i < Math.min(rows, columns); i++ ) {
- this.coefficients[i][i] = 1.0;
- }
- }
-
- public void initializeOne() {
- coefficients = MLMath.matrixApply(coefficients, (d) -> 1.0);
- }
-
- public void initializeZero() {
- coefficients = MLMath.matrixApply(coefficients, (d) -> 0.0);
- }
-
- @Override
- public String toString() {
- //return Arrays.deepToString(coefficients);
- StringBuilder sb = new StringBuilder();
- sb.append('[');
- sb.append('\n');
- for( int i = 0; i < coefficients.length; i++ ) {
- sb.append('\t');
- sb.append(Arrays.toString(coefficients[i]));
- sb.append('\n');
- }
- sb.append(']');
-
- return sb.toString();
- }
-
-}
diff --git a/src/main/java/schule/ngb/zm/ml/MatrixFactory.java b/src/main/java/schule/ngb/zm/ml/MatrixFactory.java
new file mode 100644
index 0000000..889380a
--- /dev/null
+++ b/src/main/java/schule/ngb/zm/ml/MatrixFactory.java
@@ -0,0 +1,246 @@
+package schule.ngb.zm.ml;
+
+import cern.colt.matrix.DoubleFactory2D;
+import schule.ngb.zm.Constants;
+import schule.ngb.zm.util.Log;
+
+import java.util.function.DoubleUnaryOperator;
+
+/**
+ * Zentrale Klasse zur Erstellung neuer Matrizen. Generell sollten neue Matrizen
+ * nicht direkt erstellt werden, sondern durch den Aufruf von
+ * {@link #create(int, int)} oder {@link #create(double[][])}. Die Fabrik
+ * ermittelt automatisch die beste verfügbare Implementierung und initialisiert
+ * eine entsprechende Implementierung von {@link MLMatrix}.
+ *
+ * Derzeit werden die optionale Bibliothek Colt und die interne
+ * Implementierung {@link DoubleMatrix} unterstützt.
+ */
+public class MatrixFactory {
+
+ /**
+ * Erstellt eine neue Matrix mit den angegebenen Dimensionen und
+ * initialisiert alle Werte mit 0.
+ *
+ * @param rows Anzahl der Zeilen.
+ * @param cols Anzahl der Spalten.
+ * @return Eine {@code rows} x {@code cols} Matrix.
+ */
+ public static final MLMatrix create( int rows, int cols ) {
+ try {
+ return getMatrixType().getDeclaredConstructor(int.class, int.class).newInstance(rows, cols);
+ } catch( Exception ex ) {
+ LOG.error(ex, "Could not initialize matrix implementation for class <%s>. Using internal implementation.", matrixType);
+ }
+ return new DoubleMatrix(rows, cols);
+ }
+
+ /**
+ * Erstellt eine neue Matrix mit den Dimensionen des angegebenen Arrays und
+ * initialisiert die Werte mit den entsprechenden Werten des Arrays.
+ *
+ * @param values Die Werte der Matrix.
+ * @return Eine {@code values.length} x {@code values[0].length} Matrix mit
+ * den Werten des Arrays.
+ */
+ public static final MLMatrix create( double[][] values ) {
+ try {
+ return getMatrixType().getDeclaredConstructor(double[][].class).newInstance((Object) values);
+ } catch( Exception ex ) {
+ LOG.error(ex, "Could not initialize matrix implementation for class <%s>. Using internal implementation.", matrixType);
+ }
+ return new DoubleMatrix(values);
+ }
+
+ /**
+ * Die verwendete {@link MLMatrix} Implementierung, aus der Matrizen erzeugt
+ * werden.
+ */
+ static Class extends MLMatrix> matrixType = null;
+
+ /**
+ * Ermittelt die beste verfügbare Implementierung von {@link MLMatrix}.
+ *
+ * @return Die verwendete {@link MLMatrix} Implementierung.
+ */
+ private static final Class extends MLMatrix> getMatrixType() {
+ if( matrixType == null ) {
+ try {
+ Class> clazz = Class.forName("cern.colt.matrix.impl.DenseDoubleMatrix2D", false, MatrixFactory.class.getClassLoader());
+ matrixType = ColtMatrix.class;
+ LOG.info("Colt library found. Using as matrix implementation.");
+ } catch( ClassNotFoundException e ) {
+ LOG.info("Colt library not found. Falling back on internal implementation.");
+ matrixType = DoubleMatrix.class;
+ }
+ }
+ return matrixType;
+ }
+
+ private static final Log LOG = Log.getLogger(MatrixFactory.class);
+
+ /**
+ * Interner Wrapper der DoubleMatrix2D Klasse aus der Colt Bibliothek, um
+ * das {@link MLMatrix} Interface zu implementieren.
+ */
+ static class ColtMatrix implements MLMatrix {
+
+ cern.colt.matrix.DoubleMatrix2D matrix;
+
+ public ColtMatrix( double[][] doubles ) {
+ matrix = new cern.colt.matrix.impl.DenseDoubleMatrix2D(doubles);
+ }
+
+ public ColtMatrix( int rows, int cols ) {
+ matrix = new cern.colt.matrix.impl.DenseDoubleMatrix2D(rows, cols);
+ }
+
+ public ColtMatrix( ColtMatrix matrix ) {
+ this.matrix = matrix.matrix.copy();
+ }
+
+ @Override
+ public int columns() {
+ return matrix.columns();
+ }
+
+ @Override
+ public int rows() {
+ return matrix.rows();
+ }
+
+ @Override
+ public double get( int row, int col ) {
+ return matrix.get(row, col);
+ }
+
+ @Override
+ public MLMatrix set( int row, int col, double value ) {
+ matrix.set(row, col, value);
+ return this;
+ }
+
+ @Override
+ public MLMatrix initializeRandom() {
+ return initializeRandom(-1.0, 1.0);
+ }
+
+ @Override
+ public MLMatrix initializeRandom( double lower, double upper ) {
+ matrix.assign(( d ) -> ((upper - lower) * Constants.random()) + lower);
+ return this;
+ }
+
+ @Override
+ public MLMatrix initializeOne() {
+ this.matrix.assign(1.0);
+ return this;
+ }
+
+ @Override
+ public MLMatrix initializeZero() {
+ this.matrix.assign(0.0);
+ return this;
+ }
+
+ @Override
+ public MLMatrix duplicate() {
+ ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
+ newMatrix.matrix.assign(this.matrix);
+ return newMatrix;
+ }
+
+ @Override
+ public MLMatrix multiplyTransposed( MLMatrix B ) {
+ ColtMatrix CB = (ColtMatrix) B;
+ ColtMatrix newMatrix = new ColtMatrix(0, 0);
+ newMatrix.matrix = matrix.zMult(CB.matrix, null, 1.0, 0.0, false, true);
+ return newMatrix;
+ }
+
+ @Override
+ public MLMatrix multiplyAddBias( MLMatrix B, MLMatrix C ) {
+ ColtMatrix CB = (ColtMatrix) B;
+ ColtMatrix newMatrix = new ColtMatrix(0, 0);
+ newMatrix.matrix = DoubleFactory2D.dense.repeat(((ColtMatrix) C).matrix, rows(), 1);
+ matrix.zMult(CB.matrix, newMatrix.matrix, 1.0, 1.0, false, false);
+ return newMatrix;
+ }
+
+ @Override
+ public MLMatrix transposedMultiplyAndScale( final MLMatrix B, final double scalar ) {
+ ColtMatrix CB = (ColtMatrix) B;
+ ColtMatrix newMatrix = new ColtMatrix(0, 0);
+ newMatrix.matrix = matrix.zMult(CB.matrix, null, scalar, 0.0, true, false);
+ return newMatrix;
+ }
+
+ @Override
+ public MLMatrix add( MLMatrix B ) {
+ ColtMatrix CB = (ColtMatrix) B;
+ ColtMatrix newMatrix = new ColtMatrix(this);
+ newMatrix.matrix.assign(CB.matrix, ( d1, d2 ) -> d1 + d2);
+ return newMatrix;
+ }
+
+ @Override
+ public MLMatrix addInPlace( MLMatrix B ) {
+ ColtMatrix CB = (ColtMatrix) B;
+ matrix.assign(CB.matrix, ( d1, d2 ) -> d1 + d2);
+ return this;
+ }
+
+ @Override
+ public MLMatrix sub( MLMatrix B ) {
+ ColtMatrix CB = (ColtMatrix) B;
+ ColtMatrix newMatrix = new ColtMatrix(this);
+ newMatrix.matrix.assign(CB.matrix, ( d1, d2 ) -> d1 - d2);
+ return newMatrix;
+ }
+
+ @Override
+ public MLMatrix colSums() {
+ double[][] sums = new double[1][matrix.columns()];
+ for( int c = 0; c < matrix.columns(); c++ ) {
+ for( int r = 0; r < matrix.rows(); r++ ) {
+ sums[0][c] += matrix.getQuick(r, c);
+ }
+ }
+ return new ColtMatrix(sums);
+ }
+
+ @Override
+ public MLMatrix scaleInPlace( double scalar ) {
+ this.matrix.assign(( d ) -> d * scalar);
+ return this;
+ }
+
+ @Override
+ public MLMatrix scaleInPlace( MLMatrix S ) {
+ this.matrix.forEachNonZero(( r, c, d ) -> d * S.get(r, c));
+ return this;
+ }
+
+ @Override
+ public MLMatrix apply( DoubleUnaryOperator op ) {
+ ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
+ newMatrix.matrix.assign(matrix);
+ newMatrix.matrix.assign(( d ) -> op.applyAsDouble(d));
+ return newMatrix;
+ }
+
+ @Override
+ public MLMatrix applyInPlace( DoubleUnaryOperator op ) {
+ this.matrix.assign(( d ) -> op.applyAsDouble(d));
+ return this;
+ }
+
+ @Override
+ public String toString() {
+ return matrix.toString();
+ }
+
+ }
+
+}
diff --git a/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java b/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java
index b6464e5..b1cd4d6 100644
--- a/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java
+++ b/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java
@@ -15,7 +15,7 @@ public class NeuralNetwork {
Writer writer = ResourceStreamProvider.getWriter(source);
PrintWriter out = new PrintWriter(writer)
) {
- for( NeuronLayer layer: network.layers ) {
+ for( NeuronLayer layer : network.layers ) {
out.print(layer.getNeuronCount());
out.print(' ');
out.print(layer.getInputCount());
@@ -23,20 +23,44 @@ public class NeuralNetwork {
for( int i = 0; i < layer.getInputCount(); i++ ) {
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
- out.print(layer.weights.coefficients[i][j]);
+ out.print(layer.weights.get(i, j));
out.print(' ');
}
out.println();
}
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
- out.print(layer.biases[j]);
+ out.print(layer.biases.get(0, j));
out.print(' ');
}
out.println();
}
out.flush();
} catch( IOException ex ) {
- LOG.warn(ex, "");
+ LOG.error(ex, "");
+ }
+ }
+
+ public static void saveToDataFile( String source, NeuralNetwork network ) {
+ try(
+ OutputStream stream = ResourceStreamProvider.getOutputStream(source);
+ DataOutputStream out = new DataOutputStream(stream)
+ ) {
+ for( NeuronLayer layer : network.layers ) {
+ out.writeInt(layer.getNeuronCount());
+ out.writeInt(layer.getInputCount());
+
+ for( int i = 0; i < layer.getInputCount(); i++ ) {
+ for( int j = 0; j < layer.getNeuronCount(); j++ ) {
+ out.writeDouble(layer.weights.get(i, j));
+ }
+ }
+ for( int j = 0; j < layer.getNeuronCount(); j++ ) {
+ out.writeDouble(layer.biases.get(0, j));
+ }
+ }
+ out.flush();
+ } catch( IOException ex ) {
+ LOG.error(ex, "");
}
}
@@ -56,13 +80,13 @@ public class NeuralNetwork {
for( int i = 0; i < inputs; i++ ) {
split = in.readLine().split(" ");
for( int j = 0; j < neurons; j++ ) {
- layer.weights.coefficients[i][j] = Double.parseDouble(split[j]);
+ layer.weights.set(i, j, Double.parseDouble(split[j]));
}
}
// Load Biases
split = in.readLine().split(" ");
for( int j = 0; j < neurons; j++ ) {
- layer.biases[j] = Double.parseDouble(split[j]);
+ layer.biases.set(0, j, Double.parseDouble(split[j]));
}
layers.add(layer);
@@ -70,29 +94,30 @@ public class NeuralNetwork {
return new NeuralNetwork(layers);
} catch( IOException | NoSuchElementException ex ) {
- LOG.warn(ex, "Could not load neural network from source <%s>", source);
+ LOG.error(ex, "Could not load neural network from source <%s>", source);
}
return null;
}
- /*public static NeuralNetwork loadFromFile( String source ) {
+ public static NeuralNetwork loadFromDataFile( String source ) {
try(
InputStream stream = ResourceStreamProvider.getInputStream(source);
- Scanner in = new Scanner(stream)
+ DataInputStream in = new DataInputStream(stream)
) {
List layers = new LinkedList<>();
- while( in.hasNext() ) {
- int neurons = in.nextInt();
- int inputs = in.nextInt();
+ while( in.available() > 0 ) {
+ int neurons = in.readInt();
+ int inputs = in.readInt();
NeuronLayer layer = new NeuronLayer(neurons, inputs);
for( int i = 0; i < inputs; i++ ) {
for( int j = 0; j < neurons; j++ ) {
- layer.weights.coefficients[i][j] = in.nextDouble();
+ layer.weights.set(i, j, in.readDouble());
}
}
+ // Load Biases
for( int j = 0; j < neurons; j++ ) {
- layer.biases[j] = in.nextDouble();
+ layer.biases.set(0, j, in.readDouble());
}
layers.add(layer);
@@ -100,14 +125,14 @@ public class NeuralNetwork {
return new NeuralNetwork(layers);
} catch( IOException | NoSuchElementException ex ) {
- LOG.warn(ex, "Could not load neural network from source <%s>", source);
+ LOG.error(ex, "Could not load neural network from source <%s>", source);
}
return null;
- }*/
+ }
private NeuronLayer[] layers;
- private double[][] output;
+ private MLMatrix output;
private double learningRate = 0.1;
@@ -128,7 +153,7 @@ public class NeuralNetwork {
for( int i = 0; i < layers.size(); i++ ) {
this.layers[i] = layers.get(i);
if( i > 0 ) {
- this.layers[i-1].setNextLayer(this.layers[i]);
+ this.layers[i - 1].setNextLayer(this.layers[i]);
}
}
}
@@ -138,7 +163,7 @@ public class NeuralNetwork {
for( int i = 0; i < layers.length; i++ ) {
this.layers[i] = layers[i];
if( i > 0 ) {
- this.layers[i-1].setNextLayer(this.layers[i]);
+ this.layers[i - 1].setNextLayer(this.layers[i]);
}
}
}
@@ -146,6 +171,7 @@ public class NeuralNetwork {
public int getLayerCount() {
return layers.length;
}
+
public NeuronLayer[] getLayers() {
return layers;
}
@@ -162,17 +188,25 @@ public class NeuralNetwork {
this.learningRate = pLearningRate;
}
- public double[][] getOutput() {
+ public MLMatrix getOutput() {
return output;
}
- public double[][] predict( double[][] inputs ) {
+ public MLMatrix predict( double[][] inputs ) {
//this.output = layers[1].apply(layers[0].apply(inputs));
+ return predict(MatrixFactory.create(inputs));
+ }
+
+ public MLMatrix predict( MLMatrix inputs ) {
this.output = layers[0].apply(inputs);
return this.output;
}
public void learn( double[][] expected ) {
+ learn(MatrixFactory.create(expected));
+ }
+
+ public void learn( MLMatrix expected ) {
layers[layers.length - 1].backprop(expected, learningRate);
}
diff --git a/src/main/java/schule/ngb/zm/ml/NeuronLayer.java b/src/main/java/schule/ngb/zm/ml/NeuronLayer.java
index 9767fad..4ce061d 100644
--- a/src/main/java/schule/ngb/zm/ml/NeuronLayer.java
+++ b/src/main/java/schule/ngb/zm/ml/NeuronLayer.java
@@ -1,50 +1,66 @@
package schule.ngb.zm.ml;
-import java.util.Arrays;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
-public class NeuronLayer implements Function {
+/**
+ * Implementierung einer Neuronenebene in einem Neuonalen Netz.
+ *
+ * Eine Ebene besteht aus einer Anzahl an Neuronen die jeweils eine
+ * Anzahl Eingänge haben. Die Eingänge erhalten als Signal die Ausgabe
+ * der vorherigen Ebene und berechnen die Ausgabe des jeweiligen Neurons.
+ */
+public class NeuronLayer implements Function {
+
+ public static NeuronLayer fromArray( double[][] weights, boolean transpose ) {
+ NeuronLayer layer;
+ if( transpose ) {
+ layer = new NeuronLayer(weights.length, weights[0].length);
+ } else {
+ layer = new NeuronLayer(weights[0].length, weights.length);
+ }
- public static NeuronLayer fromArray( double[][] weights ) {
- NeuronLayer layer = new NeuronLayer(weights[0].length, weights.length);
for( int i = 0; i < weights[0].length; i++ ) {
for( int j = 0; j < weights.length; j++ ) {
- layer.weights.coefficients[i][j] = weights[i][j];
+ if( transpose ) {
+ layer.weights.set(j, i, weights[i][j]);
+ } else {
+ layer.weights.set(i, j, weights[i][j]);
+ }
}
}
+
return layer;
}
- public static NeuronLayer fromArray( double[][] weights, double[] biases ) {
- NeuronLayer layer = new NeuronLayer(weights[0].length, weights.length);
- for( int i = 0; i < weights[0].length; i++ ) {
- for( int j = 0; j < weights.length; j++ ) {
- layer.weights.coefficients[i][j] = weights[i][j];
- }
- }
+ public static NeuronLayer fromArray( double[][] weights, double[] biases, boolean transpose ) {
+ NeuronLayer layer = fromArray(weights, transpose);
+
for( int j = 0; j < biases.length; j++ ) {
- layer.biases[j] = biases[j];
+ layer.biases.set(0, j, biases[j]);
}
+
return layer;
}
-
- Matrix weights;
- double[] biases;
+ MLMatrix weights;
+
+ MLMatrix biases;
NeuronLayer previous, next;
DoubleUnaryOperator activationFunction, activationFunctionDerivative;
- double[][] lastOutput, lastInput;
+ MLMatrix lastOutput, lastInput;
public NeuronLayer( int neurons, int inputs ) {
- weights = new Matrix(inputs, neurons);
- weights.initializeRandom(-1, 1);
+ weights = MatrixFactory
+ .create(inputs, neurons)
+ .initializeRandom();
- biases = new double[neurons];
- Arrays.fill(biases, 0.0); // TODO: Random?
+ biases = MatrixFactory
+ .create(1, neurons)
+ .initializeZero();
activationFunction = MLMath::sigmoid;
activationFunctionDerivative = MLMath::sigmoidDerivative;
@@ -85,45 +101,42 @@ public class NeuronLayer implements Function {
}
}
- public Matrix getWeights() {
+ public MLMatrix getWeights() {
return weights;
}
+ public MLMatrix getBiases() {
+ return biases;
+ }
+
public int getNeuronCount() {
- return weights.coefficients[0].length;
+ return weights.columns();
}
public int getInputCount() {
- return weights.coefficients.length;
+ return weights.rows();
}
- public double[][] getLastOutput() {
+ public MLMatrix getLastOutput() {
return lastOutput;
}
- public void setWeights( double[][] newWeights ) {
- weights.coefficients = MLMath.copyMatrix(newWeights);
- }
-
- public void adjustWeights( double[][] adjustment ) {
- weights.coefficients = MLMath.matrixAdd(weights.coefficients, adjustment);
+ public void setWeights( MLMatrix newWeights ) {
+ weights = newWeights.duplicate();
}
@Override
public String toString() {
- return weights.toString() + "\n" + Arrays.toString(biases);
+ return "weights:\n" + weights.toString() + "\nbiases:\n" + biases.toString();
}
@Override
- public double[][] apply( double[][] inputs ) {
- lastInput = inputs;
- lastOutput = MLMath.matrixApply(
- MLMath.biasAdd(
- MLMath.matrixMultiply(inputs, weights.coefficients),
- biases
- ),
- activationFunction
- );
+ public MLMatrix apply( MLMatrix inputs ) {
+ lastInput = inputs.duplicate();
+ lastOutput = inputs
+ .multiplyAddBias(weights, biases)
+ .applyInPlace(activationFunction);
+
if( next != null ) {
return next.apply(lastOutput);
} else {
@@ -132,36 +145,41 @@ public class NeuronLayer implements Function {
}
@Override
- public Function compose( Function super V, ? extends double[][]> before ) {
+ public Function compose( Function super V, ? extends MLMatrix> before ) {
return ( in ) -> apply(before.apply(in));
}
@Override
- public Function andThen( Function super double[][], ? extends V> after ) {
+ public Function andThen( Function super MLMatrix, ? extends V> after ) {
return ( in ) -> after.apply(apply(in));
}
- public void backprop( double[][] expected, double learningRate ) {
- double[][] error, delta, adjustment;
+ public void backprop( MLMatrix expected, double learningRate ) {
+ MLMatrix error, adjustment;
if( next == null ) {
- error = MLMath.matrixSub(expected, this.lastOutput);
+ error = expected.sub(lastOutput);
} else {
- error = MLMath.matrixMultiply(expected, MLMath.matrixTranspose(next.weights.coefficients));
+ error = expected.multiplyTransposed(next.weights);
}
- delta = MLMath.matrixScale(error, MLMath.matrixApply(this.lastOutput, this.activationFunctionDerivative));
+ error.scaleInPlace(
+ lastOutput.apply(this.activationFunctionDerivative)
+ );
// Hier schon leraningRate anwenden?
// See https://towardsdatascience.com/understanding-and-implementing-neural-networks-in-java-from-scratch-61421bb6352c
//delta = MLMath.matrixApply(delta, ( x ) -> learningRate * x);
if( previous != null ) {
- previous.backprop(delta, learningRate);
+ previous.backprop(error, learningRate);
}
- biases = MLMath.biasAdjust(biases, MLMath.matrixApply(delta, ( x ) -> learningRate * x));
+ biases.addInPlace(
+ error.colSums().scaleInPlace(
+ -learningRate / (double) error.rows()
+ )
+ );
- adjustment = MLMath.matrixMultiply(MLMath.matrixTranspose(lastInput), delta);
- adjustment = MLMath.matrixApply(adjustment, ( x ) -> learningRate * x);
- this.adjustWeights(adjustment);
+ adjustment = lastInput.transposedMultiplyAndScale(error, learningRate);
+ weights.addInPlace(adjustment);
}
}
diff --git a/src/test/java/schule/ngb/zm/ml/MLMatrixTest.java b/src/test/java/schule/ngb/zm/ml/MLMatrixTest.java
new file mode 100644
index 0000000..1e0a1da
--- /dev/null
+++ b/src/test/java/schule/ngb/zm/ml/MLMatrixTest.java
@@ -0,0 +1,426 @@
+package schule.ngb.zm.ml;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+import schule.ngb.zm.util.Timer;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+class MLMatrixTest {
+
+ private TestInfo info;
+
+ @BeforeEach
+ void saveTestInfo( TestInfo info ) {
+ this.info = info;
+ }
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void get( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix M = MatrixFactory.create(new double[][]{
+ {1, 2, 3},
+ {4, 5, 6}
+ });
+
+ assertEquals(mType, M.getClass());
+
+ assertEquals(1.0, M.get(0,0));
+ assertEquals(4.0, M.get(1,0));
+ assertEquals(6.0, M.get(1,2));
+ }
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void initializeOne( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix m = MatrixFactory.create(4, 4);
+ m.initializeOne();
+
+ assertEquals(mType, m.getClass());
+
+ for( int i = 0; i < m.rows(); i++ ) {
+ for( int j = 0; j < m.columns(); j++ ) {
+ assertEquals(1.0, m.get(i, j));
+ }
+ }
+ }
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void initializeZero( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix m = MatrixFactory.create(4, 4);
+ m.initializeZero();
+
+ assertEquals(mType, m.getClass());
+
+ for( int i = 0; i < m.rows(); i++ ) {
+ for( int j = 0; j < m.columns(); j++ ) {
+ assertEquals(0.0, m.get(i, j));
+ }
+ }
+ }
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void initializeRandom( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix m = MatrixFactory.create(4, 4);
+ m.initializeRandom();
+
+ assertEquals(mType, m.getClass());
+
+ for( int i = 0; i < m.rows(); i++ ) {
+ for( int j = 0; j < m.columns(); j++ ) {
+ double d = m.get(i, j);
+ assertTrue(-1.0 <= d && d < 1.0);
+ }
+ }
+ }
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void multiplyTransposed( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix A = MatrixFactory.create(new double[][]{
+ {1.0, 2.0, 3.0, 4.0},
+ {1.0, 2.0, 3.0, 4.0},
+ {1.0, 2.0, 3.0, 4.0}
+ });
+ MLMatrix B = MatrixFactory.create(new double[][]{
+ {1, 3, 5, 7},
+ {2, 4, 6, 8}
+ });
+
+ MLMatrix C = A.multiplyTransposed(B);
+
+ assertEquals(mType, A.getClass());
+ assertEquals(mType, B.getClass());
+ assertEquals(mType, C.getClass());
+
+ assertEquals(3, C.rows());
+ assertEquals(2, C.columns());
+ for( int i = 0; i < C.rows(); i++ ) {
+ assertEquals(50.0, C.get(i, 0));
+ assertEquals(60.0, C.get(i, 1));
+ }
+ }
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void multiplyAddBias( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix A = MatrixFactory.create(new double[][]{
+ {1.0, 2.0, 3.0, 4.0},
+ {1.0, 2.0, 3.0, 4.0},
+ {1.0, 2.0, 3.0, 4.0}
+ });
+ MLMatrix B = MatrixFactory.create(new double[][]{
+ {1.0, 2.0},
+ {3.0, 4.0},
+ {5.0, 6.0},
+ {7.0, 8.0}
+ });
+ MLMatrix V = MatrixFactory.create(new double[][]{
+ {1000.0, 2000.0}
+ });
+
+ MLMatrix C = A.multiplyAddBias(B, V);
+
+ assertEquals(mType, A.getClass());
+ assertEquals(mType, B.getClass());
+ assertEquals(mType, C.getClass());
+ assertEquals(mType, V.getClass());
+
+ assertEquals(3, C.rows());
+ assertEquals(2, C.columns());
+ for( int i = 0; i < C.rows(); i++ ) {
+ assertEquals(1050.0, C.get(i, 0));
+ assertEquals(2060.0, C.get(i, 1));
+ }
+ }
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void transposedMultiplyAndScale( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix A = MatrixFactory.create(new double[][]{
+ {1, 1, 1},
+ {2, 2, 2},
+ {3, 3, 3},
+ {4, 4, 4}
+ });
+ MLMatrix B = MatrixFactory.create(new double[][]{
+ {1.0, 2.0},
+ {3.0, 4.0},
+ {5.0, 6.0},
+ {7.0, 8.0}
+ });
+
+ MLMatrix C = A.transposedMultiplyAndScale(B, 2.0);
+
+ assertEquals(mType, A.getClass());
+ assertEquals(mType, B.getClass());
+ assertEquals(mType, C.getClass());
+
+ assertEquals(3, C.rows());
+ assertEquals(2, C.columns());
+ for( int i = 0; i < C.rows(); i++ ) {
+ assertEquals(100.0, C.get(i, 0));
+ assertEquals(120.0, C.get(i, 1));
+ }
+ }
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void apply( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix M = MatrixFactory.create(new double[][]{
+ {1, 1, 1},
+ {2, 2, 2},
+ {3, 3, 3},
+ {4, 4, 4}
+ });
+
+ MLMatrix R = M.apply(( d ) -> d * d);
+
+ assertEquals(mType, M.getClass());
+ assertEquals(mType, R.getClass());
+ assertNotSame(M, R);
+
+ for( int i = 0; i < M.rows(); i++ ) {
+ for( int j = 0; j < M.columns(); j++ ) {
+ assertEquals(
+ (i + 1) * (i + 1), R.get(i, j),
+ msg("(%d,%d)", "apply", i, j)
+ );
+ }
+ }
+
+ MLMatrix M2 = M.applyInPlace(( d ) -> d * d * d);
+ assertSame(M, M2);
+ for( int i = 0; i < M.rows(); i++ ) {
+ for( int j = 0; j < M.columns(); j++ ) {
+ assertEquals(
+ (i + 1) * (i + 1) * (i + 1), M.get(i, j),
+ msg("(%d,%d)", "applyInPlace", i, j)
+ );
+ }
+ }
+ }
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void add( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix M = MatrixFactory.create(new double[][]{
+ {1, 1, 1},
+ {2, 2, 2},
+ {3, 3, 3},
+ {4, 4, 4}
+ });
+
+ MLMatrix R = M.add(M);
+
+ assertEquals(mType, M.getClass());
+ assertEquals(mType, R.getClass());
+ assertNotSame(M, R);
+
+ for( int i = 0; i < M.rows(); i++ ) {
+ for( int j = 0; j < M.columns(); j++ ) {
+ assertEquals(
+ (i + 1) + (i + 1), R.get(i, j),
+ msg("(%d,%d)", "add", i, j)
+ );
+ }
+ }
+
+ MLMatrix M2 = M.addInPlace(R);
+ assertSame(M, M2);
+ for( int i = 0; i < M.rows(); i++ ) {
+ for( int j = 0; j < M.columns(); j++ ) {
+ assertEquals(
+ (i + 1) + (i + 1) + (i + 1), M.get(i, j),
+ msg("(%d,%d)", "addInPlace", i, j)
+ );
+ }
+ }
+ }
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void sub( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix M = MatrixFactory.create(new double[][]{
+ {1, 1, 1},
+ {2, 2, 2},
+ {3, 3, 3},
+ {4, 4, 4}
+ });
+
+ MLMatrix R = M.sub(M);
+
+ assertEquals(mType, M.getClass());
+ assertEquals(mType, R.getClass());
+ assertNotSame(M, R);
+
+ for( int i = 0; i < M.rows(); i++ ) {
+ for( int j = 0; j < M.columns(); j++ ) {
+ assertEquals(
+ 0.0, R.get(i, j),
+ msg("(%d,%d)", "sub", i, j)
+ );
+ }
+ }
+ }
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void colSums( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix M = MatrixFactory.create(new double[][]{
+ {1, 2, 3},
+ {1, 2, 3},
+ {1, 2, 3},
+ {1, 2, 3}
+ });
+
+ MLMatrix R = M.colSums();
+
+ assertEquals(mType, M.getClass());
+ assertEquals(mType, R.getClass());
+ assertNotSame(M, R);
+
+ assertEquals(1, R.rows());
+ assertEquals(3, R.columns());
+ for( int j = 0; j < M.columns(); j++ ) {
+ assertEquals(
+ (j+1)*4, R.get(0, j),
+ msg("(%d,%d)", "colSums", 0, j)
+ );
+ }
+ }
+
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void duplicate( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix M = MatrixFactory.create(new double[][]{
+ {1, 2, 3},
+ {1, 2, 3},
+ {1, 2, 3},
+ {1, 2, 3}
+ });
+
+ MLMatrix R = M.duplicate();
+
+ assertEquals(mType, M.getClass());
+ assertEquals(mType, R.getClass());
+ assertNotSame(M, R);
+
+ for( int i = 0; i < M.rows(); i++ ) {
+ for( int j = 0; j < M.columns(); j++ ) {
+ assertEquals(
+ M.get(i, j), R.get(i, j),
+ msg("(%d,%d)", "duplicate", i, j)
+ );
+ }
+ }
+ }
+
+ @ParameterizedTest
+ @ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
+ void scale( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ MLMatrix M = MatrixFactory.create(new double[][]{
+ {1, 1, 1},
+ {2, 2, 2},
+ {3, 3, 3},
+ {4, 4, 4}
+ });
+
+ MLMatrix M2 = M.scaleInPlace(2.0);
+
+ assertEquals(mType, M.getClass());
+ assertEquals(mType, M2.getClass());
+ assertSame(M, M2);
+
+ for( int i = 0; i < M.rows(); i++ ) {
+ for( int j = 0; j < M.columns(); j++ ) {
+ assertEquals(
+ (i+1)*2.0, M2.get(i, j),
+ msg("(%d,%d)", "scaleInPlace", i, j)
+ );
+ }
+ }
+
+ MLMatrix M3 = M.scaleInPlace(M);
+ assertSame(M, M3);
+ for( int i = 0; i < M.rows(); i++ ) {
+ for( int j = 0; j < M.columns(); j++ ) {
+ assertEquals(
+ ((i+1)*2.0)*((i+1)*2.0), M.get(i, j),
+ msg("(%d,%d)", "addInPlace", i, j)
+ );
+ }
+ }
+ }
+
+ private String msg( String msg, String methodName, Object... args ) {
+ String testName = this.info.getTestMethod().get().getName();
+ String className = MatrixFactory.matrixType.getSimpleName();
+ return String.format("[" + testName + "(" + className + ") " + methodName + "()] " + msg, args);
+ }
+
+ //@ParameterizedTest
+ //@ValueSource( classes = {MatrixFactory.ColtMatrix.class, DoubleMatrix.class} )
+ void speed( Class extends MLMatrix> mType ) {
+ MatrixFactory.matrixType = mType;
+
+ int N = 10;
+ int rows = 1000;
+ int cols = 1000;
+
+ Timer timer = new Timer();
+
+ MLMatrix M = MatrixFactory.create(rows, cols);
+ timer.start();
+ for( int i = 0; i < N; i++ ) {
+ M.initializeRandom();
+ }
+ timer.stop();
+ System.err.println(msg("%d iterations: %d ms", "initializeRandom", N, timer.getMillis()));
+
+ timer.reset();
+
+ MLMatrix B = MatrixFactory.create(rows*2, M.columns());
+ B.initializeRandom();
+
+ timer.start();
+ for( int i = 0; i < N; i++ ) {
+ M.multiplyTransposed(B);
+ }
+ timer.stop();
+ System.err.println(msg("%d iterations: %d ms", "multiplyTransposed", N, timer.getMillis()));
+ }
+
+}
diff --git a/src/test/java/schule/ngb/zm/ml/MatrixTest.java b/src/test/java/schule/ngb/zm/ml/MatrixTest.java
deleted file mode 100644
index 6277e62..0000000
--- a/src/test/java/schule/ngb/zm/ml/MatrixTest.java
+++ /dev/null
@@ -1,57 +0,0 @@
-package schule.ngb.zm.ml;
-
-import org.junit.jupiter.api.Test;
-
-import java.util.Arrays;
-
-import static org.junit.jupiter.api.Assertions.*;
-
-class MatrixTest {
-
- @Test
- void initializeIdentity() {
- Matrix m = new Matrix(4, 4);
- m.initializeIdentity();
-
- assertArrayEquals(new double[]{1.0, 0.0, 0.0, 0.0}, m.coefficients[0]);
- assertArrayEquals(new double[]{0.0, 1.0, 0.0, 0.0}, m.coefficients[1]);
- assertArrayEquals(new double[]{0.0, 0.0, 1.0, 0.0}, m.coefficients[2]);
- assertArrayEquals(new double[]{0.0, 0.0, 0.0, 1.0}, m.coefficients[3]);
- }
-
- @Test
- void initializeOne() {
- Matrix m = new Matrix(4, 4);
- m.initializeOne();
-
- double[] ones = new double[]{1.0, 1.0, 1.0, 1.0};
- assertArrayEquals(ones, m.coefficients[0]);
- assertArrayEquals(ones, m.coefficients[1]);
- assertArrayEquals(ones, m.coefficients[2]);
- assertArrayEquals(ones, m.coefficients[3]);
- }
-
- @Test
- void initializeZero() {
- Matrix m = new Matrix(4, 4);
- m.initializeZero();
-
- double[] zeros = new double[]{0.0, 0.0, 0.0, 0.0};
- assertArrayEquals(zeros, m.coefficients[0]);
- assertArrayEquals(zeros, m.coefficients[1]);
- assertArrayEquals(zeros, m.coefficients[2]);
- assertArrayEquals(zeros, m.coefficients[3]);
- }
-
- @Test
- void initializeRandom() {
- Matrix m = new Matrix(4, 4);
- m.initializeRandom(-1, 1);
-
- assertTrue(Arrays.stream(m.coefficients[0]).allMatch((d) -> -1.0 <= d && d < 1.0));
- assertTrue(Arrays.stream(m.coefficients[1]).allMatch((d) -> -1.0 <= d && d < 1.0));
- assertTrue(Arrays.stream(m.coefficients[2]).allMatch((d) -> -1.0 <= d && d < 1.0));
- assertTrue(Arrays.stream(m.coefficients[3]).allMatch((d) -> -1.0 <= d && d < 1.0));
- }
-
-}
diff --git a/src/test/java/schule/ngb/zm/ml/NeuralNetworkTest.java b/src/test/java/schule/ngb/zm/ml/NeuralNetworkTest.java
index f81fe4d..3dca84e 100644
--- a/src/test/java/schule/ngb/zm/ml/NeuralNetworkTest.java
+++ b/src/test/java/schule/ngb/zm/ml/NeuralNetworkTest.java
@@ -2,15 +2,14 @@ package schule.ngb.zm.ml;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
+import schule.ngb.zm.Constants;
import schule.ngb.zm.util.Log;
+import schule.ngb.zm.util.Timer;
-import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
-import static org.junit.jupiter.api.Assertions.*;
-
class NeuralNetworkTest {
@BeforeAll
@@ -18,7 +17,14 @@ class NeuralNetworkTest {
Log.enableGlobalDebugging();
}
- @Test
+ @BeforeAll
+ static void setupMatrixLibrary() {
+ Constants.setSeed(1001);
+ //MatrixFactory.matrixType = MatrixFactory.ColtMatrix.class;
+ MatrixFactory.matrixType = DoubleMatrix.class;
+ }
+
+ /*@Test
void readWrite() {
// XOR Dataset
NeuralNetwork net = new NeuralNetwork(2, 4, 1);
@@ -53,7 +59,7 @@ class NeuralNetworkTest {
}
assertArrayEquals(net.predict(inputs), net2.predict(inputs));
- }
+ }*/
@Test
void learnXor() {
@@ -78,14 +84,14 @@ class NeuralNetworkTest {
}
// calculate predictions
- double[][] predictions = net.predict(inputs);
+ MLMatrix predictions = net.predict(inputs);
for( int i = 0; i < 4; i++ ) {
- int parsed_pred = predictions[i][0] < 0.5 ? 0 : 1;
+ int parsed_pred = predictions.get(i, 0) < 0.5 ? 0 : 1;
System.out.printf(
"{%.0f, %.0f} = %.4f (%d) -> %s\n",
inputs[i][0], inputs[i][1],
- predictions[i][0],
+ predictions.get(i, 0),
parsed_pred,
parsed_pred == outputs[i][0] ? "correct" : "miss"
);
@@ -109,12 +115,16 @@ class NeuralNetworkTest {
for( int i = 0; i < trainingData.size(); i++ ) {
inputs[i][0] = trainingData.get(i).a;
inputs[i][1] = trainingData.get(i).b;
- outputs[i][0] = trainingData.get(i).result;
+ outputs[i][0] = trainingData.get(i).getResult();
}
+ Timer timer = new Timer();
+
System.out.println("Training the neural net to learn "+OPERATION+"...");
+ timer.start();
net.train(inputs, outputs, TRAINING_CYCLES);
- System.out.println(" finished training");
+ timer.stop();
+ System.out.println(" finished training (" + timer.getMillis() + "ms)");
for( int i = 1; i <= net.getLayerCount(); i++ ) {
System.out.println("Layer " +i + " weights");
@@ -136,19 +146,18 @@ class NeuralNetworkTest {
System.out.printf(
"Prediction on data (%.2f, %.2f) was %.4f, expected %.2f (of by %.4f)\n",
data.a, data.b,
- net.getOutput()[0][0],
- data.result,
- net.getOutput()[0][0] - data.result
+ net.getOutput().get(0, 0),
+ data.getResult(),
+ net.getOutput().get(0, 0) - data.getResult()
);
}
private List createTrainingSet( int trainingSetSize, CalcType operation ) {
- Random random = new Random();
List tuples = new ArrayList<>();
for( int i = 0; i < trainingSetSize; i++ ) {
- double s1 = random.nextDouble() * 0.5;
- double s2 = random.nextDouble() * 0.5;
+ double s1 = Constants.random() * 0.5;
+ double s2 = Constants.random() * 0.5;
switch( operation ) {
case ADD:
@@ -181,7 +190,6 @@ class NeuralNetworkTest {
double a;
double b;
- double result;
CalcType type;
TestData( double a, double b ) {
@@ -189,6 +197,8 @@ class NeuralNetworkTest {
this.b = b;
}
+ abstract double getResult();
+
}
private static final class AddData extends TestData {
@@ -197,7 +207,9 @@ class NeuralNetworkTest {
public AddData( double a, double b ) {
super(a, b);
- result = a + b;
+ }
+ double getResult() {
+ return a+b;
}
}
@@ -208,7 +220,9 @@ class NeuralNetworkTest {
public SubData( double a, double b ) {
super(a, b);
- result = a - b;
+ }
+ double getResult() {
+ return a-b;
}
}
@@ -219,7 +233,9 @@ class NeuralNetworkTest {
public MulData( double a, double b ) {
super(a, b);
- result = a * b;
+ }
+ double getResult() {
+ return a*b;
}
}
@@ -233,7 +249,9 @@ class NeuralNetworkTest {
if( b == 0.0 ) {
b = .1;
}
- result = a / b;
+ }
+ double getResult() {
+ return a/b;
}
}
@@ -244,7 +262,12 @@ class NeuralNetworkTest {
public ModData( double b, double a ) {
super(b, a);
- result = a % b;
+ if( b == 0.0 ) {
+ b = .1;
+ }
+ }
+ double getResult() {
+ return a%b;
}
}