From 16477463d4556737dae6e122f23ab3366d5f710a Mon Sep 17 00:00:00 2001 From: "J. Neugebauer" Date: Wed, 20 Jul 2022 17:09:09 +0200 Subject: [PATCH] java doc und refactorings --- .../java/schule/ngb/zm/ml/DoubleMatrix.java | 146 ++++++++- src/main/java/schule/ngb/zm/ml/MLMatrix.java | 284 ++++++++++++++---- .../java/schule/ngb/zm/ml/MatrixFactory.java | 51 +++- .../java/schule/ngb/zm/ml/NeuralNetwork.java | 62 ++-- .../java/schule/ngb/zm/ml/NeuronLayer.java | 41 ++- .../java/schule/ngb/zm/ml/MLMatrixTest.java | 34 +++ 6 files changed, 510 insertions(+), 108 deletions(-) diff --git a/src/main/java/schule/ngb/zm/ml/DoubleMatrix.java b/src/main/java/schule/ngb/zm/ml/DoubleMatrix.java index 4a4008c..2144217 100644 --- a/src/main/java/schule/ngb/zm/ml/DoubleMatrix.java +++ b/src/main/java/schule/ngb/zm/ml/DoubleMatrix.java @@ -4,12 +4,60 @@ import schule.ngb.zm.Constants; import java.util.function.DoubleUnaryOperator; -// TODO: Move Math into Matrix class -// TODO: Implement support for optional sci libs +/** + * Eine einfache Implementierung der {@link MLMatrix} zur Verwendung in + * {@link NeuralNetwork}s. + *

+ * Diese Klasse stellt die interne Implementierung der Matrixoperationen dar, + * die zur Berechnung der Gewichte in einem {@link NeuronLayer} notwendig sind. + *

+ * Die Klasse ist nur minimal optimiert und sollte nur für kleine Netze + * verwendet werden. Für größere Netze sollte auf eine der optionalen + * Bibliotheken wie + * Colt zurückgegriffen werden. + */ public final class DoubleMatrix implements MLMatrix { - private int columns, rows; + /** + * Anzahl Zeilen der Matrix. + */ + private int rows; + /** + * Anzahl Spalten der Matrix. + */ + private int columns; + + /** + * Die Koeffizienten der Matrix. + *

+ * Um den Overhead bei Speicher und Zugriffszeiten von zweidimensionalen + * Arrays zu vermeiden wird ein eindimensionales Array verwendet und die + * Indizes mit Spaltenpriorität berechnet. Der Index i des Koeffizienten + * {@code r,c} in Zeile {@code r} und Spalte {@code c} wird bestimmt durch + *

+	 * i = c * rows + r
+	 * 
+ *

+ * Die Werte einer Spalte liegen also hintereinander im Array. Dies sollte + * einen leichten Vorteil bei der {@link #colSums() Spaltensummen} geben. + * Generell sollte eine Iteration über die Matrix der Form + *


+	 * for( int j = 0; j < columns; j++ ) {
+	 *     for( int i = 0; i < rows; i++ ) {
+	 *         // ...
+	 *     }
+	 * }
+	 * 
+ * etwas schneller sein als + *

+	 * for( int i = 0; i < rows; i++ ) {
+	 *     for( int j = 0; j < columns; j++ ) {
+	 *         // ...
+	 *     }
+	 * }
+	 * 
+ */ double[] coefficients; public DoubleMatrix( int rows, int cols ) { @@ -29,6 +77,11 @@ public final class DoubleMatrix implements MLMatrix { } } + /** + * Initialisiert diese Matrix als Kopie der angegebenen Matrix. + * + * @param other Die zu kopierende Matrix. + */ public DoubleMatrix( DoubleMatrix other ) { this.rows = other.rows(); this.columns = other.columns(); @@ -39,55 +92,100 @@ public final class DoubleMatrix implements MLMatrix { rows * columns); } + /** + * {@inheritDoc} + */ + @Override public int columns() { return columns; } + /** + * {@inheritDoc} + */ + @Override public int rows() { return rows; } - public double[][] coefficients() { - return new double[rows][columns]; - } - + /** + * {@inheritDoc} + */ int idx( int r, int c ) { return c * rows + r; } + /** + * {@inheritDoc} + */ + @Override public double get( int row, int col ) { - return coefficients[idx(row, col)]; + try { + return coefficients[idx(row, col)]; + } catch( ArrayIndexOutOfBoundsException ex ) { + throw new IllegalArgumentException("No element at row=" + row + ", column=" + col, ex); + } } + /** + * {@inheritDoc} + */ + @Override public MLMatrix set( int row, int col, double value ) { - coefficients[idx(row, col)] = value; + try { + coefficients[idx(row, col)] = value; + } catch( ArrayIndexOutOfBoundsException ex ) { + throw new IllegalArgumentException("No element at row=" + row + ", column=" + col, ex); + } return this; } + /** + * {@inheritDoc} + */ + @Override public MLMatrix initializeRandom() { return initializeRandom(-1.0, 1.0); } + /** + * {@inheritDoc} + */ + @Override public MLMatrix initializeRandom( double lower, double upper ) { applyInPlace(( d ) -> ((upper - lower) * Constants.random()) + lower); return this; } + /** + * {@inheritDoc} + */ + @Override public MLMatrix initializeOne() { applyInPlace(( d ) -> 1.0); return this; } + /** + * {@inheritDoc} + */ + @Override public MLMatrix initializeZero() { applyInPlace(( d ) -> 0.0); return this; } + /** + * {@inheritDoc} + */ @Override public MLMatrix duplicate() { return new DoubleMatrix(this); } + /** + * {@inheritDoc} + */ @Override public MLMatrix multiplyTransposed( MLMatrix B ) { /*return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj( @@ -109,6 +207,9 @@ public final class DoubleMatrix implements MLMatrix { return result; } + /** + * {@inheritDoc} + */ @Override public MLMatrix multiplyAddBias( final MLMatrix B, final MLMatrix C ) { /*return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj( @@ -131,6 +232,9 @@ public final class DoubleMatrix implements MLMatrix { return result; } + /** + * {@inheritDoc} + */ @Override public MLMatrix transposedMultiplyAndScale( final MLMatrix B, final double scalar ) { /*return new DoubleMatrix(IntStream.range(0, columns).parallel().mapToObj( @@ -153,6 +257,9 @@ public final class DoubleMatrix implements MLMatrix { return result; } + /** + * {@inheritDoc} + */ @Override public MLMatrix add( MLMatrix B ) { /*return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj( @@ -169,6 +276,9 @@ public final class DoubleMatrix implements MLMatrix { return sum; } + /** + * {@inheritDoc} + */ @Override public MLMatrix addInPlace( MLMatrix B ) { for( int j = 0; j < columns; j++ ) { @@ -179,6 +289,9 @@ public final class DoubleMatrix implements MLMatrix { return this; } + /** + * {@inheritDoc} + */ @Override public MLMatrix sub( MLMatrix B ) { /*return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj( @@ -195,6 +308,9 @@ public final class DoubleMatrix implements MLMatrix { return diff; } + /** + * {@inheritDoc} + */ @Override public MLMatrix colSums() { /*DoubleMatrix colSums = new DoubleMatrix(1, columns); @@ -214,6 +330,9 @@ public final class DoubleMatrix implements MLMatrix { return colSums; } + /** + * {@inheritDoc} + */ @Override public MLMatrix scaleInPlace( final double scalar ) { for( int i = 0; i < coefficients.length; i++ ) { @@ -222,6 +341,9 @@ public final class DoubleMatrix implements MLMatrix { return this; } + /** + * {@inheritDoc} + */ @Override public MLMatrix scaleInPlace( final MLMatrix S ) { for( int j = 0; j < columns; j++ ) { @@ -232,6 +354,9 @@ public final class DoubleMatrix implements MLMatrix { return this; } + /** + * {@inheritDoc} + */ @Override public MLMatrix apply( DoubleUnaryOperator op ) { DoubleMatrix result = new DoubleMatrix(rows, columns); @@ -241,6 +366,9 @@ public final class DoubleMatrix implements MLMatrix { return result; } + /** + * {@inheritDoc} + */ @Override public MLMatrix applyInPlace( DoubleUnaryOperator op ) { for( int i = 0; i < coefficients.length; i++ ) { diff --git a/src/main/java/schule/ngb/zm/ml/MLMatrix.java b/src/main/java/schule/ngb/zm/ml/MLMatrix.java index bab792c..8de5e76 100644 --- a/src/main/java/schule/ngb/zm/ml/MLMatrix.java +++ b/src/main/java/schule/ngb/zm/ml/MLMatrix.java @@ -2,142 +2,312 @@ package schule.ngb.zm.ml; import java.util.function.DoubleUnaryOperator; +/** + * Interface für Matrizen, die in {@link NeuralNetwork} Klassen verwendet + * werden. + *

+ * Eine implementierende Klasse muss generell zwei Konstruktoren bereitstellen: + *

    + *
  1. {@code MLMatrix(int rows, int columns)} erstellt eine Matrix mit den + * angegebenen Dimensionen und setzt alle Koeffizienten auf 0. + *
  2. {@code MLMatrix(double[][] coefficients} erstellt eine Matrix mit der + * durch das Array gegebenen Dimensionen und setzt die Werte auf die + * jeweiligen Werte des Arrays. + *
+ *

+ * Das Interface ist nicht dazu gedacht eine allgemeine Umsetzung für + * Matrizen-Algebra abzubilden, sondern soll gezielt die im Neuralen Netzwerk + * verwendeten Algorithmen umsetzen. Einerseits würde eine ganz allgemeine + * Matrizen-Klasse nicht im Rahmen der Zeichenmaschine liegen und auf der + * anderen Seite bietet eine Konzentration auf die verwendeten Algorithmen mehr + * Spielraum zur Optimierung. + *

+ * Intern wird das Interface von {@link DoubleMatrix} implementiert. Die Klasse + * ist eine weitestgehend naive Implementierung der Algorithmen mit kleineren + * Optimierungen. Die Verwendung eines generalisierten Interfaces erlaubt aber + * zukünftig die optionale Integration spezialisierterer Algebra-Bibliotheken + * wie + * Colt, um auch große + * Netze effizient berechnen zu können. + */ public interface MLMatrix { + /** + * Die Anzahl der Spalten der Matrix. + * + * @return Spaltenzahl. + */ int columns(); + /** + * Die Anzahl der Zeilen der Matrix. + * + * @return Zeilenzahl. + */ int rows(); - double[][] coefficients(); + /** + * Gibt den Wert an der angegebenen Stelle der Matrix zurück. + * + * @param row Die Spaltennummer zwischen 0 und {@code rows()-1}. + * @param col Die Zeilennummer zwischen 0 und {@code columns()-1} + * @return Den Koeffizienten in der Zeile {@code row} und der Spalte + * {@code col}. + * @throws IllegalArgumentException Falls {@code row >= rows()} oder + * {@code col >= columns()}. + */ + double get( int row, int col ) throws IllegalArgumentException; - double get( int row, int col ); - - MLMatrix set( int row, int col, double value ); + /** + * Setzt den Wert an der angegebenen Stelle der Matrix. + * + * @param row Die Spaltennummer zwischen 0 und {@code rows()-1}. + * @param col Die Zeilennummer zwischen 0 und {@code columns()-1} + * @param value Der neue Wert. + * @return Diese Matrix selbst (method chaining). + * @throws IllegalArgumentException Falls {@code row >= rows()} oder + * {@code col >= columns()}. + */ + MLMatrix set( int row, int col, double value ) throws IllegalArgumentException; + /** + * Setzt jeden Wert in der Matrix auf eine Zufallszahl zwischen -1 und 1. + *

+ * Nach Möglichkeit sollte der + * {@link schule.ngb.zm.Constants#random(int, int) Zufallsgenerator der + * Zeichenmaschine} verwendet werden. + * + * @return Diese Matrix selbst (method chaining). + */ MLMatrix initializeRandom(); + /** + * Setzt jeden Wert in der Matrix auf eine Zufallszahl innerhalb der + * angegebenen Grenzen. + *

+ * Nach Möglichkeit sollte der + * {@link schule.ngb.zm.Constants#random(int, int) Zufallsgenerator der + * Zeichenmaschine} verwendet werden. + * + * @param lower Untere Grenze der Zufallszahlen. + * @param upper Obere Grenze der Zufallszahlen. + * @return Diese Matrix selbst (method chaining). + */ MLMatrix initializeRandom( double lower, double upper ); + /** + * Setzt alle Werte der Matrix auf 1. + * + * @return Diese Matrix selbst (method chaining). + */ MLMatrix initializeOne(); + /** + * Setzt alle Werte der Matrix auf 0. + * + * @return Diese Matrix selbst (method chaining). + */ MLMatrix initializeZero(); - //MLMatrix transpose(); - - //MLMatrix multiply( MLMatrix B ); + /** + * Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der Matrixoperation + *

+	 * C = this . B + V'
+	 * 
+ * wobei {@code this} dieses Matrixobjekt ist und {@code .} für die + * Matrixmultiplikation steht. {@vode V'} ist die Matrix {@code V} + * {@code rows()}-mal untereinander wiederholt. + *

+ * Wenn diese Matrix die Dimension r x c hat, dann muss die Matrix {@code B} + * die Dimension c x m haben und {@code V} eine 1 x m Matrix sein. Die + * Matrix {@code V'} hat also die Dimension r x m, ebenso wie das Ergebnis + * der Operation. + * + * @param B Eine {@code columns()} x m Matrix mit der Multipliziert wird. + * @param V Eine 1 x {@code B.columns()} Matrix mit den Bias-Werten. + * @return Eine {@code rows()} x m Matrix. + * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht + * zur Operation passen. Also + * {@code this.columns() != B.rows()} oder + * {@code B.columns() != V.columns()} oder + * {@code V.rows() != 1}. + */ + MLMatrix multiplyAddBias( MLMatrix B, MLMatrix V ) throws IllegalArgumentException; /** - * Erzeugt eine neue Matrix C mit dem Ergebnis der Matrixoperation + * Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der Matrixoperation *

-	 * C = A.B + V
+	 * C = this . t(B)
 	 * 
- * wobei A dieses Matrixobjekt ist und {@code .} für die + * wobei {@code this} dieses Matrixobjekt ist, {@code t(B)} die + * Transposition der Matrix {@code B} ist und {@code .} für die * Matrixmultiplikation steht. + *

+ * Wenn diese Matrix die Dimension r x c hat, dann muss die Matrix {@code B} + * die Dimension m x c haben und das Ergebnis ist eine r x m Matrix. * - * @param B - * @param V - * @return + * @param B Eine m x {@code columns()} Matrix. + * @return Eine {@code rows()} x m Matrix. + * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht + * zur Operation passen. Also + * {@code this.columns() != B.columns()}. */ - MLMatrix multiplyAddBias( MLMatrix B, MLMatrix V ); + MLMatrix multiplyTransposed( MLMatrix B ) throws IllegalArgumentException; /** - * Erzeugt eine neue Matrix C mit dem Ergebnis der Matrixoperation + * Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der Matrixoperation *

-	 * C = A.t(B)
+	 * C = t(this) . B * scalar
 	 * 
- * wobei A dieses Matrixobjekt ist und {@code t(B)} für die - * Transposition der Matrix B> steht. + * wobei {@code this} dieses Matrixobjekt ist, {@code t(this)} die + * Transposition dieser Matrix ist und {@code .} für die + * Matrixmultiplikation steht. {@code *} bezeichnet die + * Skalarmultiplikation, bei der jeder Wert der Matrix mit {@code scalar} + * multipliziert wird. + *

+ * Wenn diese Matrix die Dimension r x c hat, dann muss die Matrix {@code B} + * die Dimension r x m haben und das Ergebnis ist eine c x m Matrix. * - * @param B - * @return + * @param B Eine m x {@code columns()} Matrix. + * @return Eine {@code rows()} x m Matrix. + * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht + * zur Operation passen. Also + * {@code this.rows() != B.rows()}. */ - MLMatrix multiplyTransposed( MLMatrix B ); - - MLMatrix transposedMultiplyAndScale( MLMatrix B, double scalar ); + MLMatrix transposedMultiplyAndScale( MLMatrix B, double scalar ) throws IllegalArgumentException; /** - * Erzeugt eine neue Matrix C mit dem Ergebnis der - * komponentenweisen Matrix-Addition + * Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der komponentenweisen + * Matrix-Addition *

-	 * C = A+B
+	 * C = this + B
 	 * 
- * wobei A dieses Matrixobjekt ist. Für ein Element - * C_ij in C gilt + * wobei {@code this} dieses Matrixobjekt ist. Für ein Element {@code C_ij} + * in {@code C} gilt *
 	 * C_ij = A_ij + B_ij
 	 * 
+ *

+ * Die Matrix {@code B} muss dieselbe Dimension wie diese Matrix haben. * - * @param B Die zweite Matrix. - * @return Ein neues Matrixobjekt mit dem Ergebnis. + * @param B Eine {@code rows()} x {@code columns()} Matrix. + * @return Eine {@code rows()} x {@code columns()} Matrix. + * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht + * zur Operation passen. Also + * {@code this.rows() != B.rows()} oder + * {@code this.columns() != B.columns()}. */ - MLMatrix add( MLMatrix B ); + MLMatrix add( MLMatrix B ) throws IllegalArgumentException; /** - * Setzt dies Matrix auf das Ergebnis der - * komponentenweisen Matrix-Addition + * Setzt diese Matrix auf das Ergebnis der komponentenweisen + * Matrix-Addition *

-	 * A = A+B
+	 * A' = A + B
 	 * 
- * wobei A dieses Matrixobjekt ist. Für ein Element - * A_ij in A gilt + * wobei {@code A} dieses Matrixobjekt ist und {@code A'} diese Matrix nach + * der Operation. Für ein Element {@code A'_ij} in {@code A'} gilt *
-	 * A_ij = A_ij + B_ij
+	 * A'_ij = A_ij + B_ij
 	 * 
+ *

+ * Die Matrix {@code B} muss dieselbe Dimension wie diese Matrix haben. * - * @param B Die zweite Matrix. - * @return Diese Matrix selbst (method chaining). + * @param B Eine {@code rows()} x {@code columns()} Matrix. + * @return Eine {@code rows()} x {@code columns()} Matrix. + * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht + * zur Operation passen. Also + * {@code this.rows() != B.rows()} oder + * {@code this.columns() != B.columns()}. */ - MLMatrix addInPlace( MLMatrix B ); + MLMatrix addInPlace( MLMatrix B ) throws IllegalArgumentException; /** - * Erzeugt eine neue Matrix C mit dem Ergebnis der - * komponentenweisen Matrix-Subtraktion + * Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der komponentenweisen + * Matrix-Subtraktion *

-	 * C = A-B
+	 * C = A - B
 	 * 
- * wobei A dieses Matrixobjekt ist. Für ein Element - * C_ij in C gilt + * wobei {@code A} dieses Matrixobjekt ist. Für ein Element {@code C_ij} in + * {@code C} gilt *
 	 * C_ij = A_ij - B_ij
 	 * 
+ *

+ * Die Matrix {@code B} muss dieselbe Dimension wie diese Matrix haben. * - * @param B - * @return + * @param B Eine {@code rows()} x {@code columns()} Matrix. + * @return Eine {@code rows()} x {@code columns()} Matrix. + * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht + * zur Operation passen. Also + * {@code this.rows() != B.rows()} oder + * {@code this.columns() != B.columns()}. */ - MLMatrix sub( MLMatrix B ); + MLMatrix sub( MLMatrix B ) throws IllegalArgumentException; + /** + * Multipliziert jeden Wert dieser Matrix mit dem angegebenen Skalar. + *

+ * Ist {@code A} dieses Matrixobjekt und {@code A'} diese Matrix nach der + * Operation, dann gilt für ein Element {@code A'_ij} in {@code A'} + *

+	 * A'_ij = A_ij * scalar
+	 * 
+ * + * @param scalar Ein Skalar. + * @return Diese Matrix selbst (method chaining) + */ MLMatrix scaleInPlace( double scalar ); - MLMatrix scaleInPlace( MLMatrix S ); + /** + * Multipliziert jeden Wert dieser Matrix mit dem entsprechenden Wert in der + * Matrix {@code S}. + *

+ * Ist {@code A} dieses Matrixobjekt und {@code A'} diese Matrix nach der + * Operation, dann gilt für ein Element {@code A'_ij} in {@code A'} + *

+	 * A'_ij = A_ij * S_ij
+	 * 
+ * + * @param S Eine {@code rows()} x {@code columns()} Matrix. + * @return Diese Matrix selbst (method chaining) + * @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht + * zur Operation passen. Also + * {@code this.rows() != B.rows()} oder + * {@code this.columns() != B.columns()}. + */ + MLMatrix scaleInPlace( MLMatrix S ) throws IllegalArgumentException; /** * Berechnet eine neue Matrix mit nur einer Zeile, die die Spaltensummen * dieser Matrix enthalten. - * @return + * + * @return Eine 1 x {@code columns()} Matrix. */ MLMatrix colSums(); /** - * Endet die gegebene Funktion auf jeden Wert der Matrix an. + * Erzeugt eine neue Matrix, deren Werte gleich den Werten dieser Matrix + * nach der Anwendung der angegebenen Funktion sind. * - * @param op - * @return + * @param op Eine Operation {@code (double) -> double}. + * @return Eine {@code rows()} x {@code columns()} Matrix. */ MLMatrix apply( DoubleUnaryOperator op ); /** * Endet die gegebene Funktion auf jeden Wert der Matrix an. * - * @param op - * @return + * @param op Eine Operation {@code (double) -> double}. + * @return Diese Matrix selbst (method chaining) */ MLMatrix applyInPlace( DoubleUnaryOperator op ); /** - * Erzeugt eine neue Matrix mit denselben Dimenstionen und Koeffizienten wie + * Erzeugt eine neue Matrix mit denselben Dimensionen und Koeffizienten wie * diese Matrix. * - * @return + * @return Eine Kopie dieser Matrix. */ MLMatrix duplicate(); diff --git a/src/main/java/schule/ngb/zm/ml/MatrixFactory.java b/src/main/java/schule/ngb/zm/ml/MatrixFactory.java index 6a4d6d4..889380a 100644 --- a/src/main/java/schule/ngb/zm/ml/MatrixFactory.java +++ b/src/main/java/schule/ngb/zm/ml/MatrixFactory.java @@ -6,14 +6,27 @@ import schule.ngb.zm.util.Log; import java.util.function.DoubleUnaryOperator; +/** + * Zentrale Klasse zur Erstellung neuer Matrizen. Generell sollten neue Matrizen + * nicht direkt erstellt werden, sondern durch den Aufruf von + * {@link #create(int, int)} oder {@link #create(double[][])}. Die Fabrik + * ermittelt automatisch die beste verfügbare Implementierung und initialisiert + * eine entsprechende Implementierung von {@link MLMatrix}. + *

+ * Derzeit werden die optionale Bibliothek Colt und die interne + * Implementierung {@link DoubleMatrix} unterstützt. + */ public class MatrixFactory { - public static void main( String[] args ) { - System.out.println( - MatrixFactory.create(new double[][]{{1.0, 0.0}, {0.0, 1.0}}).toString() - ); - } - + /** + * Erstellt eine neue Matrix mit den angegebenen Dimensionen und + * initialisiert alle Werte mit 0. + * + * @param rows Anzahl der Zeilen. + * @param cols Anzahl der Spalten. + * @return Eine {@code rows} x {@code cols} Matrix. + */ public static final MLMatrix create( int rows, int cols ) { try { return getMatrixType().getDeclaredConstructor(int.class, int.class).newInstance(rows, cols); @@ -23,6 +36,14 @@ public class MatrixFactory { return new DoubleMatrix(rows, cols); } + /** + * Erstellt eine neue Matrix mit den Dimensionen des angegebenen Arrays und + * initialisiert die Werte mit den entsprechenden Werten des Arrays. + * + * @param values Die Werte der Matrix. + * @return Eine {@code values.length} x {@code values[0].length} Matrix mit + * den Werten des Arrays. + */ public static final MLMatrix create( double[][] values ) { try { return getMatrixType().getDeclaredConstructor(double[][].class).newInstance((Object) values); @@ -32,8 +53,17 @@ public class MatrixFactory { return new DoubleMatrix(values); } + /** + * Die verwendete {@link MLMatrix} Implementierung, aus der Matrizen erzeugt + * werden. + */ static Class matrixType = null; + /** + * Ermittelt die beste verfügbare Implementierung von {@link MLMatrix}. + * + * @return Die verwendete {@link MLMatrix} Implementierung. + */ private static final Class getMatrixType() { if( matrixType == null ) { try { @@ -50,6 +80,10 @@ public class MatrixFactory { private static final Log LOG = Log.getLogger(MatrixFactory.class); + /** + * Interner Wrapper der DoubleMatrix2D Klasse aus der Colt Bibliothek, um + * das {@link MLMatrix} Interface zu implementieren. + */ static class ColtMatrix implements MLMatrix { cern.colt.matrix.DoubleMatrix2D matrix; @@ -87,11 +121,6 @@ public class MatrixFactory { return this; } - @Override - public double[][] coefficients() { - return this.matrix.toArray(); - } - @Override public MLMatrix initializeRandom() { return initializeRandom(-1.0, 1.0); diff --git a/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java b/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java index 2eae11b..b1cd4d6 100644 --- a/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java +++ b/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java @@ -15,7 +15,7 @@ public class NeuralNetwork { Writer writer = ResourceStreamProvider.getWriter(source); PrintWriter out = new PrintWriter(writer) ) { - for( NeuronLayer layer: network.layers ) { + for( NeuronLayer layer : network.layers ) { out.print(layer.getNeuronCount()); out.print(' '); out.print(layer.getInputCount()); @@ -23,20 +23,44 @@ public class NeuralNetwork { for( int i = 0; i < layer.getInputCount(); i++ ) { for( int j = 0; j < layer.getNeuronCount(); j++ ) { - //out.print(layer.weights.coefficients[i][j]); + out.print(layer.weights.get(i, j)); out.print(' '); } out.println(); } for( int j = 0; j < layer.getNeuronCount(); j++ ) { - //out.print(layer.biases[j]); + out.print(layer.biases.get(0, j)); out.print(' '); } out.println(); } out.flush(); } catch( IOException ex ) { - LOG.warn(ex, ""); + LOG.error(ex, ""); + } + } + + public static void saveToDataFile( String source, NeuralNetwork network ) { + try( + OutputStream stream = ResourceStreamProvider.getOutputStream(source); + DataOutputStream out = new DataOutputStream(stream) + ) { + for( NeuronLayer layer : network.layers ) { + out.writeInt(layer.getNeuronCount()); + out.writeInt(layer.getInputCount()); + + for( int i = 0; i < layer.getInputCount(); i++ ) { + for( int j = 0; j < layer.getNeuronCount(); j++ ) { + out.writeDouble(layer.weights.get(i, j)); + } + } + for( int j = 0; j < layer.getNeuronCount(); j++ ) { + out.writeDouble(layer.biases.get(0, j)); + } + } + out.flush(); + } catch( IOException ex ) { + LOG.error(ex, ""); } } @@ -56,13 +80,13 @@ public class NeuralNetwork { for( int i = 0; i < inputs; i++ ) { split = in.readLine().split(" "); for( int j = 0; j < neurons; j++ ) { - //layer.weights.coefficients[i][j] = Double.parseDouble(split[j]); + layer.weights.set(i, j, Double.parseDouble(split[j])); } } // Load Biases split = in.readLine().split(" "); for( int j = 0; j < neurons; j++ ) { - //layer.biases[j] = Double.parseDouble(split[j]); + layer.biases.set(0, j, Double.parseDouble(split[j])); } layers.add(layer); @@ -70,29 +94,30 @@ public class NeuralNetwork { return new NeuralNetwork(layers); } catch( IOException | NoSuchElementException ex ) { - LOG.warn(ex, "Could not load neural network from source <%s>", source); + LOG.error(ex, "Could not load neural network from source <%s>", source); } return null; } - /*public static NeuralNetwork loadFromFile( String source ) { + public static NeuralNetwork loadFromDataFile( String source ) { try( InputStream stream = ResourceStreamProvider.getInputStream(source); - Scanner in = new Scanner(stream) + DataInputStream in = new DataInputStream(stream) ) { List layers = new LinkedList<>(); - while( in.hasNext() ) { - int neurons = in.nextInt(); - int inputs = in.nextInt(); + while( in.available() > 0 ) { + int neurons = in.readInt(); + int inputs = in.readInt(); NeuronLayer layer = new NeuronLayer(neurons, inputs); for( int i = 0; i < inputs; i++ ) { for( int j = 0; j < neurons; j++ ) { - layer.weights.coefficients[i][j] = in.nextDouble(); + layer.weights.set(i, j, in.readDouble()); } } + // Load Biases for( int j = 0; j < neurons; j++ ) { - layer.biases[j] = in.nextDouble(); + layer.biases.set(0, j, in.readDouble()); } layers.add(layer); @@ -100,10 +125,10 @@ public class NeuralNetwork { return new NeuralNetwork(layers); } catch( IOException | NoSuchElementException ex ) { - LOG.warn(ex, "Could not load neural network from source <%s>", source); + LOG.error(ex, "Could not load neural network from source <%s>", source); } return null; - }*/ + } private NeuronLayer[] layers; @@ -128,7 +153,7 @@ public class NeuralNetwork { for( int i = 0; i < layers.size(); i++ ) { this.layers[i] = layers.get(i); if( i > 0 ) { - this.layers[i-1].setNextLayer(this.layers[i]); + this.layers[i - 1].setNextLayer(this.layers[i]); } } } @@ -138,7 +163,7 @@ public class NeuralNetwork { for( int i = 0; i < layers.length; i++ ) { this.layers[i] = layers[i]; if( i > 0 ) { - this.layers[i-1].setNextLayer(this.layers[i]); + this.layers[i - 1].setNextLayer(this.layers[i]); } } } @@ -146,6 +171,7 @@ public class NeuralNetwork { public int getLayerCount() { return layers.length; } + public NeuronLayer[] getLayers() { return layers; } diff --git a/src/main/java/schule/ngb/zm/ml/NeuronLayer.java b/src/main/java/schule/ngb/zm/ml/NeuronLayer.java index 26a76eb..4ce061d 100644 --- a/src/main/java/schule/ngb/zm/ml/NeuronLayer.java +++ b/src/main/java/schule/ngb/zm/ml/NeuronLayer.java @@ -3,30 +3,45 @@ package schule.ngb.zm.ml; import java.util.function.DoubleUnaryOperator; import java.util.function.Function; +/** + * Implementierung einer Neuronenebene in einem Neuonalen Netz. + *

+ * Eine Ebene besteht aus einer Anzahl an Neuronen die jeweils eine + * Anzahl Eingänge haben. Die Eingänge erhalten als Signal die Ausgabe + * der vorherigen Ebene und berechnen die Ausgabe des jeweiligen Neurons. + */ public class NeuronLayer implements Function { - /*public static NeuronLayer fromArray( double[][] weights ) { - NeuronLayer layer = new NeuronLayer(weights[0].length, weights.length); + public static NeuronLayer fromArray( double[][] weights, boolean transpose ) { + NeuronLayer layer; + if( transpose ) { + layer = new NeuronLayer(weights.length, weights[0].length); + } else { + layer = new NeuronLayer(weights[0].length, weights.length); + } + for( int i = 0; i < weights[0].length; i++ ) { for( int j = 0; j < weights.length; j++ ) { - layer.weights.coefficients[i][j] = weights[i][j]; + if( transpose ) { + layer.weights.set(j, i, weights[i][j]); + } else { + layer.weights.set(i, j, weights[i][j]); + } } } + return layer; } - public static NeuronLayer fromArray( double[][] weights, double[] biases ) { - NeuronLayer layer = new NeuronLayer(weights[0].length, weights.length); - for( int i = 0; i < weights[0].length; i++ ) { - for( int j = 0; j < weights.length; j++ ) { - layer.weights.coefficients[i][j] = weights[i][j]; - } - } + public static NeuronLayer fromArray( double[][] weights, double[] biases, boolean transpose ) { + NeuronLayer layer = fromArray(weights, transpose); + for( int j = 0; j < biases.length; j++ ) { - layer.biases[j] = biases[j]; + layer.biases.set(0, j, biases[j]); } + return layer; - }*/ + } MLMatrix weights; @@ -112,7 +127,7 @@ public class NeuronLayer implements Function { @Override public String toString() { - return weights.toString() + "\n" + biases.toString(); + return "weights:\n" + weights.toString() + "\nbiases:\n" + biases.toString(); } @Override diff --git a/src/test/java/schule/ngb/zm/ml/MLMatrixTest.java b/src/test/java/schule/ngb/zm/ml/MLMatrixTest.java index 2c153f6..1e0a1da 100644 --- a/src/test/java/schule/ngb/zm/ml/MLMatrixTest.java +++ b/src/test/java/schule/ngb/zm/ml/MLMatrixTest.java @@ -1,9 +1,11 @@ package schule.ngb.zm.ml; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInfo; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import schule.ngb.zm.util.Timer; import static org.junit.jupiter.api.Assertions.*; @@ -389,4 +391,36 @@ class MLMatrixTest { return String.format("[" + testName + "(" + className + ") " + methodName + "()] " + msg, args); } + //@ParameterizedTest + //@ValueSource( classes = {MatrixFactory.ColtMatrix.class, DoubleMatrix.class} ) + void speed( Class mType ) { + MatrixFactory.matrixType = mType; + + int N = 10; + int rows = 1000; + int cols = 1000; + + Timer timer = new Timer(); + + MLMatrix M = MatrixFactory.create(rows, cols); + timer.start(); + for( int i = 0; i < N; i++ ) { + M.initializeRandom(); + } + timer.stop(); + System.err.println(msg("%d iterations: %d ms", "initializeRandom", N, timer.getMillis())); + + timer.reset(); + + MLMatrix B = MatrixFactory.create(rows*2, M.columns()); + B.initializeRandom(); + + timer.start(); + for( int i = 0; i < N; i++ ) { + M.multiplyTransposed(B); + } + timer.stop(); + System.err.println(msg("%d iterations: %d ms", "multiplyTransposed", N, timer.getMillis())); + } + }