diff --git a/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java b/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java index 02e4935..b6464e5 100644 --- a/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java +++ b/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java @@ -1,9 +1,110 @@ package schule.ngb.zm.ml; import schule.ngb.zm.util.Log; +import schule.ngb.zm.util.ResourceStreamProvider; + +import java.io.*; +import java.util.LinkedList; +import java.util.List; +import java.util.NoSuchElementException; public class NeuralNetwork { + public static void saveToFile( String source, NeuralNetwork network ) { + try( + Writer writer = ResourceStreamProvider.getWriter(source); + PrintWriter out = new PrintWriter(writer) + ) { + for( NeuronLayer layer: network.layers ) { + out.print(layer.getNeuronCount()); + out.print(' '); + out.print(layer.getInputCount()); + out.println(); + + for( int i = 0; i < layer.getInputCount(); i++ ) { + for( int j = 0; j < layer.getNeuronCount(); j++ ) { + out.print(layer.weights.coefficients[i][j]); + out.print(' '); + } + out.println(); + } + for( int j = 0; j < layer.getNeuronCount(); j++ ) { + out.print(layer.biases[j]); + out.print(' '); + } + out.println(); + } + out.flush(); + } catch( IOException ex ) { + LOG.warn(ex, ""); + } + } + + public static NeuralNetwork loadFromFile( String source ) { + try( + Reader reader = ResourceStreamProvider.getReader(source); + BufferedReader in = new BufferedReader(reader) + ) { + List layers = new LinkedList<>(); + String line; + while( (line = in.readLine()) != null ) { + String[] split = line.split(" "); + int neurons = Integer.parseInt(split[0]); + int inputs = Integer.parseInt(split[1]); + + NeuronLayer layer = new NeuronLayer(neurons, inputs); + for( int i = 0; i < inputs; i++ ) { + split = in.readLine().split(" "); + for( int j = 0; j < neurons; j++ ) { + layer.weights.coefficients[i][j] = Double.parseDouble(split[j]); + } + } + // Load Biases + split = in.readLine().split(" "); + for( int j = 0; j < neurons; j++ ) { + layer.biases[j] = Double.parseDouble(split[j]); + } + + layers.add(layer); + } + + return new NeuralNetwork(layers); + } catch( IOException | NoSuchElementException ex ) { + LOG.warn(ex, "Could not load neural network from source <%s>", source); + } + return null; + } + + /*public static NeuralNetwork loadFromFile( String source ) { + try( + InputStream stream = ResourceStreamProvider.getInputStream(source); + Scanner in = new Scanner(stream) + ) { + List layers = new LinkedList<>(); + while( in.hasNext() ) { + int neurons = in.nextInt(); + int inputs = in.nextInt(); + + NeuronLayer layer = new NeuronLayer(neurons, inputs); + for( int i = 0; i < inputs; i++ ) { + for( int j = 0; j < neurons; j++ ) { + layer.weights.coefficients[i][j] = in.nextDouble(); + } + } + for( int j = 0; j < neurons; j++ ) { + layer.biases[j] = in.nextDouble(); + } + + layers.add(layer); + } + + return new NeuralNetwork(layers); + } catch( IOException | NoSuchElementException ex ) { + LOG.warn(ex, "Could not load neural network from source <%s>", source); + } + return null; + }*/ + private NeuronLayer[] layers; private double[][] output; @@ -18,27 +119,36 @@ public class NeuralNetwork { this(new NeuronLayer(layer1, inputs), new NeuronLayer(layer2, layer1), new NeuronLayer(outputs, layer2)); } - public NeuralNetwork( NeuronLayer layer1, NeuronLayer layer2 ) { - this.layers = new NeuronLayer[2]; - this.layers[0] = layer1; - this.layers[1] = layer2; - layer1.connect(null, layer2); - layer2.connect(layer1, null); + public NeuralNetwork( int inputs, int layer1, int layer2, int layer3, int outputs ) { + this(new NeuronLayer(layer1, inputs), new NeuronLayer(layer2, layer1), new NeuronLayer(layer3, layer2), new NeuronLayer(outputs, layer3)); } - public NeuralNetwork( NeuronLayer layer1, NeuronLayer layer2, NeuronLayer layer3 ) { - this.layers = new NeuronLayer[3]; - this.layers[0] = layer1; - this.layers[1] = layer2; - this.layers[2] = layer3; - layer1.connect(null, layer2); - layer2.connect(layer1, layer3); - layer3.connect(layer2, null); + public NeuralNetwork( List layers ) { + this.layers = new NeuronLayer[layers.size()]; + for( int i = 0; i < layers.size(); i++ ) { + this.layers[i] = layers.get(i); + if( i > 0 ) { + this.layers[i-1].setNextLayer(this.layers[i]); + } + } + } + + public NeuralNetwork( NeuronLayer... layers ) { + this.layers = new NeuronLayer[layers.length]; + for( int i = 0; i < layers.length; i++ ) { + this.layers[i] = layers[i]; + if( i > 0 ) { + this.layers[i-1].setNextLayer(this.layers[i]); + } + } } public int getLayerCount() { return layers.length; } + public NeuronLayer[] getLayers() { + return layers; + } public NeuronLayer getLayer( int i ) { return layers[i - 1]; @@ -63,7 +173,7 @@ public class NeuralNetwork { } public void learn( double[][] expected ) { - layers[layers.length-1].backprop(expected, learningRate); + layers[layers.length - 1].backprop(expected, learningRate); } public void train( double[][] inputs, double[][] expected, int iterations/*, double minChange, int timeout */ ) { diff --git a/src/main/java/schule/ngb/zm/util/FileLoader.java b/src/main/java/schule/ngb/zm/util/FileLoader.java index 3ffa18e..0d680bd 100644 --- a/src/main/java/schule/ngb/zm/util/FileLoader.java +++ b/src/main/java/schule/ngb/zm/util/FileLoader.java @@ -1,6 +1,6 @@ package schule.ngb.zm.util; -import java.io.*; +import java.io.IOException; import java.net.URISyntaxException; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; @@ -8,7 +8,6 @@ import java.nio.file.Files; import java.nio.file.Paths; import java.util.Arrays; import java.util.Collections; -import java.util.LinkedList; import java.util.List; public final class FileLoader { @@ -61,10 +60,10 @@ public final class FileLoader { .lines(Paths.get(ResourceStreamProvider.getResourceURL(source).toURI()), charset) .skip(n) .map( - (line) -> Arrays + ( line ) -> Arrays .stream(line.split(Character.toString(separator))) .mapToDouble( - (value) -> { + ( value ) -> { try { return Double.parseDouble(value); } catch( NumberFormatException nfe ) { diff --git a/src/main/java/schule/ngb/zm/util/FontLoader.java b/src/main/java/schule/ngb/zm/util/FontLoader.java index a71dbe0..7b85553 100644 --- a/src/main/java/schule/ngb/zm/util/FontLoader.java +++ b/src/main/java/schule/ngb/zm/util/FontLoader.java @@ -34,7 +34,7 @@ public class FontLoader { } // Load userfonts - try( InputStream in = ResourceStreamProvider.getResourceStream(source) ) { + try( InputStream in = ResourceStreamProvider.getInputStream(source) ) { font = Font.createFont(Font.TRUETYPE_FONT, in).deriveFont(Font.PLAIN); if( font != null ) { diff --git a/src/main/java/schule/ngb/zm/util/ImageLoader.java b/src/main/java/schule/ngb/zm/util/ImageLoader.java index 07010a8..cef8257 100644 --- a/src/main/java/schule/ngb/zm/util/ImageLoader.java +++ b/src/main/java/schule/ngb/zm/util/ImageLoader.java @@ -13,9 +13,7 @@ import java.io.IOException; import java.io.InputStream; import java.lang.ref.SoftReference; import java.util.Map; -import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; -import java.util.logging.Logger; public final class ImageLoader { @@ -29,7 +27,7 @@ public final class ImageLoader { * Lädt ein Bild von der angegebenen Quelle {@code source}. *

* Die Bilddatei wird nach den Regeln von - * {@link ResourceStreamProvider#getResourceStream(String)} gesucht und + * {@link ResourceStreamProvider#getInputStream(String)} gesucht und * geöffnet. Tritt dabei ein Fehler auf oder konnte keine passende Datei * gefunden werden, wird {@code null} zurückgegeben. *

@@ -70,7 +68,7 @@ public final class ImageLoader { } BufferedImage img = null; - try( InputStream in = ResourceStreamProvider.getResourceStream(source) ) { + try( InputStream in = ResourceStreamProvider.getInputStream(source) ) { //URL url = ResourceStreamProvider.getResourceURL(source); //BufferedImage img = ImageIO.read(url); diff --git a/src/main/java/schule/ngb/zm/util/ResourceStreamProvider.java b/src/main/java/schule/ngb/zm/util/ResourceStreamProvider.java index cf231f9..fe3024d 100644 --- a/src/main/java/schule/ngb/zm/util/ResourceStreamProvider.java +++ b/src/main/java/schule/ngb/zm/util/ResourceStreamProvider.java @@ -3,10 +3,9 @@ package schule.ngb.zm.util; import schule.ngb.zm.Zeichenmaschine; import java.io.*; +import java.net.URISyntaxException; import java.net.URL; -import java.util.logging.Level; -import java.util.logging.LogManager; -import java.util.logging.Logger; +import java.util.stream.StreamSupport; /** * Helferklasse, um {@link InputStream}s für Resourcen zu erhalten. @@ -50,15 +49,11 @@ public class ResourceStreamProvider { * einer bestehenden Resource oder falls * keine passende Resource gefunden wurde. */ - public static InputStream getResourceStream( String source ) throws NullPointerException, IllegalArgumentException, IOException { - if( source == null ) { - throw new NullPointerException("Resource source may not be null"); - } - if( source.length() == 0 ) { - throw new IllegalArgumentException("Resource source may not be empty."); - } + public static InputStream getInputStream( String source ) throws NullPointerException, IllegalArgumentException, IOException { + Validator.requireNotNull(source, "Resource source may not be null"); + Validator.requireNotEmpty(source, "Resource source may not be empty."); - InputStream in; + InputStream in = null; // See if source is a readable file File file = new File(source); @@ -72,7 +67,9 @@ public class ResourceStreamProvider { } // File does not exist, try other means // load ressource relative to .class-file - in = Zeichenmaschine.class.getResourceAsStream(source); + if( in == null ) { + in = Zeichenmaschine.class.getResourceAsStream(source); + } // relative to ClassLoader if( in == null ) { @@ -89,6 +86,16 @@ public class ResourceStreamProvider { return in; } + public static InputStream getInputStream( File file ) throws IOException { + Validator.requireNotNull(file, "Provided file can't be null."); + return new FileInputStream(file); + } + + public static InputStream getInputStream( URL url ) throws IOException { + Validator.requireNotNull(url, "Provided URL can't be null."); + return url.openStream(); + } + /** * Ermittelt zur angegebenen Quelle einen passenden {@link URL} (Unified * Resource Locator). Eine passende Datei-Resource wird wie folgt @@ -119,12 +126,8 @@ public class ResourceStreamProvider { * einer bestehenden Resource. */ public static URL getResourceURL( String source ) throws NullPointerException, IllegalArgumentException, IOException { - if( source == null ) { - throw new NullPointerException("Resource source may not be null"); - } - if( source.length() == 0 ) { - throw new IllegalArgumentException("Resource source may not be empty."); - } + Validator.requireNotNull(source, "Resource source may not be null"); + Validator.requireNotEmpty(source, "Resource source may not be empty."); File file = new File(source); if( file.isFile() ) { @@ -146,20 +149,24 @@ public class ResourceStreamProvider { return new URL(source); } - public static InputStream getResourceStream( File file ) throws FileNotFoundException, SecurityException { - if( file == null ) { - throw new NullPointerException("Provided file can't be null."); - } + public static OutputStream getOutputStream( String source ) throws IOException { + Validator.requireNotNull(source, "Resource source may not be null"); + Validator.requireNotEmpty(source, "Resource source may not be empty."); - return new FileInputStream(file); + return getOutputStream(new File(source)); } - public static InputStream getResourceStream( URL url ) throws IOException { - if( url == null ) { - throw new NullPointerException("Provided URL can't be null."); - } + public static OutputStream getOutputStream( File file ) throws IOException { + Validator.requireNotNull(file, "Provided file can't be null."); + return new FileOutputStream(file); + } - return url.openStream(); + public static Reader getReader( String source ) throws IOException { + return new InputStreamReader(getInputStream(source)); + } + + public static Writer getWriter( String source ) throws IOException { + return new OutputStreamWriter(getOutputStream(source)); } private ResourceStreamProvider() {