mirror of
https://github.com/jneug/zeichenmaschine.git
synced 2026-04-14 06:33:34 +02:00
Merge branch 'optional-ml'
This commit is contained in:
@@ -28,8 +28,13 @@ dependencies {
|
|||||||
runtimeOnly 'com.googlecode.soundlibs:tritonus-share:0.3.7.4'
|
runtimeOnly 'com.googlecode.soundlibs:tritonus-share:0.3.7.4'
|
||||||
runtimeOnly 'com.googlecode.soundlibs:mp3spi:1.9.5.4'
|
runtimeOnly 'com.googlecode.soundlibs:mp3spi:1.9.5.4'
|
||||||
|
|
||||||
|
//compileOnlyApi 'colt:colt:1.2.0'
|
||||||
|
api 'colt:colt:1.2.0'
|
||||||
|
//api 'net.sourceforge.parallelcolt:parallelcolt:0.10.1'
|
||||||
|
|
||||||
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
|
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
|
||||||
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
|
testImplementation 'org.junit.jupiter:junit-jupiter-params:5.8.1'
|
||||||
|
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
|
||||||
}
|
}
|
||||||
|
|
||||||
test {
|
test {
|
||||||
|
|||||||
@@ -1269,7 +1269,8 @@ public class Constants {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Erzeugt eine Pseudozufallszahl nach einer Gaussverteilung.
|
* Erzeugt eine Pseudozufallszahl zwischen -1 und 1 nach einer
|
||||||
|
* Normalverteilung mit Mittelwert 0 und Standardabweichung 1.
|
||||||
*
|
*
|
||||||
* @return Eine Zufallszahl.
|
* @return Eine Zufallszahl.
|
||||||
* @see Random#nextGaussian()
|
* @see Random#nextGaussian()
|
||||||
|
|||||||
399
src/main/java/schule/ngb/zm/ml/DoubleMatrix.java
Normal file
399
src/main/java/schule/ngb/zm/ml/DoubleMatrix.java
Normal file
@@ -0,0 +1,399 @@
|
|||||||
|
package schule.ngb.zm.ml;
|
||||||
|
|
||||||
|
import schule.ngb.zm.Constants;
|
||||||
|
|
||||||
|
import java.util.function.DoubleUnaryOperator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Eine einfache Implementierung der {@link MLMatrix} zur Verwendung in
|
||||||
|
* {@link NeuralNetwork}s.
|
||||||
|
* <p>
|
||||||
|
* Diese Klasse stellt die interne Implementierung der Matrixoperationen dar,
|
||||||
|
* die zur Berechnung der Gewichte in einem {@link NeuronLayer} notwendig sind.
|
||||||
|
* <p>
|
||||||
|
* Die Klasse ist nur minimal optimiert und sollte nur für kleine Netze
|
||||||
|
* verwendet werden. Für größere Netze sollte auf eine der optionalen
|
||||||
|
* Bibliotheken wie
|
||||||
|
* <a href="">Colt</a> zurückgegriffen werden.
|
||||||
|
*/
|
||||||
|
public final class DoubleMatrix implements MLMatrix {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Anzahl Zeilen der Matrix.
|
||||||
|
*/
|
||||||
|
private int rows;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Anzahl Spalten der Matrix.
|
||||||
|
*/
|
||||||
|
private int columns;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Die Koeffizienten der Matrix.
|
||||||
|
* <p>
|
||||||
|
* Um den Overhead bei Speicher und Zugriffszeiten von zweidimensionalen
|
||||||
|
* Arrays zu vermeiden wird ein eindimensionales Array verwendet und die
|
||||||
|
* Indizes mit Spaltenpriorität berechnet. Der Index i des Koeffizienten
|
||||||
|
* {@code r,c} in Zeile {@code r} und Spalte {@code c} wird bestimmt durch
|
||||||
|
* <pre>
|
||||||
|
* i = c * rows + r
|
||||||
|
* </pre>
|
||||||
|
* <p>
|
||||||
|
* Die Werte einer Spalte liegen also hintereinander im Array. Dies sollte
|
||||||
|
* einen leichten Vorteil bei der {@link #colSums() Spaltensummen} geben.
|
||||||
|
* Generell sollte eine Iteration über die Matrix der Form
|
||||||
|
* <pre><code>
|
||||||
|
* for( int j = 0; j < columns; j++ ) {
|
||||||
|
* for( int i = 0; i < rows; i++ ) {
|
||||||
|
* // ...
|
||||||
|
* }
|
||||||
|
* }
|
||||||
|
* </code></pre>
|
||||||
|
* etwas schneller sein als
|
||||||
|
* <pre><code>
|
||||||
|
* for( int i = 0; i < rows; i++ ) {
|
||||||
|
* for( int j = 0; j < columns; j++ ) {
|
||||||
|
* // ...
|
||||||
|
* }
|
||||||
|
* }
|
||||||
|
* </code></pre>
|
||||||
|
*/
|
||||||
|
double[] coefficients;
|
||||||
|
|
||||||
|
public DoubleMatrix( int rows, int cols ) {
|
||||||
|
this.rows = rows;
|
||||||
|
this.columns = cols;
|
||||||
|
coefficients = new double[rows * cols];
|
||||||
|
}
|
||||||
|
|
||||||
|
public DoubleMatrix( double[][] coefficients ) {
|
||||||
|
this.rows = coefficients.length;
|
||||||
|
this.columns = coefficients[0].length;
|
||||||
|
this.coefficients = new double[rows * columns];
|
||||||
|
for( int j = 0; j < columns; j++ ) {
|
||||||
|
for( int i = 0; i < rows; i++ ) {
|
||||||
|
this.coefficients[idx(i, j)] = coefficients[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialisiert diese Matrix als Kopie der angegebenen Matrix.
|
||||||
|
*
|
||||||
|
* @param other Die zu kopierende Matrix.
|
||||||
|
*/
|
||||||
|
public DoubleMatrix( DoubleMatrix other ) {
|
||||||
|
this.rows = other.rows();
|
||||||
|
this.columns = other.columns();
|
||||||
|
this.coefficients = new double[rows * columns];
|
||||||
|
System.arraycopy(
|
||||||
|
other.coefficients, 0,
|
||||||
|
this.coefficients, 0,
|
||||||
|
rows * columns);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public int columns() {
|
||||||
|
return columns;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public int rows() {
|
||||||
|
return rows;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
int idx( int r, int c ) {
|
||||||
|
return c * rows + r;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public double get( int row, int col ) {
|
||||||
|
try {
|
||||||
|
return coefficients[idx(row, col)];
|
||||||
|
} catch( ArrayIndexOutOfBoundsException ex ) {
|
||||||
|
throw new IllegalArgumentException("No element at row=" + row + ", column=" + col, ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix set( int row, int col, double value ) {
|
||||||
|
try {
|
||||||
|
coefficients[idx(row, col)] = value;
|
||||||
|
} catch( ArrayIndexOutOfBoundsException ex ) {
|
||||||
|
throw new IllegalArgumentException("No element at row=" + row + ", column=" + col, ex);
|
||||||
|
}
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix initializeRandom() {
|
||||||
|
return initializeRandom(-1.0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix initializeRandom( double lower, double upper ) {
|
||||||
|
applyInPlace(( d ) -> ((upper - lower) * Constants.random()) + lower);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix initializeOne() {
|
||||||
|
applyInPlace(( d ) -> 1.0);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix initializeZero() {
|
||||||
|
applyInPlace(( d ) -> 0.0);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix duplicate() {
|
||||||
|
return new DoubleMatrix(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix multiplyTransposed( MLMatrix B ) {
|
||||||
|
/*return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
|
||||||
|
( i ) -> IntStream.range(0, B.rows()).mapToDouble(
|
||||||
|
( j ) -> IntStream.range(0, columns).mapToDouble(
|
||||||
|
( k ) -> get(i, k) * B.get(j, k)
|
||||||
|
).sum()
|
||||||
|
).toArray()
|
||||||
|
).toArray(double[][]::new));*/
|
||||||
|
DoubleMatrix result = new DoubleMatrix(rows, B.rows());
|
||||||
|
for( int i = 0; i < rows; i++ ) {
|
||||||
|
for( int j = 0; j < B.rows(); j++ ) {
|
||||||
|
result.coefficients[result.idx(i, j)] = 0.0;
|
||||||
|
for( int k = 0; k < columns; k++ ) {
|
||||||
|
result.coefficients[result.idx(i, j)] += get(i, k) * B.get(j, k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix multiplyAddBias( final MLMatrix B, final MLMatrix C ) {
|
||||||
|
/*return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
|
||||||
|
( i ) -> IntStream.range(0, B.columns()).mapToDouble(
|
||||||
|
( j ) -> IntStream.range(0, columns).mapToDouble(
|
||||||
|
( k ) -> get(i, k) * B.get(k, j)
|
||||||
|
).sum() + C.get(0, j)
|
||||||
|
).toArray()
|
||||||
|
).toArray(double[][]::new));*/
|
||||||
|
DoubleMatrix result = new DoubleMatrix(rows, B.columns());
|
||||||
|
for( int i = 0; i < rows; i++ ) {
|
||||||
|
for( int j = 0; j < B.columns(); j++ ) {
|
||||||
|
result.coefficients[result.idx(i, j)] = 0.0;
|
||||||
|
for( int k = 0; k < columns; k++ ) {
|
||||||
|
result.coefficients[result.idx(i, j)] += get(i, k) * B.get(k, j);
|
||||||
|
}
|
||||||
|
result.coefficients[result.idx(i, j)] += C.get(0, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix transposedMultiplyAndScale( final MLMatrix B, final double scalar ) {
|
||||||
|
/*return new DoubleMatrix(IntStream.range(0, columns).parallel().mapToObj(
|
||||||
|
( i ) -> IntStream.range(0, B.columns()).mapToDouble(
|
||||||
|
( j ) -> IntStream.range(0, rows).mapToDouble(
|
||||||
|
( k ) -> get(k, i) * B.get(k, j) * scalar
|
||||||
|
).sum()
|
||||||
|
).toArray()
|
||||||
|
).toArray(double[][]::new));*/
|
||||||
|
DoubleMatrix result = new DoubleMatrix(columns, B.columns());
|
||||||
|
for( int i = 0; i < columns; i++ ) {
|
||||||
|
for( int j = 0; j < B.columns(); j++ ) {
|
||||||
|
result.coefficients[result.idx(i, j)] = 0.0;
|
||||||
|
for( int k = 0; k < rows; k++ ) {
|
||||||
|
result.coefficients[result.idx(i, j)] += get(k, i) * B.get(k, j);
|
||||||
|
}
|
||||||
|
result.coefficients[result.idx(i, j)] *= scalar;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix add( MLMatrix B ) {
|
||||||
|
/*return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
|
||||||
|
( i ) -> IntStream.range(0, columns).mapToDouble(
|
||||||
|
( j ) -> get(i, j) + B.get(i, j)
|
||||||
|
).toArray()
|
||||||
|
).toArray(double[][]::new));*/
|
||||||
|
DoubleMatrix sum = new DoubleMatrix(rows, columns);
|
||||||
|
for( int j = 0; j < columns; j++ ) {
|
||||||
|
for( int i = 0; i < rows; i++ ) {
|
||||||
|
sum.coefficients[idx(i, j)] = coefficients[idx(i, j)] + B.get(i, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix addInPlace( MLMatrix B ) {
|
||||||
|
for( int j = 0; j < columns; j++ ) {
|
||||||
|
for( int i = 0; i < rows; i++ ) {
|
||||||
|
coefficients[idx(i, j)] += B.get(i, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix sub( MLMatrix B ) {
|
||||||
|
/*return new DoubleMatrix(IntStream.range(0, rows).parallel().mapToObj(
|
||||||
|
( i ) -> IntStream.range(0, columns).mapToDouble(
|
||||||
|
( j ) -> get(i, j) - B.get(i, j)
|
||||||
|
).toArray()
|
||||||
|
).toArray(double[][]::new));*/
|
||||||
|
DoubleMatrix diff = new DoubleMatrix(rows, columns);
|
||||||
|
for( int j = 0; j < columns; j++ ) {
|
||||||
|
for( int i = 0; i < rows; i++ ) {
|
||||||
|
diff.coefficients[idx(i, j)] = coefficients[idx(i, j)] - B.get(i, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return diff;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix colSums() {
|
||||||
|
/*DoubleMatrix colSums = new DoubleMatrix(1, columns);
|
||||||
|
colSums.coefficients = IntStream.range(0, columns).parallel().mapToDouble(
|
||||||
|
( j ) -> IntStream.range(0, rows).mapToDouble(
|
||||||
|
( i ) -> get(i, j)
|
||||||
|
).sum()
|
||||||
|
).toArray();
|
||||||
|
return colSums;*/
|
||||||
|
DoubleMatrix colSums = new DoubleMatrix(1, columns);
|
||||||
|
for( int j = 0; j < columns; j++ ) {
|
||||||
|
colSums.coefficients[j] = 0.0;
|
||||||
|
for( int i = 0; i < rows; i++ ) {
|
||||||
|
colSums.coefficients[j] += coefficients[idx(i, j)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return colSums;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix scaleInPlace( final double scalar ) {
|
||||||
|
for( int i = 0; i < coefficients.length; i++ ) {
|
||||||
|
coefficients[i] *= scalar;
|
||||||
|
}
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix scaleInPlace( final MLMatrix S ) {
|
||||||
|
for( int j = 0; j < columns; j++ ) {
|
||||||
|
for( int i = 0; i < rows; i++ ) {
|
||||||
|
coefficients[idx(i, j)] *= S.get(i, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix apply( DoubleUnaryOperator op ) {
|
||||||
|
DoubleMatrix result = new DoubleMatrix(rows, columns);
|
||||||
|
for( int i = 0; i < coefficients.length; i++ ) {
|
||||||
|
result.coefficients[i] = op.applyAsDouble(coefficients[i]);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public MLMatrix applyInPlace( DoubleUnaryOperator op ) {
|
||||||
|
for( int i = 0; i < coefficients.length; i++ ) {
|
||||||
|
coefficients[i] = op.applyAsDouble(coefficients[i]);
|
||||||
|
}
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
sb.append(rows);
|
||||||
|
sb.append(" x ");
|
||||||
|
sb.append(columns);
|
||||||
|
sb.append(" Matrix");
|
||||||
|
sb.append('\n');
|
||||||
|
for( int i = 0; i < rows; i++ ) {
|
||||||
|
for( int j = 0; j < columns; j++ ) {
|
||||||
|
sb.append(get(i, j));
|
||||||
|
if( j < columns - 1 )
|
||||||
|
sb.append(' ');
|
||||||
|
}
|
||||||
|
sb.append('\n');
|
||||||
|
}
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
316
src/main/java/schule/ngb/zm/ml/MLMatrix.java
Normal file
316
src/main/java/schule/ngb/zm/ml/MLMatrix.java
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
package schule.ngb.zm.ml;
|
||||||
|
|
||||||
|
import java.util.function.DoubleUnaryOperator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface für Matrizen, die in {@link NeuralNetwork} Klassen verwendet
|
||||||
|
* werden.
|
||||||
|
* <p>
|
||||||
|
* Eine implementierende Klasse muss generell zwei Konstruktoren bereitstellen:
|
||||||
|
* <ol>
|
||||||
|
* <li> {@code MLMatrix(int rows, int columns)} erstellt eine Matrix mit den
|
||||||
|
* angegebenen Dimensionen und setzt alle Koeffizienten auf 0.
|
||||||
|
* <li> {@code MLMatrix(double[][] coefficients} erstellt eine Matrix mit der
|
||||||
|
* durch das Array gegebenen Dimensionen und setzt die Werte auf die
|
||||||
|
* jeweiligen Werte des Arrays.
|
||||||
|
* </ol>
|
||||||
|
* <p>
|
||||||
|
* Das Interface ist nicht dazu gedacht eine allgemeine Umsetzung für
|
||||||
|
* Matrizen-Algebra abzubilden, sondern soll gezielt die im Neuralen Netzwerk
|
||||||
|
* verwendeten Algorithmen umsetzen. Einerseits würde eine ganz allgemeine
|
||||||
|
* Matrizen-Klasse nicht im Rahmen der Zeichenmaschine liegen und auf der
|
||||||
|
* anderen Seite bietet eine Konzentration auf die verwendeten Algorithmen mehr
|
||||||
|
* Spielraum zur Optimierung.
|
||||||
|
* <p>
|
||||||
|
* Intern wird das Interface von {@link DoubleMatrix} implementiert. Die Klasse
|
||||||
|
* ist eine weitestgehend naive Implementierung der Algorithmen mit kleineren
|
||||||
|
* Optimierungen. Die Verwendung eines generalisierten Interfaces erlaubt aber
|
||||||
|
* zukünftig die optionale Integration spezialisierterer Algebra-Bibliotheken
|
||||||
|
* wie
|
||||||
|
* <a href="https://dst.lbl.gov/ACSSoftware/colt/">Colt</a>, um auch große
|
||||||
|
* Netze effizient berechnen zu können.
|
||||||
|
*/
|
||||||
|
public interface MLMatrix {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Die Anzahl der Spalten der Matrix.
|
||||||
|
*
|
||||||
|
* @return Spaltenzahl.
|
||||||
|
*/
|
||||||
|
int columns();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Die Anzahl der Zeilen der Matrix.
|
||||||
|
*
|
||||||
|
* @return Zeilenzahl.
|
||||||
|
*/
|
||||||
|
int rows();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gibt den Wert an der angegebenen Stelle der Matrix zurück.
|
||||||
|
*
|
||||||
|
* @param row Die Spaltennummer zwischen 0 und {@code rows()-1}.
|
||||||
|
* @param col Die Zeilennummer zwischen 0 und {@code columns()-1}
|
||||||
|
* @return Den Koeffizienten in der Zeile {@code row} und der Spalte
|
||||||
|
* {@code col}.
|
||||||
|
* @throws IllegalArgumentException Falls {@code row >= rows()} oder
|
||||||
|
* {@code col >= columns()}.
|
||||||
|
*/
|
||||||
|
double get( int row, int col ) throws IllegalArgumentException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Setzt den Wert an der angegebenen Stelle der Matrix.
|
||||||
|
*
|
||||||
|
* @param row Die Spaltennummer zwischen 0 und {@code rows()-1}.
|
||||||
|
* @param col Die Zeilennummer zwischen 0 und {@code columns()-1}
|
||||||
|
* @param value Der neue Wert.
|
||||||
|
* @return Diese Matrix selbst (method chaining).
|
||||||
|
* @throws IllegalArgumentException Falls {@code row >= rows()} oder
|
||||||
|
* {@code col >= columns()}.
|
||||||
|
*/
|
||||||
|
MLMatrix set( int row, int col, double value ) throws IllegalArgumentException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Setzt jeden Wert in der Matrix auf eine Zufallszahl zwischen -1 und 1.
|
||||||
|
* <p>
|
||||||
|
* Nach Möglichkeit sollte der
|
||||||
|
* {@link schule.ngb.zm.Constants#random(int, int) Zufallsgenerator der
|
||||||
|
* Zeichenmaschine} verwendet werden.
|
||||||
|
*
|
||||||
|
* @return Diese Matrix selbst (method chaining).
|
||||||
|
*/
|
||||||
|
MLMatrix initializeRandom();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Setzt jeden Wert in der Matrix auf eine Zufallszahl innerhalb der
|
||||||
|
* angegebenen Grenzen.
|
||||||
|
* <p>
|
||||||
|
* Nach Möglichkeit sollte der
|
||||||
|
* {@link schule.ngb.zm.Constants#random(int, int) Zufallsgenerator der
|
||||||
|
* Zeichenmaschine} verwendet werden.
|
||||||
|
*
|
||||||
|
* @param lower Untere Grenze der Zufallszahlen.
|
||||||
|
* @param upper Obere Grenze der Zufallszahlen.
|
||||||
|
* @return Diese Matrix selbst (method chaining).
|
||||||
|
*/
|
||||||
|
MLMatrix initializeRandom( double lower, double upper );
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Setzt alle Werte der Matrix auf 1.
|
||||||
|
*
|
||||||
|
* @return Diese Matrix selbst (method chaining).
|
||||||
|
*/
|
||||||
|
MLMatrix initializeOne();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Setzt alle Werte der Matrix auf 0.
|
||||||
|
*
|
||||||
|
* @return Diese Matrix selbst (method chaining).
|
||||||
|
*/
|
||||||
|
MLMatrix initializeZero();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der Matrixoperation
|
||||||
|
* <pre>
|
||||||
|
* C = this . B + V'
|
||||||
|
* </pre>
|
||||||
|
* wobei {@code this} dieses Matrixobjekt ist und {@code .} für die
|
||||||
|
* Matrixmultiplikation steht. {@vode V'} ist die Matrix {@code V}
|
||||||
|
* {@code rows()}-mal untereinander wiederholt.
|
||||||
|
* <p>
|
||||||
|
* Wenn diese Matrix die Dimension r x c hat, dann muss die Matrix {@code B}
|
||||||
|
* die Dimension c x m haben und {@code V} eine 1 x m Matrix sein. Die
|
||||||
|
* Matrix {@code V'} hat also die Dimension r x m, ebenso wie das Ergebnis
|
||||||
|
* der Operation.
|
||||||
|
*
|
||||||
|
* @param B Eine {@code columns()} x m Matrix mit der Multipliziert wird.
|
||||||
|
* @param V Eine 1 x {@code B.columns()} Matrix mit den Bias-Werten.
|
||||||
|
* @return Eine {@code rows()} x m Matrix.
|
||||||
|
* @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
|
||||||
|
* zur Operation passen. Also
|
||||||
|
* {@code this.columns() != B.rows()} oder
|
||||||
|
* {@code B.columns() != V.columns()} oder
|
||||||
|
* {@code V.rows() != 1}.
|
||||||
|
*/
|
||||||
|
MLMatrix multiplyAddBias( MLMatrix B, MLMatrix V ) throws IllegalArgumentException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der Matrixoperation
|
||||||
|
* <pre>
|
||||||
|
* C = this . t(B)
|
||||||
|
* </pre>
|
||||||
|
* wobei {@code this} dieses Matrixobjekt ist, {@code t(B)} die
|
||||||
|
* Transposition der Matrix {@code B} ist und {@code .} für die
|
||||||
|
* Matrixmultiplikation steht.
|
||||||
|
* <p>
|
||||||
|
* Wenn diese Matrix die Dimension r x c hat, dann muss die Matrix {@code B}
|
||||||
|
* die Dimension m x c haben und das Ergebnis ist eine r x m Matrix.
|
||||||
|
*
|
||||||
|
* @param B Eine m x {@code columns()} Matrix.
|
||||||
|
* @return Eine {@code rows()} x m Matrix.
|
||||||
|
* @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
|
||||||
|
* zur Operation passen. Also
|
||||||
|
* {@code this.columns() != B.columns()}.
|
||||||
|
*/
|
||||||
|
MLMatrix multiplyTransposed( MLMatrix B ) throws IllegalArgumentException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der Matrixoperation
|
||||||
|
* <pre>
|
||||||
|
* C = t(this) . B * scalar
|
||||||
|
* </pre>
|
||||||
|
* wobei {@code this} dieses Matrixobjekt ist, {@code t(this)} die
|
||||||
|
* Transposition dieser Matrix ist und {@code .} für die
|
||||||
|
* Matrixmultiplikation steht. {@code *} bezeichnet die
|
||||||
|
* Skalarmultiplikation, bei der jeder Wert der Matrix mit {@code scalar}
|
||||||
|
* multipliziert wird.
|
||||||
|
* <p>
|
||||||
|
* Wenn diese Matrix die Dimension r x c hat, dann muss die Matrix {@code B}
|
||||||
|
* die Dimension r x m haben und das Ergebnis ist eine c x m Matrix.
|
||||||
|
*
|
||||||
|
* @param B Eine m x {@code columns()} Matrix.
|
||||||
|
* @return Eine {@code rows()} x m Matrix.
|
||||||
|
* @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
|
||||||
|
* zur Operation passen. Also
|
||||||
|
* {@code this.rows() != B.rows()}.
|
||||||
|
*/
|
||||||
|
MLMatrix transposedMultiplyAndScale( MLMatrix B, double scalar ) throws IllegalArgumentException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der komponentenweisen
|
||||||
|
* Matrix-Addition
|
||||||
|
* <pre>
|
||||||
|
* C = this + B
|
||||||
|
* </pre>
|
||||||
|
* wobei {@code this} dieses Matrixobjekt ist. Für ein Element {@code C_ij}
|
||||||
|
* in {@code C} gilt
|
||||||
|
* <pre>
|
||||||
|
* C_ij = A_ij + B_ij
|
||||||
|
* </pre>
|
||||||
|
* <p>
|
||||||
|
* Die Matrix {@code B} muss dieselbe Dimension wie diese Matrix haben.
|
||||||
|
*
|
||||||
|
* @param B Eine {@code rows()} x {@code columns()} Matrix.
|
||||||
|
* @return Eine {@code rows()} x {@code columns()} Matrix.
|
||||||
|
* @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
|
||||||
|
* zur Operation passen. Also
|
||||||
|
* {@code this.rows() != B.rows()} oder
|
||||||
|
* {@code this.columns() != B.columns()}.
|
||||||
|
*/
|
||||||
|
MLMatrix add( MLMatrix B ) throws IllegalArgumentException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Setzt diese Matrix auf das Ergebnis der komponentenweisen
|
||||||
|
* Matrix-Addition
|
||||||
|
* <pre>
|
||||||
|
* A' = A + B
|
||||||
|
* </pre>
|
||||||
|
* wobei {@code A} dieses Matrixobjekt ist und {@code A'} diese Matrix nach
|
||||||
|
* der Operation. Für ein Element {@code A'_ij} in {@code A'} gilt
|
||||||
|
* <pre>
|
||||||
|
* A'_ij = A_ij + B_ij
|
||||||
|
* </pre>
|
||||||
|
* <p>
|
||||||
|
* Die Matrix {@code B} muss dieselbe Dimension wie diese Matrix haben.
|
||||||
|
*
|
||||||
|
* @param B Eine {@code rows()} x {@code columns()} Matrix.
|
||||||
|
* @return Eine {@code rows()} x {@code columns()} Matrix.
|
||||||
|
* @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
|
||||||
|
* zur Operation passen. Also
|
||||||
|
* {@code this.rows() != B.rows()} oder
|
||||||
|
* {@code this.columns() != B.columns()}.
|
||||||
|
*/
|
||||||
|
MLMatrix addInPlace( MLMatrix B ) throws IllegalArgumentException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Erzeugt eine neue Matrix {@code C} mit dem Ergebnis der komponentenweisen
|
||||||
|
* Matrix-Subtraktion
|
||||||
|
* <pre>
|
||||||
|
* C = A - B
|
||||||
|
* </pre>
|
||||||
|
* wobei {@code A} dieses Matrixobjekt ist. Für ein Element {@code C_ij} in
|
||||||
|
* {@code C} gilt
|
||||||
|
* <pre>
|
||||||
|
* C_ij = A_ij - B_ij
|
||||||
|
* </pre>
|
||||||
|
* <p>
|
||||||
|
* Die Matrix {@code B} muss dieselbe Dimension wie diese Matrix haben.
|
||||||
|
*
|
||||||
|
* @param B Eine {@code rows()} x {@code columns()} Matrix.
|
||||||
|
* @return Eine {@code rows()} x {@code columns()} Matrix.
|
||||||
|
* @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
|
||||||
|
* zur Operation passen. Also
|
||||||
|
* {@code this.rows() != B.rows()} oder
|
||||||
|
* {@code this.columns() != B.columns()}.
|
||||||
|
*/
|
||||||
|
MLMatrix sub( MLMatrix B ) throws IllegalArgumentException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multipliziert jeden Wert dieser Matrix mit dem angegebenen Skalar.
|
||||||
|
* <p>
|
||||||
|
* Ist {@code A} dieses Matrixobjekt und {@code A'} diese Matrix nach der
|
||||||
|
* Operation, dann gilt für ein Element {@code A'_ij} in {@code A'}
|
||||||
|
* <pre>
|
||||||
|
* A'_ij = A_ij * scalar
|
||||||
|
* </pre>
|
||||||
|
*
|
||||||
|
* @param scalar Ein Skalar.
|
||||||
|
* @return Diese Matrix selbst (method chaining)
|
||||||
|
*/
|
||||||
|
MLMatrix scaleInPlace( double scalar );
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multipliziert jeden Wert dieser Matrix mit dem entsprechenden Wert in der
|
||||||
|
* Matrix {@code S}.
|
||||||
|
* <p>
|
||||||
|
* Ist {@code A} dieses Matrixobjekt und {@code A'} diese Matrix nach der
|
||||||
|
* Operation, dann gilt für ein Element {@code A'_ij} in {@code A'}
|
||||||
|
* <pre>
|
||||||
|
* A'_ij = A_ij * S_ij
|
||||||
|
* </pre>
|
||||||
|
*
|
||||||
|
* @param S Eine {@code rows()} x {@code columns()} Matrix.
|
||||||
|
* @return Diese Matrix selbst (method chaining)
|
||||||
|
* @throws IllegalArgumentException Falls die Dimensionen der Matrizen nicht
|
||||||
|
* zur Operation passen. Also
|
||||||
|
* {@code this.rows() != B.rows()} oder
|
||||||
|
* {@code this.columns() != B.columns()}.
|
||||||
|
*/
|
||||||
|
MLMatrix scaleInPlace( MLMatrix S ) throws IllegalArgumentException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Berechnet eine neue Matrix mit nur einer Zeile, die die Spaltensummen
|
||||||
|
* dieser Matrix enthalten.
|
||||||
|
*
|
||||||
|
* @return Eine 1 x {@code columns()} Matrix.
|
||||||
|
*/
|
||||||
|
MLMatrix colSums();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Erzeugt eine neue Matrix, deren Werte gleich den Werten dieser Matrix
|
||||||
|
* nach der Anwendung der angegebenen Funktion sind.
|
||||||
|
*
|
||||||
|
* @param op Eine Operation {@code (double) -> double}.
|
||||||
|
* @return Eine {@code rows()} x {@code columns()} Matrix.
|
||||||
|
*/
|
||||||
|
MLMatrix apply( DoubleUnaryOperator op );
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Endet die gegebene Funktion auf jeden Wert der Matrix an.
|
||||||
|
*
|
||||||
|
* @param op Eine Operation {@code (double) -> double}.
|
||||||
|
* @return Diese Matrix selbst (method chaining)
|
||||||
|
*/
|
||||||
|
MLMatrix applyInPlace( DoubleUnaryOperator op );
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Erzeugt eine neue Matrix mit denselben Dimensionen und Koeffizienten wie
|
||||||
|
* diese Matrix.
|
||||||
|
*
|
||||||
|
* @return Eine Kopie dieser Matrix.
|
||||||
|
*/
|
||||||
|
MLMatrix duplicate();
|
||||||
|
|
||||||
|
String toString();
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
package schule.ngb.zm.ml;
|
|
||||||
|
|
||||||
import schule.ngb.zm.Constants;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
// TODO: Move Math into Matrix class
|
|
||||||
// TODO: Implement support for optional sci libs
|
|
||||||
public class Matrix {
|
|
||||||
|
|
||||||
private int columns, rows;
|
|
||||||
|
|
||||||
double[][] coefficients;
|
|
||||||
|
|
||||||
public Matrix( int rows, int cols ) {
|
|
||||||
this.rows = rows;
|
|
||||||
this.columns = cols;
|
|
||||||
coefficients = new double[rows][cols];
|
|
||||||
}
|
|
||||||
|
|
||||||
public Matrix( double[][] coefficients ) {
|
|
||||||
this.coefficients = coefficients;
|
|
||||||
this.rows = coefficients.length;
|
|
||||||
this.columns = coefficients[0].length;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getColumns() {
|
|
||||||
return columns;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getRows() {
|
|
||||||
return rows;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double[][] getCoefficients() {
|
|
||||||
return coefficients;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double get( int row, int col ) {
|
|
||||||
return coefficients[row][col];
|
|
||||||
}
|
|
||||||
|
|
||||||
public void initializeRandom() {
|
|
||||||
coefficients = MLMath.matrixApply(coefficients, (d) -> Constants.randomGaussian());
|
|
||||||
}
|
|
||||||
|
|
||||||
public void initializeRandom( double lower, double upper ) {
|
|
||||||
coefficients = MLMath.matrixApply(coefficients, (d) -> ((upper-lower) * (Constants.randomGaussian()+1) * .5) + lower);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void initializeIdentity() {
|
|
||||||
initializeZero();
|
|
||||||
for( int i = 0; i < Math.min(rows, columns); i++ ) {
|
|
||||||
this.coefficients[i][i] = 1.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void initializeOne() {
|
|
||||||
coefficients = MLMath.matrixApply(coefficients, (d) -> 1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void initializeZero() {
|
|
||||||
coefficients = MLMath.matrixApply(coefficients, (d) -> 0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
//return Arrays.deepToString(coefficients);
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
sb.append('[');
|
|
||||||
sb.append('\n');
|
|
||||||
for( int i = 0; i < coefficients.length; i++ ) {
|
|
||||||
sb.append('\t');
|
|
||||||
sb.append(Arrays.toString(coefficients[i]));
|
|
||||||
sb.append('\n');
|
|
||||||
}
|
|
||||||
sb.append(']');
|
|
||||||
|
|
||||||
return sb.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
246
src/main/java/schule/ngb/zm/ml/MatrixFactory.java
Normal file
246
src/main/java/schule/ngb/zm/ml/MatrixFactory.java
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
package schule.ngb.zm.ml;
|
||||||
|
|
||||||
|
import cern.colt.matrix.DoubleFactory2D;
|
||||||
|
import schule.ngb.zm.Constants;
|
||||||
|
import schule.ngb.zm.util.Log;
|
||||||
|
|
||||||
|
import java.util.function.DoubleUnaryOperator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Zentrale Klasse zur Erstellung neuer Matrizen. Generell sollten neue Matrizen
|
||||||
|
* nicht direkt erstellt werden, sondern durch den Aufruf von
|
||||||
|
* {@link #create(int, int)} oder {@link #create(double[][])}. Die Fabrik
|
||||||
|
* ermittelt automatisch die beste verfügbare Implementierung und initialisiert
|
||||||
|
* eine entsprechende Implementierung von {@link MLMatrix}.
|
||||||
|
* <p>
|
||||||
|
* Derzeit werden die optionale Bibliothek <a
|
||||||
|
* href="https://dst.lbl.gov/ACSSoftware/colt/">Colt</a> und die interne
|
||||||
|
* Implementierung {@link DoubleMatrix} unterstützt.
|
||||||
|
*/
|
||||||
|
public class MatrixFactory {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Erstellt eine neue Matrix mit den angegebenen Dimensionen und
|
||||||
|
* initialisiert alle Werte mit 0.
|
||||||
|
*
|
||||||
|
* @param rows Anzahl der Zeilen.
|
||||||
|
* @param cols Anzahl der Spalten.
|
||||||
|
* @return Eine {@code rows} x {@code cols} Matrix.
|
||||||
|
*/
|
||||||
|
public static final MLMatrix create( int rows, int cols ) {
|
||||||
|
try {
|
||||||
|
return getMatrixType().getDeclaredConstructor(int.class, int.class).newInstance(rows, cols);
|
||||||
|
} catch( Exception ex ) {
|
||||||
|
LOG.error(ex, "Could not initialize matrix implementation for class <%s>. Using internal implementation.", matrixType);
|
||||||
|
}
|
||||||
|
return new DoubleMatrix(rows, cols);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Erstellt eine neue Matrix mit den Dimensionen des angegebenen Arrays und
|
||||||
|
* initialisiert die Werte mit den entsprechenden Werten des Arrays.
|
||||||
|
*
|
||||||
|
* @param values Die Werte der Matrix.
|
||||||
|
* @return Eine {@code values.length} x {@code values[0].length} Matrix mit
|
||||||
|
* den Werten des Arrays.
|
||||||
|
*/
|
||||||
|
public static final MLMatrix create( double[][] values ) {
|
||||||
|
try {
|
||||||
|
return getMatrixType().getDeclaredConstructor(double[][].class).newInstance((Object) values);
|
||||||
|
} catch( Exception ex ) {
|
||||||
|
LOG.error(ex, "Could not initialize matrix implementation for class <%s>. Using internal implementation.", matrixType);
|
||||||
|
}
|
||||||
|
return new DoubleMatrix(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Die verwendete {@link MLMatrix} Implementierung, aus der Matrizen erzeugt
|
||||||
|
* werden.
|
||||||
|
*/
|
||||||
|
static Class<? extends MLMatrix> matrixType = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ermittelt die beste verfügbare Implementierung von {@link MLMatrix}.
|
||||||
|
*
|
||||||
|
* @return Die verwendete {@link MLMatrix} Implementierung.
|
||||||
|
*/
|
||||||
|
private static final Class<? extends MLMatrix> getMatrixType() {
|
||||||
|
if( matrixType == null ) {
|
||||||
|
try {
|
||||||
|
Class<?> clazz = Class.forName("cern.colt.matrix.impl.DenseDoubleMatrix2D", false, MatrixFactory.class.getClassLoader());
|
||||||
|
matrixType = ColtMatrix.class;
|
||||||
|
LOG.info("Colt library found. Using <cern.colt.matrix.impl.DenseDoubleMatrix2D> as matrix implementation.");
|
||||||
|
} catch( ClassNotFoundException e ) {
|
||||||
|
LOG.info("Colt library not found. Falling back on internal implementation.");
|
||||||
|
matrixType = DoubleMatrix.class;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return matrixType;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final Log LOG = Log.getLogger(MatrixFactory.class);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interner Wrapper der DoubleMatrix2D Klasse aus der Colt Bibliothek, um
|
||||||
|
* das {@link MLMatrix} Interface zu implementieren.
|
||||||
|
*/
|
||||||
|
static class ColtMatrix implements MLMatrix {
|
||||||
|
|
||||||
|
cern.colt.matrix.DoubleMatrix2D matrix;
|
||||||
|
|
||||||
|
public ColtMatrix( double[][] doubles ) {
|
||||||
|
matrix = new cern.colt.matrix.impl.DenseDoubleMatrix2D(doubles);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ColtMatrix( int rows, int cols ) {
|
||||||
|
matrix = new cern.colt.matrix.impl.DenseDoubleMatrix2D(rows, cols);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ColtMatrix( ColtMatrix matrix ) {
|
||||||
|
this.matrix = matrix.matrix.copy();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int columns() {
|
||||||
|
return matrix.columns();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int rows() {
|
||||||
|
return matrix.rows();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double get( int row, int col ) {
|
||||||
|
return matrix.get(row, col);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix set( int row, int col, double value ) {
|
||||||
|
matrix.set(row, col, value);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix initializeRandom() {
|
||||||
|
return initializeRandom(-1.0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix initializeRandom( double lower, double upper ) {
|
||||||
|
matrix.assign(( d ) -> ((upper - lower) * Constants.random()) + lower);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix initializeOne() {
|
||||||
|
this.matrix.assign(1.0);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix initializeZero() {
|
||||||
|
this.matrix.assign(0.0);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix duplicate() {
|
||||||
|
ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
|
||||||
|
newMatrix.matrix.assign(this.matrix);
|
||||||
|
return newMatrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix multiplyTransposed( MLMatrix B ) {
|
||||||
|
ColtMatrix CB = (ColtMatrix) B;
|
||||||
|
ColtMatrix newMatrix = new ColtMatrix(0, 0);
|
||||||
|
newMatrix.matrix = matrix.zMult(CB.matrix, null, 1.0, 0.0, false, true);
|
||||||
|
return newMatrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix multiplyAddBias( MLMatrix B, MLMatrix C ) {
|
||||||
|
ColtMatrix CB = (ColtMatrix) B;
|
||||||
|
ColtMatrix newMatrix = new ColtMatrix(0, 0);
|
||||||
|
newMatrix.matrix = DoubleFactory2D.dense.repeat(((ColtMatrix) C).matrix, rows(), 1);
|
||||||
|
matrix.zMult(CB.matrix, newMatrix.matrix, 1.0, 1.0, false, false);
|
||||||
|
return newMatrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix transposedMultiplyAndScale( final MLMatrix B, final double scalar ) {
|
||||||
|
ColtMatrix CB = (ColtMatrix) B;
|
||||||
|
ColtMatrix newMatrix = new ColtMatrix(0, 0);
|
||||||
|
newMatrix.matrix = matrix.zMult(CB.matrix, null, scalar, 0.0, true, false);
|
||||||
|
return newMatrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix add( MLMatrix B ) {
|
||||||
|
ColtMatrix CB = (ColtMatrix) B;
|
||||||
|
ColtMatrix newMatrix = new ColtMatrix(this);
|
||||||
|
newMatrix.matrix.assign(CB.matrix, ( d1, d2 ) -> d1 + d2);
|
||||||
|
return newMatrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix addInPlace( MLMatrix B ) {
|
||||||
|
ColtMatrix CB = (ColtMatrix) B;
|
||||||
|
matrix.assign(CB.matrix, ( d1, d2 ) -> d1 + d2);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix sub( MLMatrix B ) {
|
||||||
|
ColtMatrix CB = (ColtMatrix) B;
|
||||||
|
ColtMatrix newMatrix = new ColtMatrix(this);
|
||||||
|
newMatrix.matrix.assign(CB.matrix, ( d1, d2 ) -> d1 - d2);
|
||||||
|
return newMatrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix colSums() {
|
||||||
|
double[][] sums = new double[1][matrix.columns()];
|
||||||
|
for( int c = 0; c < matrix.columns(); c++ ) {
|
||||||
|
for( int r = 0; r < matrix.rows(); r++ ) {
|
||||||
|
sums[0][c] += matrix.getQuick(r, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new ColtMatrix(sums);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix scaleInPlace( double scalar ) {
|
||||||
|
this.matrix.assign(( d ) -> d * scalar);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix scaleInPlace( MLMatrix S ) {
|
||||||
|
this.matrix.forEachNonZero(( r, c, d ) -> d * S.get(r, c));
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix apply( DoubleUnaryOperator op ) {
|
||||||
|
ColtMatrix newMatrix = new ColtMatrix(matrix.rows(), matrix.columns());
|
||||||
|
newMatrix.matrix.assign(matrix);
|
||||||
|
newMatrix.matrix.assign(( d ) -> op.applyAsDouble(d));
|
||||||
|
return newMatrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MLMatrix applyInPlace( DoubleUnaryOperator op ) {
|
||||||
|
this.matrix.assign(( d ) -> op.applyAsDouble(d));
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return matrix.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -15,7 +15,7 @@ public class NeuralNetwork {
|
|||||||
Writer writer = ResourceStreamProvider.getWriter(source);
|
Writer writer = ResourceStreamProvider.getWriter(source);
|
||||||
PrintWriter out = new PrintWriter(writer)
|
PrintWriter out = new PrintWriter(writer)
|
||||||
) {
|
) {
|
||||||
for( NeuronLayer layer: network.layers ) {
|
for( NeuronLayer layer : network.layers ) {
|
||||||
out.print(layer.getNeuronCount());
|
out.print(layer.getNeuronCount());
|
||||||
out.print(' ');
|
out.print(' ');
|
||||||
out.print(layer.getInputCount());
|
out.print(layer.getInputCount());
|
||||||
@@ -23,20 +23,44 @@ public class NeuralNetwork {
|
|||||||
|
|
||||||
for( int i = 0; i < layer.getInputCount(); i++ ) {
|
for( int i = 0; i < layer.getInputCount(); i++ ) {
|
||||||
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
|
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
|
||||||
out.print(layer.weights.coefficients[i][j]);
|
out.print(layer.weights.get(i, j));
|
||||||
out.print(' ');
|
out.print(' ');
|
||||||
}
|
}
|
||||||
out.println();
|
out.println();
|
||||||
}
|
}
|
||||||
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
|
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
|
||||||
out.print(layer.biases[j]);
|
out.print(layer.biases.get(0, j));
|
||||||
out.print(' ');
|
out.print(' ');
|
||||||
}
|
}
|
||||||
out.println();
|
out.println();
|
||||||
}
|
}
|
||||||
out.flush();
|
out.flush();
|
||||||
} catch( IOException ex ) {
|
} catch( IOException ex ) {
|
||||||
LOG.warn(ex, "");
|
LOG.error(ex, "");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void saveToDataFile( String source, NeuralNetwork network ) {
|
||||||
|
try(
|
||||||
|
OutputStream stream = ResourceStreamProvider.getOutputStream(source);
|
||||||
|
DataOutputStream out = new DataOutputStream(stream)
|
||||||
|
) {
|
||||||
|
for( NeuronLayer layer : network.layers ) {
|
||||||
|
out.writeInt(layer.getNeuronCount());
|
||||||
|
out.writeInt(layer.getInputCount());
|
||||||
|
|
||||||
|
for( int i = 0; i < layer.getInputCount(); i++ ) {
|
||||||
|
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
|
||||||
|
out.writeDouble(layer.weights.get(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
|
||||||
|
out.writeDouble(layer.biases.get(0, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out.flush();
|
||||||
|
} catch( IOException ex ) {
|
||||||
|
LOG.error(ex, "");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,13 +80,13 @@ public class NeuralNetwork {
|
|||||||
for( int i = 0; i < inputs; i++ ) {
|
for( int i = 0; i < inputs; i++ ) {
|
||||||
split = in.readLine().split(" ");
|
split = in.readLine().split(" ");
|
||||||
for( int j = 0; j < neurons; j++ ) {
|
for( int j = 0; j < neurons; j++ ) {
|
||||||
layer.weights.coefficients[i][j] = Double.parseDouble(split[j]);
|
layer.weights.set(i, j, Double.parseDouble(split[j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Load Biases
|
// Load Biases
|
||||||
split = in.readLine().split(" ");
|
split = in.readLine().split(" ");
|
||||||
for( int j = 0; j < neurons; j++ ) {
|
for( int j = 0; j < neurons; j++ ) {
|
||||||
layer.biases[j] = Double.parseDouble(split[j]);
|
layer.biases.set(0, j, Double.parseDouble(split[j]));
|
||||||
}
|
}
|
||||||
|
|
||||||
layers.add(layer);
|
layers.add(layer);
|
||||||
@@ -70,29 +94,30 @@ public class NeuralNetwork {
|
|||||||
|
|
||||||
return new NeuralNetwork(layers);
|
return new NeuralNetwork(layers);
|
||||||
} catch( IOException | NoSuchElementException ex ) {
|
} catch( IOException | NoSuchElementException ex ) {
|
||||||
LOG.warn(ex, "Could not load neural network from source <%s>", source);
|
LOG.error(ex, "Could not load neural network from source <%s>", source);
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*public static NeuralNetwork loadFromFile( String source ) {
|
public static NeuralNetwork loadFromDataFile( String source ) {
|
||||||
try(
|
try(
|
||||||
InputStream stream = ResourceStreamProvider.getInputStream(source);
|
InputStream stream = ResourceStreamProvider.getInputStream(source);
|
||||||
Scanner in = new Scanner(stream)
|
DataInputStream in = new DataInputStream(stream)
|
||||||
) {
|
) {
|
||||||
List<NeuronLayer> layers = new LinkedList<>();
|
List<NeuronLayer> layers = new LinkedList<>();
|
||||||
while( in.hasNext() ) {
|
while( in.available() > 0 ) {
|
||||||
int neurons = in.nextInt();
|
int neurons = in.readInt();
|
||||||
int inputs = in.nextInt();
|
int inputs = in.readInt();
|
||||||
|
|
||||||
NeuronLayer layer = new NeuronLayer(neurons, inputs);
|
NeuronLayer layer = new NeuronLayer(neurons, inputs);
|
||||||
for( int i = 0; i < inputs; i++ ) {
|
for( int i = 0; i < inputs; i++ ) {
|
||||||
for( int j = 0; j < neurons; j++ ) {
|
for( int j = 0; j < neurons; j++ ) {
|
||||||
layer.weights.coefficients[i][j] = in.nextDouble();
|
layer.weights.set(i, j, in.readDouble());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Load Biases
|
||||||
for( int j = 0; j < neurons; j++ ) {
|
for( int j = 0; j < neurons; j++ ) {
|
||||||
layer.biases[j] = in.nextDouble();
|
layer.biases.set(0, j, in.readDouble());
|
||||||
}
|
}
|
||||||
|
|
||||||
layers.add(layer);
|
layers.add(layer);
|
||||||
@@ -100,14 +125,14 @@ public class NeuralNetwork {
|
|||||||
|
|
||||||
return new NeuralNetwork(layers);
|
return new NeuralNetwork(layers);
|
||||||
} catch( IOException | NoSuchElementException ex ) {
|
} catch( IOException | NoSuchElementException ex ) {
|
||||||
LOG.warn(ex, "Could not load neural network from source <%s>", source);
|
LOG.error(ex, "Could not load neural network from source <%s>", source);
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}*/
|
}
|
||||||
|
|
||||||
private NeuronLayer[] layers;
|
private NeuronLayer[] layers;
|
||||||
|
|
||||||
private double[][] output;
|
private MLMatrix output;
|
||||||
|
|
||||||
private double learningRate = 0.1;
|
private double learningRate = 0.1;
|
||||||
|
|
||||||
@@ -128,7 +153,7 @@ public class NeuralNetwork {
|
|||||||
for( int i = 0; i < layers.size(); i++ ) {
|
for( int i = 0; i < layers.size(); i++ ) {
|
||||||
this.layers[i] = layers.get(i);
|
this.layers[i] = layers.get(i);
|
||||||
if( i > 0 ) {
|
if( i > 0 ) {
|
||||||
this.layers[i-1].setNextLayer(this.layers[i]);
|
this.layers[i - 1].setNextLayer(this.layers[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -138,7 +163,7 @@ public class NeuralNetwork {
|
|||||||
for( int i = 0; i < layers.length; i++ ) {
|
for( int i = 0; i < layers.length; i++ ) {
|
||||||
this.layers[i] = layers[i];
|
this.layers[i] = layers[i];
|
||||||
if( i > 0 ) {
|
if( i > 0 ) {
|
||||||
this.layers[i-1].setNextLayer(this.layers[i]);
|
this.layers[i - 1].setNextLayer(this.layers[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -146,6 +171,7 @@ public class NeuralNetwork {
|
|||||||
public int getLayerCount() {
|
public int getLayerCount() {
|
||||||
return layers.length;
|
return layers.length;
|
||||||
}
|
}
|
||||||
|
|
||||||
public NeuronLayer[] getLayers() {
|
public NeuronLayer[] getLayers() {
|
||||||
return layers;
|
return layers;
|
||||||
}
|
}
|
||||||
@@ -162,17 +188,25 @@ public class NeuralNetwork {
|
|||||||
this.learningRate = pLearningRate;
|
this.learningRate = pLearningRate;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double[][] getOutput() {
|
public MLMatrix getOutput() {
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double[][] predict( double[][] inputs ) {
|
public MLMatrix predict( double[][] inputs ) {
|
||||||
//this.output = layers[1].apply(layers[0].apply(inputs));
|
//this.output = layers[1].apply(layers[0].apply(inputs));
|
||||||
|
return predict(MatrixFactory.create(inputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
public MLMatrix predict( MLMatrix inputs ) {
|
||||||
this.output = layers[0].apply(inputs);
|
this.output = layers[0].apply(inputs);
|
||||||
return this.output;
|
return this.output;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void learn( double[][] expected ) {
|
public void learn( double[][] expected ) {
|
||||||
|
learn(MatrixFactory.create(expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void learn( MLMatrix expected ) {
|
||||||
layers[layers.length - 1].backprop(expected, learningRate);
|
layers[layers.length - 1].backprop(expected, learningRate);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,50 +1,66 @@
|
|||||||
package schule.ngb.zm.ml;
|
package schule.ngb.zm.ml;
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.function.DoubleUnaryOperator;
|
import java.util.function.DoubleUnaryOperator;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
|
|
||||||
public class NeuronLayer implements Function<double[][], double[][]> {
|
/**
|
||||||
|
* Implementierung einer Neuronenebene in einem Neuonalen Netz.
|
||||||
|
* <p>
|
||||||
|
* Eine Ebene besteht aus einer Anzahl an <em>Neuronen</em> die jeweils eine
|
||||||
|
* Anzahl <em>Eingänge</em> haben. Die Eingänge erhalten als Signal die Ausgabe
|
||||||
|
* der vorherigen Ebene und berechnen die Ausgabe des jeweiligen Neurons.
|
||||||
|
*/
|
||||||
|
public class NeuronLayer implements Function<MLMatrix, MLMatrix> {
|
||||||
|
|
||||||
|
public static NeuronLayer fromArray( double[][] weights, boolean transpose ) {
|
||||||
|
NeuronLayer layer;
|
||||||
|
if( transpose ) {
|
||||||
|
layer = new NeuronLayer(weights.length, weights[0].length);
|
||||||
|
} else {
|
||||||
|
layer = new NeuronLayer(weights[0].length, weights.length);
|
||||||
|
}
|
||||||
|
|
||||||
public static NeuronLayer fromArray( double[][] weights ) {
|
|
||||||
NeuronLayer layer = new NeuronLayer(weights[0].length, weights.length);
|
|
||||||
for( int i = 0; i < weights[0].length; i++ ) {
|
for( int i = 0; i < weights[0].length; i++ ) {
|
||||||
for( int j = 0; j < weights.length; j++ ) {
|
for( int j = 0; j < weights.length; j++ ) {
|
||||||
layer.weights.coefficients[i][j] = weights[i][j];
|
if( transpose ) {
|
||||||
|
layer.weights.set(j, i, weights[i][j]);
|
||||||
|
} else {
|
||||||
|
layer.weights.set(i, j, weights[i][j]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return layer;
|
return layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static NeuronLayer fromArray( double[][] weights, double[] biases ) {
|
public static NeuronLayer fromArray( double[][] weights, double[] biases, boolean transpose ) {
|
||||||
NeuronLayer layer = new NeuronLayer(weights[0].length, weights.length);
|
NeuronLayer layer = fromArray(weights, transpose);
|
||||||
for( int i = 0; i < weights[0].length; i++ ) {
|
|
||||||
for( int j = 0; j < weights.length; j++ ) {
|
|
||||||
layer.weights.coefficients[i][j] = weights[i][j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for( int j = 0; j < biases.length; j++ ) {
|
for( int j = 0; j < biases.length; j++ ) {
|
||||||
layer.biases[j] = biases[j];
|
layer.biases.set(0, j, biases[j]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return layer;
|
return layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
Matrix weights;
|
|
||||||
|
|
||||||
double[] biases;
|
MLMatrix weights;
|
||||||
|
|
||||||
|
MLMatrix biases;
|
||||||
|
|
||||||
NeuronLayer previous, next;
|
NeuronLayer previous, next;
|
||||||
|
|
||||||
DoubleUnaryOperator activationFunction, activationFunctionDerivative;
|
DoubleUnaryOperator activationFunction, activationFunctionDerivative;
|
||||||
|
|
||||||
double[][] lastOutput, lastInput;
|
MLMatrix lastOutput, lastInput;
|
||||||
|
|
||||||
public NeuronLayer( int neurons, int inputs ) {
|
public NeuronLayer( int neurons, int inputs ) {
|
||||||
weights = new Matrix(inputs, neurons);
|
weights = MatrixFactory
|
||||||
weights.initializeRandom(-1, 1);
|
.create(inputs, neurons)
|
||||||
|
.initializeRandom();
|
||||||
|
|
||||||
biases = new double[neurons];
|
biases = MatrixFactory
|
||||||
Arrays.fill(biases, 0.0); // TODO: Random?
|
.create(1, neurons)
|
||||||
|
.initializeZero();
|
||||||
|
|
||||||
activationFunction = MLMath::sigmoid;
|
activationFunction = MLMath::sigmoid;
|
||||||
activationFunctionDerivative = MLMath::sigmoidDerivative;
|
activationFunctionDerivative = MLMath::sigmoidDerivative;
|
||||||
@@ -85,45 +101,42 @@ public class NeuronLayer implements Function<double[][], double[][]> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Matrix getWeights() {
|
public MLMatrix getWeights() {
|
||||||
return weights;
|
return weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MLMatrix getBiases() {
|
||||||
|
return biases;
|
||||||
|
}
|
||||||
|
|
||||||
public int getNeuronCount() {
|
public int getNeuronCount() {
|
||||||
return weights.coefficients[0].length;
|
return weights.columns();
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getInputCount() {
|
public int getInputCount() {
|
||||||
return weights.coefficients.length;
|
return weights.rows();
|
||||||
}
|
}
|
||||||
|
|
||||||
public double[][] getLastOutput() {
|
public MLMatrix getLastOutput() {
|
||||||
return lastOutput;
|
return lastOutput;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setWeights( double[][] newWeights ) {
|
public void setWeights( MLMatrix newWeights ) {
|
||||||
weights.coefficients = MLMath.copyMatrix(newWeights);
|
weights = newWeights.duplicate();
|
||||||
}
|
|
||||||
|
|
||||||
public void adjustWeights( double[][] adjustment ) {
|
|
||||||
weights.coefficients = MLMath.matrixAdd(weights.coefficients, adjustment);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return weights.toString() + "\n" + Arrays.toString(biases);
|
return "weights:\n" + weights.toString() + "\nbiases:\n" + biases.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double[][] apply( double[][] inputs ) {
|
public MLMatrix apply( MLMatrix inputs ) {
|
||||||
lastInput = inputs;
|
lastInput = inputs.duplicate();
|
||||||
lastOutput = MLMath.matrixApply(
|
lastOutput = inputs
|
||||||
MLMath.biasAdd(
|
.multiplyAddBias(weights, biases)
|
||||||
MLMath.matrixMultiply(inputs, weights.coefficients),
|
.applyInPlace(activationFunction);
|
||||||
biases
|
|
||||||
),
|
|
||||||
activationFunction
|
|
||||||
);
|
|
||||||
if( next != null ) {
|
if( next != null ) {
|
||||||
return next.apply(lastOutput);
|
return next.apply(lastOutput);
|
||||||
} else {
|
} else {
|
||||||
@@ -132,36 +145,41 @@ public class NeuronLayer implements Function<double[][], double[][]> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <V> Function<V, double[][]> compose( Function<? super V, ? extends double[][]> before ) {
|
public <V> Function<V, MLMatrix> compose( Function<? super V, ? extends MLMatrix> before ) {
|
||||||
return ( in ) -> apply(before.apply(in));
|
return ( in ) -> apply(before.apply(in));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <V> Function<double[][], V> andThen( Function<? super double[][], ? extends V> after ) {
|
public <V> Function<MLMatrix, V> andThen( Function<? super MLMatrix, ? extends V> after ) {
|
||||||
return ( in ) -> after.apply(apply(in));
|
return ( in ) -> after.apply(apply(in));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void backprop( double[][] expected, double learningRate ) {
|
public void backprop( MLMatrix expected, double learningRate ) {
|
||||||
double[][] error, delta, adjustment;
|
MLMatrix error, adjustment;
|
||||||
if( next == null ) {
|
if( next == null ) {
|
||||||
error = MLMath.matrixSub(expected, this.lastOutput);
|
error = expected.sub(lastOutput);
|
||||||
} else {
|
} else {
|
||||||
error = MLMath.matrixMultiply(expected, MLMath.matrixTranspose(next.weights.coefficients));
|
error = expected.multiplyTransposed(next.weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
delta = MLMath.matrixScale(error, MLMath.matrixApply(this.lastOutput, this.activationFunctionDerivative));
|
error.scaleInPlace(
|
||||||
|
lastOutput.apply(this.activationFunctionDerivative)
|
||||||
|
);
|
||||||
// Hier schon leraningRate anwenden?
|
// Hier schon leraningRate anwenden?
|
||||||
// See https://towardsdatascience.com/understanding-and-implementing-neural-networks-in-java-from-scratch-61421bb6352c
|
// See https://towardsdatascience.com/understanding-and-implementing-neural-networks-in-java-from-scratch-61421bb6352c
|
||||||
//delta = MLMath.matrixApply(delta, ( x ) -> learningRate * x);
|
//delta = MLMath.matrixApply(delta, ( x ) -> learningRate * x);
|
||||||
if( previous != null ) {
|
if( previous != null ) {
|
||||||
previous.backprop(delta, learningRate);
|
previous.backprop(error, learningRate);
|
||||||
}
|
}
|
||||||
|
|
||||||
biases = MLMath.biasAdjust(biases, MLMath.matrixApply(delta, ( x ) -> learningRate * x));
|
biases.addInPlace(
|
||||||
|
error.colSums().scaleInPlace(
|
||||||
|
-learningRate / (double) error.rows()
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
adjustment = MLMath.matrixMultiply(MLMath.matrixTranspose(lastInput), delta);
|
adjustment = lastInput.transposedMultiplyAndScale(error, learningRate);
|
||||||
adjustment = MLMath.matrixApply(adjustment, ( x ) -> learningRate * x);
|
weights.addInPlace(adjustment);
|
||||||
this.adjustWeights(adjustment);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
426
src/test/java/schule/ngb/zm/ml/MLMatrixTest.java
Normal file
426
src/test/java/schule/ngb/zm/ml/MLMatrixTest.java
Normal file
@@ -0,0 +1,426 @@
|
|||||||
|
package schule.ngb.zm.ml;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.TestInfo;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
|
import schule.ngb.zm.util.Timer;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
class MLMatrixTest {
|
||||||
|
|
||||||
|
private TestInfo info;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void saveTestInfo( TestInfo info ) {
|
||||||
|
this.info = info;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void get( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||||
|
{1, 2, 3},
|
||||||
|
{4, 5, 6}
|
||||||
|
});
|
||||||
|
|
||||||
|
assertEquals(mType, M.getClass());
|
||||||
|
|
||||||
|
assertEquals(1.0, M.get(0,0));
|
||||||
|
assertEquals(4.0, M.get(1,0));
|
||||||
|
assertEquals(6.0, M.get(1,2));
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void initializeOne( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix m = MatrixFactory.create(4, 4);
|
||||||
|
m.initializeOne();
|
||||||
|
|
||||||
|
assertEquals(mType, m.getClass());
|
||||||
|
|
||||||
|
for( int i = 0; i < m.rows(); i++ ) {
|
||||||
|
for( int j = 0; j < m.columns(); j++ ) {
|
||||||
|
assertEquals(1.0, m.get(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void initializeZero( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix m = MatrixFactory.create(4, 4);
|
||||||
|
m.initializeZero();
|
||||||
|
|
||||||
|
assertEquals(mType, m.getClass());
|
||||||
|
|
||||||
|
for( int i = 0; i < m.rows(); i++ ) {
|
||||||
|
for( int j = 0; j < m.columns(); j++ ) {
|
||||||
|
assertEquals(0.0, m.get(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void initializeRandom( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix m = MatrixFactory.create(4, 4);
|
||||||
|
m.initializeRandom();
|
||||||
|
|
||||||
|
assertEquals(mType, m.getClass());
|
||||||
|
|
||||||
|
for( int i = 0; i < m.rows(); i++ ) {
|
||||||
|
for( int j = 0; j < m.columns(); j++ ) {
|
||||||
|
double d = m.get(i, j);
|
||||||
|
assertTrue(-1.0 <= d && d < 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void multiplyTransposed( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix A = MatrixFactory.create(new double[][]{
|
||||||
|
{1.0, 2.0, 3.0, 4.0},
|
||||||
|
{1.0, 2.0, 3.0, 4.0},
|
||||||
|
{1.0, 2.0, 3.0, 4.0}
|
||||||
|
});
|
||||||
|
MLMatrix B = MatrixFactory.create(new double[][]{
|
||||||
|
{1, 3, 5, 7},
|
||||||
|
{2, 4, 6, 8}
|
||||||
|
});
|
||||||
|
|
||||||
|
MLMatrix C = A.multiplyTransposed(B);
|
||||||
|
|
||||||
|
assertEquals(mType, A.getClass());
|
||||||
|
assertEquals(mType, B.getClass());
|
||||||
|
assertEquals(mType, C.getClass());
|
||||||
|
|
||||||
|
assertEquals(3, C.rows());
|
||||||
|
assertEquals(2, C.columns());
|
||||||
|
for( int i = 0; i < C.rows(); i++ ) {
|
||||||
|
assertEquals(50.0, C.get(i, 0));
|
||||||
|
assertEquals(60.0, C.get(i, 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void multiplyAddBias( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix A = MatrixFactory.create(new double[][]{
|
||||||
|
{1.0, 2.0, 3.0, 4.0},
|
||||||
|
{1.0, 2.0, 3.0, 4.0},
|
||||||
|
{1.0, 2.0, 3.0, 4.0}
|
||||||
|
});
|
||||||
|
MLMatrix B = MatrixFactory.create(new double[][]{
|
||||||
|
{1.0, 2.0},
|
||||||
|
{3.0, 4.0},
|
||||||
|
{5.0, 6.0},
|
||||||
|
{7.0, 8.0}
|
||||||
|
});
|
||||||
|
MLMatrix V = MatrixFactory.create(new double[][]{
|
||||||
|
{1000.0, 2000.0}
|
||||||
|
});
|
||||||
|
|
||||||
|
MLMatrix C = A.multiplyAddBias(B, V);
|
||||||
|
|
||||||
|
assertEquals(mType, A.getClass());
|
||||||
|
assertEquals(mType, B.getClass());
|
||||||
|
assertEquals(mType, C.getClass());
|
||||||
|
assertEquals(mType, V.getClass());
|
||||||
|
|
||||||
|
assertEquals(3, C.rows());
|
||||||
|
assertEquals(2, C.columns());
|
||||||
|
for( int i = 0; i < C.rows(); i++ ) {
|
||||||
|
assertEquals(1050.0, C.get(i, 0));
|
||||||
|
assertEquals(2060.0, C.get(i, 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void transposedMultiplyAndScale( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix A = MatrixFactory.create(new double[][]{
|
||||||
|
{1, 1, 1},
|
||||||
|
{2, 2, 2},
|
||||||
|
{3, 3, 3},
|
||||||
|
{4, 4, 4}
|
||||||
|
});
|
||||||
|
MLMatrix B = MatrixFactory.create(new double[][]{
|
||||||
|
{1.0, 2.0},
|
||||||
|
{3.0, 4.0},
|
||||||
|
{5.0, 6.0},
|
||||||
|
{7.0, 8.0}
|
||||||
|
});
|
||||||
|
|
||||||
|
MLMatrix C = A.transposedMultiplyAndScale(B, 2.0);
|
||||||
|
|
||||||
|
assertEquals(mType, A.getClass());
|
||||||
|
assertEquals(mType, B.getClass());
|
||||||
|
assertEquals(mType, C.getClass());
|
||||||
|
|
||||||
|
assertEquals(3, C.rows());
|
||||||
|
assertEquals(2, C.columns());
|
||||||
|
for( int i = 0; i < C.rows(); i++ ) {
|
||||||
|
assertEquals(100.0, C.get(i, 0));
|
||||||
|
assertEquals(120.0, C.get(i, 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void apply( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||||
|
{1, 1, 1},
|
||||||
|
{2, 2, 2},
|
||||||
|
{3, 3, 3},
|
||||||
|
{4, 4, 4}
|
||||||
|
});
|
||||||
|
|
||||||
|
MLMatrix R = M.apply(( d ) -> d * d);
|
||||||
|
|
||||||
|
assertEquals(mType, M.getClass());
|
||||||
|
assertEquals(mType, R.getClass());
|
||||||
|
assertNotSame(M, R);
|
||||||
|
|
||||||
|
for( int i = 0; i < M.rows(); i++ ) {
|
||||||
|
for( int j = 0; j < M.columns(); j++ ) {
|
||||||
|
assertEquals(
|
||||||
|
(i + 1) * (i + 1), R.get(i, j),
|
||||||
|
msg("(%d,%d)", "apply", i, j)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MLMatrix M2 = M.applyInPlace(( d ) -> d * d * d);
|
||||||
|
assertSame(M, M2);
|
||||||
|
for( int i = 0; i < M.rows(); i++ ) {
|
||||||
|
for( int j = 0; j < M.columns(); j++ ) {
|
||||||
|
assertEquals(
|
||||||
|
(i + 1) * (i + 1) * (i + 1), M.get(i, j),
|
||||||
|
msg("(%d,%d)", "applyInPlace", i, j)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void add( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||||
|
{1, 1, 1},
|
||||||
|
{2, 2, 2},
|
||||||
|
{3, 3, 3},
|
||||||
|
{4, 4, 4}
|
||||||
|
});
|
||||||
|
|
||||||
|
MLMatrix R = M.add(M);
|
||||||
|
|
||||||
|
assertEquals(mType, M.getClass());
|
||||||
|
assertEquals(mType, R.getClass());
|
||||||
|
assertNotSame(M, R);
|
||||||
|
|
||||||
|
for( int i = 0; i < M.rows(); i++ ) {
|
||||||
|
for( int j = 0; j < M.columns(); j++ ) {
|
||||||
|
assertEquals(
|
||||||
|
(i + 1) + (i + 1), R.get(i, j),
|
||||||
|
msg("(%d,%d)", "add", i, j)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MLMatrix M2 = M.addInPlace(R);
|
||||||
|
assertSame(M, M2);
|
||||||
|
for( int i = 0; i < M.rows(); i++ ) {
|
||||||
|
for( int j = 0; j < M.columns(); j++ ) {
|
||||||
|
assertEquals(
|
||||||
|
(i + 1) + (i + 1) + (i + 1), M.get(i, j),
|
||||||
|
msg("(%d,%d)", "addInPlace", i, j)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void sub( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||||
|
{1, 1, 1},
|
||||||
|
{2, 2, 2},
|
||||||
|
{3, 3, 3},
|
||||||
|
{4, 4, 4}
|
||||||
|
});
|
||||||
|
|
||||||
|
MLMatrix R = M.sub(M);
|
||||||
|
|
||||||
|
assertEquals(mType, M.getClass());
|
||||||
|
assertEquals(mType, R.getClass());
|
||||||
|
assertNotSame(M, R);
|
||||||
|
|
||||||
|
for( int i = 0; i < M.rows(); i++ ) {
|
||||||
|
for( int j = 0; j < M.columns(); j++ ) {
|
||||||
|
assertEquals(
|
||||||
|
0.0, R.get(i, j),
|
||||||
|
msg("(%d,%d)", "sub", i, j)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void colSums( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||||
|
{1, 2, 3},
|
||||||
|
{1, 2, 3},
|
||||||
|
{1, 2, 3},
|
||||||
|
{1, 2, 3}
|
||||||
|
});
|
||||||
|
|
||||||
|
MLMatrix R = M.colSums();
|
||||||
|
|
||||||
|
assertEquals(mType, M.getClass());
|
||||||
|
assertEquals(mType, R.getClass());
|
||||||
|
assertNotSame(M, R);
|
||||||
|
|
||||||
|
assertEquals(1, R.rows());
|
||||||
|
assertEquals(3, R.columns());
|
||||||
|
for( int j = 0; j < M.columns(); j++ ) {
|
||||||
|
assertEquals(
|
||||||
|
(j+1)*4, R.get(0, j),
|
||||||
|
msg("(%d,%d)", "colSums", 0, j)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void duplicate( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||||
|
{1, 2, 3},
|
||||||
|
{1, 2, 3},
|
||||||
|
{1, 2, 3},
|
||||||
|
{1, 2, 3}
|
||||||
|
});
|
||||||
|
|
||||||
|
MLMatrix R = M.duplicate();
|
||||||
|
|
||||||
|
assertEquals(mType, M.getClass());
|
||||||
|
assertEquals(mType, R.getClass());
|
||||||
|
assertNotSame(M, R);
|
||||||
|
|
||||||
|
for( int i = 0; i < M.rows(); i++ ) {
|
||||||
|
for( int j = 0; j < M.columns(); j++ ) {
|
||||||
|
assertEquals(
|
||||||
|
M.get(i, j), R.get(i, j),
|
||||||
|
msg("(%d,%d)", "duplicate", i, j)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource( classes = {DoubleMatrix.class, MatrixFactory.ColtMatrix.class} )
|
||||||
|
void scale( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
MLMatrix M = MatrixFactory.create(new double[][]{
|
||||||
|
{1, 1, 1},
|
||||||
|
{2, 2, 2},
|
||||||
|
{3, 3, 3},
|
||||||
|
{4, 4, 4}
|
||||||
|
});
|
||||||
|
|
||||||
|
MLMatrix M2 = M.scaleInPlace(2.0);
|
||||||
|
|
||||||
|
assertEquals(mType, M.getClass());
|
||||||
|
assertEquals(mType, M2.getClass());
|
||||||
|
assertSame(M, M2);
|
||||||
|
|
||||||
|
for( int i = 0; i < M.rows(); i++ ) {
|
||||||
|
for( int j = 0; j < M.columns(); j++ ) {
|
||||||
|
assertEquals(
|
||||||
|
(i+1)*2.0, M2.get(i, j),
|
||||||
|
msg("(%d,%d)", "scaleInPlace", i, j)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MLMatrix M3 = M.scaleInPlace(M);
|
||||||
|
assertSame(M, M3);
|
||||||
|
for( int i = 0; i < M.rows(); i++ ) {
|
||||||
|
for( int j = 0; j < M.columns(); j++ ) {
|
||||||
|
assertEquals(
|
||||||
|
((i+1)*2.0)*((i+1)*2.0), M.get(i, j),
|
||||||
|
msg("(%d,%d)", "addInPlace", i, j)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String msg( String msg, String methodName, Object... args ) {
|
||||||
|
String testName = this.info.getTestMethod().get().getName();
|
||||||
|
String className = MatrixFactory.matrixType.getSimpleName();
|
||||||
|
return String.format("[" + testName + "(" + className + ") " + methodName + "()] " + msg, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
//@ParameterizedTest
|
||||||
|
//@ValueSource( classes = {MatrixFactory.ColtMatrix.class, DoubleMatrix.class} )
|
||||||
|
void speed( Class<? extends MLMatrix> mType ) {
|
||||||
|
MatrixFactory.matrixType = mType;
|
||||||
|
|
||||||
|
int N = 10;
|
||||||
|
int rows = 1000;
|
||||||
|
int cols = 1000;
|
||||||
|
|
||||||
|
Timer timer = new Timer();
|
||||||
|
|
||||||
|
MLMatrix M = MatrixFactory.create(rows, cols);
|
||||||
|
timer.start();
|
||||||
|
for( int i = 0; i < N; i++ ) {
|
||||||
|
M.initializeRandom();
|
||||||
|
}
|
||||||
|
timer.stop();
|
||||||
|
System.err.println(msg("%d iterations: %d ms", "initializeRandom", N, timer.getMillis()));
|
||||||
|
|
||||||
|
timer.reset();
|
||||||
|
|
||||||
|
MLMatrix B = MatrixFactory.create(rows*2, M.columns());
|
||||||
|
B.initializeRandom();
|
||||||
|
|
||||||
|
timer.start();
|
||||||
|
for( int i = 0; i < N; i++ ) {
|
||||||
|
M.multiplyTransposed(B);
|
||||||
|
}
|
||||||
|
timer.stop();
|
||||||
|
System.err.println(msg("%d iterations: %d ms", "multiplyTransposed", N, timer.getMillis()));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
package schule.ngb.zm.ml;
|
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
|
||||||
|
|
||||||
class MatrixTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void initializeIdentity() {
|
|
||||||
Matrix m = new Matrix(4, 4);
|
|
||||||
m.initializeIdentity();
|
|
||||||
|
|
||||||
assertArrayEquals(new double[]{1.0, 0.0, 0.0, 0.0}, m.coefficients[0]);
|
|
||||||
assertArrayEquals(new double[]{0.0, 1.0, 0.0, 0.0}, m.coefficients[1]);
|
|
||||||
assertArrayEquals(new double[]{0.0, 0.0, 1.0, 0.0}, m.coefficients[2]);
|
|
||||||
assertArrayEquals(new double[]{0.0, 0.0, 0.0, 1.0}, m.coefficients[3]);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void initializeOne() {
|
|
||||||
Matrix m = new Matrix(4, 4);
|
|
||||||
m.initializeOne();
|
|
||||||
|
|
||||||
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0};
|
|
||||||
assertArrayEquals(ones, m.coefficients[0]);
|
|
||||||
assertArrayEquals(ones, m.coefficients[1]);
|
|
||||||
assertArrayEquals(ones, m.coefficients[2]);
|
|
||||||
assertArrayEquals(ones, m.coefficients[3]);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void initializeZero() {
|
|
||||||
Matrix m = new Matrix(4, 4);
|
|
||||||
m.initializeZero();
|
|
||||||
|
|
||||||
double[] zeros = new double[]{0.0, 0.0, 0.0, 0.0};
|
|
||||||
assertArrayEquals(zeros, m.coefficients[0]);
|
|
||||||
assertArrayEquals(zeros, m.coefficients[1]);
|
|
||||||
assertArrayEquals(zeros, m.coefficients[2]);
|
|
||||||
assertArrayEquals(zeros, m.coefficients[3]);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void initializeRandom() {
|
|
||||||
Matrix m = new Matrix(4, 4);
|
|
||||||
m.initializeRandom(-1, 1);
|
|
||||||
|
|
||||||
assertTrue(Arrays.stream(m.coefficients[0]).allMatch((d) -> -1.0 <= d && d < 1.0));
|
|
||||||
assertTrue(Arrays.stream(m.coefficients[1]).allMatch((d) -> -1.0 <= d && d < 1.0));
|
|
||||||
assertTrue(Arrays.stream(m.coefficients[2]).allMatch((d) -> -1.0 <= d && d < 1.0));
|
|
||||||
assertTrue(Arrays.stream(m.coefficients[3]).allMatch((d) -> -1.0 <= d && d < 1.0));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -2,15 +2,14 @@ package schule.ngb.zm.ml;
|
|||||||
|
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import schule.ngb.zm.Constants;
|
||||||
import schule.ngb.zm.util.Log;
|
import schule.ngb.zm.util.Log;
|
||||||
|
import schule.ngb.zm.util.Timer;
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.*;
|
|
||||||
|
|
||||||
class NeuralNetworkTest {
|
class NeuralNetworkTest {
|
||||||
|
|
||||||
@BeforeAll
|
@BeforeAll
|
||||||
@@ -18,7 +17,14 @@ class NeuralNetworkTest {
|
|||||||
Log.enableGlobalDebugging();
|
Log.enableGlobalDebugging();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@BeforeAll
|
||||||
|
static void setupMatrixLibrary() {
|
||||||
|
Constants.setSeed(1001);
|
||||||
|
//MatrixFactory.matrixType = MatrixFactory.ColtMatrix.class;
|
||||||
|
MatrixFactory.matrixType = DoubleMatrix.class;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*@Test
|
||||||
void readWrite() {
|
void readWrite() {
|
||||||
// XOR Dataset
|
// XOR Dataset
|
||||||
NeuralNetwork net = new NeuralNetwork(2, 4, 1);
|
NeuralNetwork net = new NeuralNetwork(2, 4, 1);
|
||||||
@@ -53,7 +59,7 @@ class NeuralNetworkTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
assertArrayEquals(net.predict(inputs), net2.predict(inputs));
|
assertArrayEquals(net.predict(inputs), net2.predict(inputs));
|
||||||
}
|
}*/
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void learnXor() {
|
void learnXor() {
|
||||||
@@ -78,14 +84,14 @@ class NeuralNetworkTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// calculate predictions
|
// calculate predictions
|
||||||
double[][] predictions = net.predict(inputs);
|
MLMatrix predictions = net.predict(inputs);
|
||||||
for( int i = 0; i < 4; i++ ) {
|
for( int i = 0; i < 4; i++ ) {
|
||||||
int parsed_pred = predictions[i][0] < 0.5 ? 0 : 1;
|
int parsed_pred = predictions.get(i, 0) < 0.5 ? 0 : 1;
|
||||||
|
|
||||||
System.out.printf(
|
System.out.printf(
|
||||||
"{%.0f, %.0f} = %.4f (%d) -> %s\n",
|
"{%.0f, %.0f} = %.4f (%d) -> %s\n",
|
||||||
inputs[i][0], inputs[i][1],
|
inputs[i][0], inputs[i][1],
|
||||||
predictions[i][0],
|
predictions.get(i, 0),
|
||||||
parsed_pred,
|
parsed_pred,
|
||||||
parsed_pred == outputs[i][0] ? "correct" : "miss"
|
parsed_pred == outputs[i][0] ? "correct" : "miss"
|
||||||
);
|
);
|
||||||
@@ -109,12 +115,16 @@ class NeuralNetworkTest {
|
|||||||
for( int i = 0; i < trainingData.size(); i++ ) {
|
for( int i = 0; i < trainingData.size(); i++ ) {
|
||||||
inputs[i][0] = trainingData.get(i).a;
|
inputs[i][0] = trainingData.get(i).a;
|
||||||
inputs[i][1] = trainingData.get(i).b;
|
inputs[i][1] = trainingData.get(i).b;
|
||||||
outputs[i][0] = trainingData.get(i).result;
|
outputs[i][0] = trainingData.get(i).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Timer timer = new Timer();
|
||||||
|
|
||||||
System.out.println("Training the neural net to learn "+OPERATION+"...");
|
System.out.println("Training the neural net to learn "+OPERATION+"...");
|
||||||
|
timer.start();
|
||||||
net.train(inputs, outputs, TRAINING_CYCLES);
|
net.train(inputs, outputs, TRAINING_CYCLES);
|
||||||
System.out.println(" finished training");
|
timer.stop();
|
||||||
|
System.out.println(" finished training (" + timer.getMillis() + "ms)");
|
||||||
|
|
||||||
for( int i = 1; i <= net.getLayerCount(); i++ ) {
|
for( int i = 1; i <= net.getLayerCount(); i++ ) {
|
||||||
System.out.println("Layer " +i + " weights");
|
System.out.println("Layer " +i + " weights");
|
||||||
@@ -136,19 +146,18 @@ class NeuralNetworkTest {
|
|||||||
System.out.printf(
|
System.out.printf(
|
||||||
"Prediction on data (%.2f, %.2f) was %.4f, expected %.2f (of by %.4f)\n",
|
"Prediction on data (%.2f, %.2f) was %.4f, expected %.2f (of by %.4f)\n",
|
||||||
data.a, data.b,
|
data.a, data.b,
|
||||||
net.getOutput()[0][0],
|
net.getOutput().get(0, 0),
|
||||||
data.result,
|
data.getResult(),
|
||||||
net.getOutput()[0][0] - data.result
|
net.getOutput().get(0, 0) - data.getResult()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<TestData> createTrainingSet( int trainingSetSize, CalcType operation ) {
|
private List<TestData> createTrainingSet( int trainingSetSize, CalcType operation ) {
|
||||||
Random random = new Random();
|
|
||||||
List<TestData> tuples = new ArrayList<>();
|
List<TestData> tuples = new ArrayList<>();
|
||||||
|
|
||||||
for( int i = 0; i < trainingSetSize; i++ ) {
|
for( int i = 0; i < trainingSetSize; i++ ) {
|
||||||
double s1 = random.nextDouble() * 0.5;
|
double s1 = Constants.random() * 0.5;
|
||||||
double s2 = random.nextDouble() * 0.5;
|
double s2 = Constants.random() * 0.5;
|
||||||
|
|
||||||
switch( operation ) {
|
switch( operation ) {
|
||||||
case ADD:
|
case ADD:
|
||||||
@@ -181,7 +190,6 @@ class NeuralNetworkTest {
|
|||||||
|
|
||||||
double a;
|
double a;
|
||||||
double b;
|
double b;
|
||||||
double result;
|
|
||||||
CalcType type;
|
CalcType type;
|
||||||
|
|
||||||
TestData( double a, double b ) {
|
TestData( double a, double b ) {
|
||||||
@@ -189,6 +197,8 @@ class NeuralNetworkTest {
|
|||||||
this.b = b;
|
this.b = b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
abstract double getResult();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final class AddData extends TestData {
|
private static final class AddData extends TestData {
|
||||||
@@ -197,7 +207,9 @@ class NeuralNetworkTest {
|
|||||||
|
|
||||||
public AddData( double a, double b ) {
|
public AddData( double a, double b ) {
|
||||||
super(a, b);
|
super(a, b);
|
||||||
result = a + b;
|
}
|
||||||
|
double getResult() {
|
||||||
|
return a+b;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -208,7 +220,9 @@ class NeuralNetworkTest {
|
|||||||
|
|
||||||
public SubData( double a, double b ) {
|
public SubData( double a, double b ) {
|
||||||
super(a, b);
|
super(a, b);
|
||||||
result = a - b;
|
}
|
||||||
|
double getResult() {
|
||||||
|
return a-b;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -219,7 +233,9 @@ class NeuralNetworkTest {
|
|||||||
|
|
||||||
public MulData( double a, double b ) {
|
public MulData( double a, double b ) {
|
||||||
super(a, b);
|
super(a, b);
|
||||||
result = a * b;
|
}
|
||||||
|
double getResult() {
|
||||||
|
return a*b;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -233,7 +249,9 @@ class NeuralNetworkTest {
|
|||||||
if( b == 0.0 ) {
|
if( b == 0.0 ) {
|
||||||
b = .1;
|
b = .1;
|
||||||
}
|
}
|
||||||
result = a / b;
|
}
|
||||||
|
double getResult() {
|
||||||
|
return a/b;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -244,7 +262,12 @@ class NeuralNetworkTest {
|
|||||||
|
|
||||||
public ModData( double b, double a ) {
|
public ModData( double b, double a ) {
|
||||||
super(b, a);
|
super(b, a);
|
||||||
result = a % b;
|
if( b == 0.0 ) {
|
||||||
|
b = .1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
double getResult() {
|
||||||
|
return a%b;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user