Tests zum Laden und SPeichern

This commit is contained in:
ngb
2022-07-15 19:42:48 +02:00
parent 7b84570d18
commit 6040545274

View File

@@ -1,22 +1,103 @@
package schule.ngb.zm.ml;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import schule.ngb.zm.util.Log;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
class NeuralNetworkTest {
@BeforeAll
static void enableDebugging() {
Log.enableGlobalDebugging();
}
@Test
void readWrite() {
// XOR Dataset
NeuralNetwork net = new NeuralNetwork(2, 4, 1);
double[][] inputs = new double[][]{
{0, 0}, {0, 1}, {1, 0}, {1, 1}
};
double[][] outputs = new double[][]{
{0}, {1}, {1}, {0}
};
System.out.println("Training the neural net to learn XOR...");
net.train(inputs, outputs, 10000);
System.out.println(" finished training");
NeuralNetwork.saveToFile("./ml-test.txt", net);
assertTrue(new File("./ml-test.txt").isFile());
NeuralNetwork net2 = NeuralNetwork.loadFromFile("./ml-test.txt");
assertEquals(net.getLayerCount(), net2.getLayerCount());
for( int l = 0; l < net2.getLayerCount(); l++ ) {
NeuronLayer layer = net.getLayer(l+1);
NeuronLayer layer2 = net2.getLayer(l+1);
for( int i = 0; i < layer.getInputCount(); i++ ) {
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
assertEquals(layer.weights.coefficients[i][j], layer2.weights.coefficients[i][j]);
}
}
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
assertEquals(layer.biases[j], layer2.biases[j]);
}
}
assertArrayEquals(net.predict(inputs), net2.predict(inputs));
}
@Test
void learnXor() {
int TRAINING_CYCLES = 40000;
NeuralNetwork net = new NeuralNetwork(2, 4, 1);
double[][] inputs = new double[][]{
{0, 0}, {0, 1}, {1, 0}, {1, 1}
};
double[][] outputs = new double[][]{
{0}, {1}, {1}, {0}
};
System.out.println("Training the neural net to learn XOR...");
net.train(inputs, outputs, TRAINING_CYCLES);
System.out.println(" finished training");
for( int i = 1; i <= net.getLayerCount(); i++ ) {
System.out.println("Layer " +i + " weights");
System.out.println(net.getLayer(i));
}
// calculate predictions
double[][] predictions = net.predict(inputs);
for( int i = 0; i < 4; i++ ) {
int parsed_pred = predictions[i][0] < 0.5 ? 0 : 1;
System.out.printf(
"{%.0f, %.0f} = %.4f (%d) -> %s\n",
inputs[i][0], inputs[i][1],
predictions[i][0],
parsed_pred,
parsed_pred == outputs[i][0] ? "correct" : "miss"
);
}
}
@Test
void learnCalc() {
Log.enableGlobalDebugging();
int INPUT_SIZE = 50;
int PREDICT_SIZE = 4;
int TRAINING_CYCLES = 40000;
CalcType OPERATION = CalcType.SUB;
CalcType OPERATION = CalcType.ADD;
// Create neural network with layer1: 4 neurones, layer2: 1 neuron
NeuralNetwork net = new NeuralNetwork(2, 8, 4, 1);
@@ -37,7 +118,7 @@ class NeuralNetworkTest {
for( int i = 1; i <= net.getLayerCount(); i++ ) {
System.out.println("Layer " +i + " weights");
System.out.println(net.getLayer(i).weights);
System.out.println(net.getLayer(i));
}
// calculate the predictions on unknown data