From 6040545274a8937e4a23549d335b3a5854242a8d Mon Sep 17 00:00:00 2001 From: "J. Neugebauer" Date: Fri, 15 Jul 2022 19:42:48 +0200 Subject: [PATCH] Tests zum Laden und SPeichern --- .../schule/ngb/zm/ml/NeuralNetworkTest.java | 89 ++++++++++++++++++- 1 file changed, 85 insertions(+), 4 deletions(-) diff --git a/src/test/java/schule/ngb/zm/ml/NeuralNetworkTest.java b/src/test/java/schule/ngb/zm/ml/NeuralNetworkTest.java index eb9afb1..f81fe4d 100644 --- a/src/test/java/schule/ngb/zm/ml/NeuralNetworkTest.java +++ b/src/test/java/schule/ngb/zm/ml/NeuralNetworkTest.java @@ -1,22 +1,103 @@ package schule.ngb.zm.ml; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import schule.ngb.zm.util.Log; +import java.io.File; import java.util.ArrayList; import java.util.List; import java.util.Random; +import static org.junit.jupiter.api.Assertions.*; + class NeuralNetworkTest { + @BeforeAll + static void enableDebugging() { + Log.enableGlobalDebugging(); + } + + @Test + void readWrite() { + // XOR Dataset + NeuralNetwork net = new NeuralNetwork(2, 4, 1); + double[][] inputs = new double[][]{ + {0, 0}, {0, 1}, {1, 0}, {1, 1} + }; + double[][] outputs = new double[][]{ + {0}, {1}, {1}, {0} + }; + + System.out.println("Training the neural net to learn XOR..."); + net.train(inputs, outputs, 10000); + System.out.println(" finished training"); + + NeuralNetwork.saveToFile("./ml-test.txt", net); + assertTrue(new File("./ml-test.txt").isFile()); + + NeuralNetwork net2 = NeuralNetwork.loadFromFile("./ml-test.txt"); + assertEquals(net.getLayerCount(), net2.getLayerCount()); + for( int l = 0; l < net2.getLayerCount(); l++ ) { + NeuronLayer layer = net.getLayer(l+1); + NeuronLayer layer2 = net2.getLayer(l+1); + + for( int i = 0; i < layer.getInputCount(); i++ ) { + for( int j = 0; j < layer.getNeuronCount(); j++ ) { + assertEquals(layer.weights.coefficients[i][j], layer2.weights.coefficients[i][j]); + } + } + for( int j = 0; j < layer.getNeuronCount(); j++ ) { + assertEquals(layer.biases[j], layer2.biases[j]); + } + } + + assertArrayEquals(net.predict(inputs), net2.predict(inputs)); + } + + @Test + void learnXor() { + int TRAINING_CYCLES = 40000; + + NeuralNetwork net = new NeuralNetwork(2, 4, 1); + + double[][] inputs = new double[][]{ + {0, 0}, {0, 1}, {1, 0}, {1, 1} + }; + double[][] outputs = new double[][]{ + {0}, {1}, {1}, {0} + }; + + System.out.println("Training the neural net to learn XOR..."); + net.train(inputs, outputs, TRAINING_CYCLES); + System.out.println(" finished training"); + + for( int i = 1; i <= net.getLayerCount(); i++ ) { + System.out.println("Layer " +i + " weights"); + System.out.println(net.getLayer(i)); + } + + // calculate predictions + double[][] predictions = net.predict(inputs); + for( int i = 0; i < 4; i++ ) { + int parsed_pred = predictions[i][0] < 0.5 ? 0 : 1; + + System.out.printf( + "{%.0f, %.0f} = %.4f (%d) -> %s\n", + inputs[i][0], inputs[i][1], + predictions[i][0], + parsed_pred, + parsed_pred == outputs[i][0] ? "correct" : "miss" + ); + } + } + @Test void learnCalc() { - Log.enableGlobalDebugging(); - int INPUT_SIZE = 50; int PREDICT_SIZE = 4; int TRAINING_CYCLES = 40000; - CalcType OPERATION = CalcType.SUB; + CalcType OPERATION = CalcType.ADD; // Create neural network with layer1: 4 neurones, layer2: 1 neuron NeuralNetwork net = new NeuralNetwork(2, 8, 4, 1); @@ -37,7 +118,7 @@ class NeuralNetworkTest { for( int i = 1; i <= net.getLayerCount(); i++ ) { System.out.println("Layer " +i + " weights"); - System.out.println(net.getLayer(i).weights); + System.out.println(net.getLayer(i)); } // calculate the predictions on unknown data