mirror of
https://github.com/jneug/zeichenmaschine.git
synced 2026-04-14 14:43:33 +02:00
Tests zum Laden und SPeichern
This commit is contained in:
@@ -1,22 +1,103 @@
|
|||||||
package schule.ngb.zm.ml;
|
package schule.ngb.zm.ml;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import schule.ngb.zm.util.Log;
|
import schule.ngb.zm.util.Log;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
class NeuralNetworkTest {
|
class NeuralNetworkTest {
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
static void enableDebugging() {
|
||||||
|
Log.enableGlobalDebugging();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void readWrite() {
|
||||||
|
// XOR Dataset
|
||||||
|
NeuralNetwork net = new NeuralNetwork(2, 4, 1);
|
||||||
|
double[][] inputs = new double[][]{
|
||||||
|
{0, 0}, {0, 1}, {1, 0}, {1, 1}
|
||||||
|
};
|
||||||
|
double[][] outputs = new double[][]{
|
||||||
|
{0}, {1}, {1}, {0}
|
||||||
|
};
|
||||||
|
|
||||||
|
System.out.println("Training the neural net to learn XOR...");
|
||||||
|
net.train(inputs, outputs, 10000);
|
||||||
|
System.out.println(" finished training");
|
||||||
|
|
||||||
|
NeuralNetwork.saveToFile("./ml-test.txt", net);
|
||||||
|
assertTrue(new File("./ml-test.txt").isFile());
|
||||||
|
|
||||||
|
NeuralNetwork net2 = NeuralNetwork.loadFromFile("./ml-test.txt");
|
||||||
|
assertEquals(net.getLayerCount(), net2.getLayerCount());
|
||||||
|
for( int l = 0; l < net2.getLayerCount(); l++ ) {
|
||||||
|
NeuronLayer layer = net.getLayer(l+1);
|
||||||
|
NeuronLayer layer2 = net2.getLayer(l+1);
|
||||||
|
|
||||||
|
for( int i = 0; i < layer.getInputCount(); i++ ) {
|
||||||
|
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
|
||||||
|
assertEquals(layer.weights.coefficients[i][j], layer2.weights.coefficients[i][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
|
||||||
|
assertEquals(layer.biases[j], layer2.biases[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assertArrayEquals(net.predict(inputs), net2.predict(inputs));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void learnXor() {
|
||||||
|
int TRAINING_CYCLES = 40000;
|
||||||
|
|
||||||
|
NeuralNetwork net = new NeuralNetwork(2, 4, 1);
|
||||||
|
|
||||||
|
double[][] inputs = new double[][]{
|
||||||
|
{0, 0}, {0, 1}, {1, 0}, {1, 1}
|
||||||
|
};
|
||||||
|
double[][] outputs = new double[][]{
|
||||||
|
{0}, {1}, {1}, {0}
|
||||||
|
};
|
||||||
|
|
||||||
|
System.out.println("Training the neural net to learn XOR...");
|
||||||
|
net.train(inputs, outputs, TRAINING_CYCLES);
|
||||||
|
System.out.println(" finished training");
|
||||||
|
|
||||||
|
for( int i = 1; i <= net.getLayerCount(); i++ ) {
|
||||||
|
System.out.println("Layer " +i + " weights");
|
||||||
|
System.out.println(net.getLayer(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate predictions
|
||||||
|
double[][] predictions = net.predict(inputs);
|
||||||
|
for( int i = 0; i < 4; i++ ) {
|
||||||
|
int parsed_pred = predictions[i][0] < 0.5 ? 0 : 1;
|
||||||
|
|
||||||
|
System.out.printf(
|
||||||
|
"{%.0f, %.0f} = %.4f (%d) -> %s\n",
|
||||||
|
inputs[i][0], inputs[i][1],
|
||||||
|
predictions[i][0],
|
||||||
|
parsed_pred,
|
||||||
|
parsed_pred == outputs[i][0] ? "correct" : "miss"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void learnCalc() {
|
void learnCalc() {
|
||||||
Log.enableGlobalDebugging();
|
|
||||||
|
|
||||||
int INPUT_SIZE = 50;
|
int INPUT_SIZE = 50;
|
||||||
int PREDICT_SIZE = 4;
|
int PREDICT_SIZE = 4;
|
||||||
int TRAINING_CYCLES = 40000;
|
int TRAINING_CYCLES = 40000;
|
||||||
CalcType OPERATION = CalcType.SUB;
|
CalcType OPERATION = CalcType.ADD;
|
||||||
|
|
||||||
// Create neural network with layer1: 4 neurones, layer2: 1 neuron
|
// Create neural network with layer1: 4 neurones, layer2: 1 neuron
|
||||||
NeuralNetwork net = new NeuralNetwork(2, 8, 4, 1);
|
NeuralNetwork net = new NeuralNetwork(2, 8, 4, 1);
|
||||||
@@ -37,7 +118,7 @@ class NeuralNetworkTest {
|
|||||||
|
|
||||||
for( int i = 1; i <= net.getLayerCount(); i++ ) {
|
for( int i = 1; i <= net.getLayerCount(); i++ ) {
|
||||||
System.out.println("Layer " +i + " weights");
|
System.out.println("Layer " +i + " weights");
|
||||||
System.out.println(net.getLayer(i).weights);
|
System.out.println(net.getLayer(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculate the predictions on unknown data
|
// calculate the predictions on unknown data
|
||||||
|
|||||||
Reference in New Issue
Block a user