Laden und speichern von Netzen ermöglicht

This commit is contained in:
ngb
2022-07-15 19:42:13 +02:00
parent d5abd4ef68
commit b24eec5063
5 changed files with 166 additions and 52 deletions

View File

@@ -1,9 +1,110 @@
package schule.ngb.zm.ml; package schule.ngb.zm.ml;
import schule.ngb.zm.util.Log; import schule.ngb.zm.util.Log;
import schule.ngb.zm.util.ResourceStreamProvider;
import java.io.*;
import java.util.LinkedList;
import java.util.List;
import java.util.NoSuchElementException;
public class NeuralNetwork { public class NeuralNetwork {
public static void saveToFile( String source, NeuralNetwork network ) {
try(
Writer writer = ResourceStreamProvider.getWriter(source);
PrintWriter out = new PrintWriter(writer)
) {
for( NeuronLayer layer: network.layers ) {
out.print(layer.getNeuronCount());
out.print(' ');
out.print(layer.getInputCount());
out.println();
for( int i = 0; i < layer.getInputCount(); i++ ) {
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
out.print(layer.weights.coefficients[i][j]);
out.print(' ');
}
out.println();
}
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
out.print(layer.biases[j]);
out.print(' ');
}
out.println();
}
out.flush();
} catch( IOException ex ) {
LOG.warn(ex, "");
}
}
public static NeuralNetwork loadFromFile( String source ) {
try(
Reader reader = ResourceStreamProvider.getReader(source);
BufferedReader in = new BufferedReader(reader)
) {
List<NeuronLayer> layers = new LinkedList<>();
String line;
while( (line = in.readLine()) != null ) {
String[] split = line.split(" ");
int neurons = Integer.parseInt(split[0]);
int inputs = Integer.parseInt(split[1]);
NeuronLayer layer = new NeuronLayer(neurons, inputs);
for( int i = 0; i < inputs; i++ ) {
split = in.readLine().split(" ");
for( int j = 0; j < neurons; j++ ) {
layer.weights.coefficients[i][j] = Double.parseDouble(split[j]);
}
}
// Load Biases
split = in.readLine().split(" ");
for( int j = 0; j < neurons; j++ ) {
layer.biases[j] = Double.parseDouble(split[j]);
}
layers.add(layer);
}
return new NeuralNetwork(layers);
} catch( IOException | NoSuchElementException ex ) {
LOG.warn(ex, "Could not load neural network from source <%s>", source);
}
return null;
}
/*public static NeuralNetwork loadFromFile( String source ) {
try(
InputStream stream = ResourceStreamProvider.getInputStream(source);
Scanner in = new Scanner(stream)
) {
List<NeuronLayer> layers = new LinkedList<>();
while( in.hasNext() ) {
int neurons = in.nextInt();
int inputs = in.nextInt();
NeuronLayer layer = new NeuronLayer(neurons, inputs);
for( int i = 0; i < inputs; i++ ) {
for( int j = 0; j < neurons; j++ ) {
layer.weights.coefficients[i][j] = in.nextDouble();
}
}
for( int j = 0; j < neurons; j++ ) {
layer.biases[j] = in.nextDouble();
}
layers.add(layer);
}
return new NeuralNetwork(layers);
} catch( IOException | NoSuchElementException ex ) {
LOG.warn(ex, "Could not load neural network from source <%s>", source);
}
return null;
}*/
private NeuronLayer[] layers; private NeuronLayer[] layers;
private double[][] output; private double[][] output;
@@ -18,27 +119,36 @@ public class NeuralNetwork {
this(new NeuronLayer(layer1, inputs), new NeuronLayer(layer2, layer1), new NeuronLayer(outputs, layer2)); this(new NeuronLayer(layer1, inputs), new NeuronLayer(layer2, layer1), new NeuronLayer(outputs, layer2));
} }
public NeuralNetwork( NeuronLayer layer1, NeuronLayer layer2 ) { public NeuralNetwork( int inputs, int layer1, int layer2, int layer3, int outputs ) {
this.layers = new NeuronLayer[2]; this(new NeuronLayer(layer1, inputs), new NeuronLayer(layer2, layer1), new NeuronLayer(layer3, layer2), new NeuronLayer(outputs, layer3));
this.layers[0] = layer1;
this.layers[1] = layer2;
layer1.connect(null, layer2);
layer2.connect(layer1, null);
} }
public NeuralNetwork( NeuronLayer layer1, NeuronLayer layer2, NeuronLayer layer3 ) { public NeuralNetwork( List<NeuronLayer> layers ) {
this.layers = new NeuronLayer[3]; this.layers = new NeuronLayer[layers.size()];
this.layers[0] = layer1; for( int i = 0; i < layers.size(); i++ ) {
this.layers[1] = layer2; this.layers[i] = layers.get(i);
this.layers[2] = layer3; if( i > 0 ) {
layer1.connect(null, layer2); this.layers[i-1].setNextLayer(this.layers[i]);
layer2.connect(layer1, layer3); }
layer3.connect(layer2, null); }
}
public NeuralNetwork( NeuronLayer... layers ) {
this.layers = new NeuronLayer[layers.length];
for( int i = 0; i < layers.length; i++ ) {
this.layers[i] = layers[i];
if( i > 0 ) {
this.layers[i-1].setNextLayer(this.layers[i]);
}
}
} }
public int getLayerCount() { public int getLayerCount() {
return layers.length; return layers.length;
} }
public NeuronLayer[] getLayers() {
return layers;
}
public NeuronLayer getLayer( int i ) { public NeuronLayer getLayer( int i ) {
return layers[i - 1]; return layers[i - 1];

View File

@@ -1,6 +1,6 @@
package schule.ngb.zm.util; package schule.ngb.zm.util;
import java.io.*; import java.io.IOException;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@@ -8,7 +8,6 @@ import java.nio.file.Files;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedList;
import java.util.List; import java.util.List;
public final class FileLoader { public final class FileLoader {

View File

@@ -34,7 +34,7 @@ public class FontLoader {
} }
// Load userfonts // Load userfonts
try( InputStream in = ResourceStreamProvider.getResourceStream(source) ) { try( InputStream in = ResourceStreamProvider.getInputStream(source) ) {
font = Font.createFont(Font.TRUETYPE_FONT, in).deriveFont(Font.PLAIN); font = Font.createFont(Font.TRUETYPE_FONT, in).deriveFont(Font.PLAIN);
if( font != null ) { if( font != null ) {

View File

@@ -13,9 +13,7 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.lang.ref.SoftReference; import java.lang.ref.SoftReference;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Logger;
public final class ImageLoader { public final class ImageLoader {
@@ -29,7 +27,7 @@ public final class ImageLoader {
* Lädt ein Bild von der angegebenen Quelle {@code source}. * Lädt ein Bild von der angegebenen Quelle {@code source}.
* <p> * <p>
* Die Bilddatei wird nach den Regeln von * Die Bilddatei wird nach den Regeln von
* {@link ResourceStreamProvider#getResourceStream(String)} gesucht und * {@link ResourceStreamProvider#getInputStream(String)} gesucht und
* geöffnet. Tritt dabei ein Fehler auf oder konnte keine passende Datei * geöffnet. Tritt dabei ein Fehler auf oder konnte keine passende Datei
* gefunden werden, wird {@code null} zurückgegeben. * gefunden werden, wird {@code null} zurückgegeben.
* <p> * <p>
@@ -70,7 +68,7 @@ public final class ImageLoader {
} }
BufferedImage img = null; BufferedImage img = null;
try( InputStream in = ResourceStreamProvider.getResourceStream(source) ) { try( InputStream in = ResourceStreamProvider.getInputStream(source) ) {
//URL url = ResourceStreamProvider.getResourceURL(source); //URL url = ResourceStreamProvider.getResourceURL(source);
//BufferedImage img = ImageIO.read(url); //BufferedImage img = ImageIO.read(url);

View File

@@ -3,10 +3,9 @@ package schule.ngb.zm.util;
import schule.ngb.zm.Zeichenmaschine; import schule.ngb.zm.Zeichenmaschine;
import java.io.*; import java.io.*;
import java.net.URISyntaxException;
import java.net.URL; import java.net.URL;
import java.util.logging.Level; import java.util.stream.StreamSupport;
import java.util.logging.LogManager;
import java.util.logging.Logger;
/** /**
* Helferklasse, um {@link InputStream}s für Resourcen zu erhalten. * Helferklasse, um {@link InputStream}s für Resourcen zu erhalten.
@@ -50,15 +49,11 @@ public class ResourceStreamProvider {
* einer bestehenden Resource oder falls * einer bestehenden Resource oder falls
* keine passende Resource gefunden wurde. * keine passende Resource gefunden wurde.
*/ */
public static InputStream getResourceStream( String source ) throws NullPointerException, IllegalArgumentException, IOException { public static InputStream getInputStream( String source ) throws NullPointerException, IllegalArgumentException, IOException {
if( source == null ) { Validator.requireNotNull(source, "Resource source may not be null");
throw new NullPointerException("Resource source may not be null"); Validator.requireNotEmpty(source, "Resource source may not be empty.");
}
if( source.length() == 0 ) {
throw new IllegalArgumentException("Resource source may not be empty.");
}
InputStream in; InputStream in = null;
// See if source is a readable file // See if source is a readable file
File file = new File(source); File file = new File(source);
@@ -72,7 +67,9 @@ public class ResourceStreamProvider {
} }
// File does not exist, try other means // File does not exist, try other means
// load ressource relative to .class-file // load ressource relative to .class-file
if( in == null ) {
in = Zeichenmaschine.class.getResourceAsStream(source); in = Zeichenmaschine.class.getResourceAsStream(source);
}
// relative to ClassLoader // relative to ClassLoader
if( in == null ) { if( in == null ) {
@@ -89,6 +86,16 @@ public class ResourceStreamProvider {
return in; return in;
} }
public static InputStream getInputStream( File file ) throws IOException {
Validator.requireNotNull(file, "Provided file can't be null.");
return new FileInputStream(file);
}
public static InputStream getInputStream( URL url ) throws IOException {
Validator.requireNotNull(url, "Provided URL can't be null.");
return url.openStream();
}
/** /**
* Ermittelt zur angegebenen Quelle einen passenden {@link URL} (<em>Unified * Ermittelt zur angegebenen Quelle einen passenden {@link URL} (<em>Unified
* Resource Locator</em>). Eine passende Datei-Resource wird wie folgt * Resource Locator</em>). Eine passende Datei-Resource wird wie folgt
@@ -119,12 +126,8 @@ public class ResourceStreamProvider {
* einer bestehenden Resource. * einer bestehenden Resource.
*/ */
public static URL getResourceURL( String source ) throws NullPointerException, IllegalArgumentException, IOException { public static URL getResourceURL( String source ) throws NullPointerException, IllegalArgumentException, IOException {
if( source == null ) { Validator.requireNotNull(source, "Resource source may not be null");
throw new NullPointerException("Resource source may not be null"); Validator.requireNotEmpty(source, "Resource source may not be empty.");
}
if( source.length() == 0 ) {
throw new IllegalArgumentException("Resource source may not be empty.");
}
File file = new File(source); File file = new File(source);
if( file.isFile() ) { if( file.isFile() ) {
@@ -146,20 +149,24 @@ public class ResourceStreamProvider {
return new URL(source); return new URL(source);
} }
public static InputStream getResourceStream( File file ) throws FileNotFoundException, SecurityException { public static OutputStream getOutputStream( String source ) throws IOException {
if( file == null ) { Validator.requireNotNull(source, "Resource source may not be null");
throw new NullPointerException("Provided file can't be null."); Validator.requireNotEmpty(source, "Resource source may not be empty.");
return getOutputStream(new File(source));
} }
return new FileInputStream(file); public static OutputStream getOutputStream( File file ) throws IOException {
Validator.requireNotNull(file, "Provided file can't be null.");
return new FileOutputStream(file);
} }
public static InputStream getResourceStream( URL url ) throws IOException { public static Reader getReader( String source ) throws IOException {
if( url == null ) { return new InputStreamReader(getInputStream(source));
throw new NullPointerException("Provided URL can't be null.");
} }
return url.openStream(); public static Writer getWriter( String source ) throws IOException {
return new OutputStreamWriter(getOutputStream(source));
} }
private ResourceStreamProvider() { private ResourceStreamProvider() {