Laden und speichern von Netzen ermöglicht

This commit is contained in:
ngb
2022-07-15 19:42:13 +02:00
parent d5abd4ef68
commit b24eec5063
5 changed files with 166 additions and 52 deletions

View File

@@ -1,9 +1,110 @@
package schule.ngb.zm.ml;
import schule.ngb.zm.util.Log;
import schule.ngb.zm.util.ResourceStreamProvider;
import java.io.*;
import java.util.LinkedList;
import java.util.List;
import java.util.NoSuchElementException;
public class NeuralNetwork {
public static void saveToFile( String source, NeuralNetwork network ) {
try(
Writer writer = ResourceStreamProvider.getWriter(source);
PrintWriter out = new PrintWriter(writer)
) {
for( NeuronLayer layer: network.layers ) {
out.print(layer.getNeuronCount());
out.print(' ');
out.print(layer.getInputCount());
out.println();
for( int i = 0; i < layer.getInputCount(); i++ ) {
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
out.print(layer.weights.coefficients[i][j]);
out.print(' ');
}
out.println();
}
for( int j = 0; j < layer.getNeuronCount(); j++ ) {
out.print(layer.biases[j]);
out.print(' ');
}
out.println();
}
out.flush();
} catch( IOException ex ) {
LOG.warn(ex, "");
}
}
public static NeuralNetwork loadFromFile( String source ) {
try(
Reader reader = ResourceStreamProvider.getReader(source);
BufferedReader in = new BufferedReader(reader)
) {
List<NeuronLayer> layers = new LinkedList<>();
String line;
while( (line = in.readLine()) != null ) {
String[] split = line.split(" ");
int neurons = Integer.parseInt(split[0]);
int inputs = Integer.parseInt(split[1]);
NeuronLayer layer = new NeuronLayer(neurons, inputs);
for( int i = 0; i < inputs; i++ ) {
split = in.readLine().split(" ");
for( int j = 0; j < neurons; j++ ) {
layer.weights.coefficients[i][j] = Double.parseDouble(split[j]);
}
}
// Load Biases
split = in.readLine().split(" ");
for( int j = 0; j < neurons; j++ ) {
layer.biases[j] = Double.parseDouble(split[j]);
}
layers.add(layer);
}
return new NeuralNetwork(layers);
} catch( IOException | NoSuchElementException ex ) {
LOG.warn(ex, "Could not load neural network from source <%s>", source);
}
return null;
}
/*public static NeuralNetwork loadFromFile( String source ) {
try(
InputStream stream = ResourceStreamProvider.getInputStream(source);
Scanner in = new Scanner(stream)
) {
List<NeuronLayer> layers = new LinkedList<>();
while( in.hasNext() ) {
int neurons = in.nextInt();
int inputs = in.nextInt();
NeuronLayer layer = new NeuronLayer(neurons, inputs);
for( int i = 0; i < inputs; i++ ) {
for( int j = 0; j < neurons; j++ ) {
layer.weights.coefficients[i][j] = in.nextDouble();
}
}
for( int j = 0; j < neurons; j++ ) {
layer.biases[j] = in.nextDouble();
}
layers.add(layer);
}
return new NeuralNetwork(layers);
} catch( IOException | NoSuchElementException ex ) {
LOG.warn(ex, "Could not load neural network from source <%s>", source);
}
return null;
}*/
private NeuronLayer[] layers;
private double[][] output;
@@ -18,27 +119,36 @@ public class NeuralNetwork {
this(new NeuronLayer(layer1, inputs), new NeuronLayer(layer2, layer1), new NeuronLayer(outputs, layer2));
}
public NeuralNetwork( NeuronLayer layer1, NeuronLayer layer2 ) {
this.layers = new NeuronLayer[2];
this.layers[0] = layer1;
this.layers[1] = layer2;
layer1.connect(null, layer2);
layer2.connect(layer1, null);
public NeuralNetwork( int inputs, int layer1, int layer2, int layer3, int outputs ) {
this(new NeuronLayer(layer1, inputs), new NeuronLayer(layer2, layer1), new NeuronLayer(layer3, layer2), new NeuronLayer(outputs, layer3));
}
public NeuralNetwork( NeuronLayer layer1, NeuronLayer layer2, NeuronLayer layer3 ) {
this.layers = new NeuronLayer[3];
this.layers[0] = layer1;
this.layers[1] = layer2;
this.layers[2] = layer3;
layer1.connect(null, layer2);
layer2.connect(layer1, layer3);
layer3.connect(layer2, null);
public NeuralNetwork( List<NeuronLayer> layers ) {
this.layers = new NeuronLayer[layers.size()];
for( int i = 0; i < layers.size(); i++ ) {
this.layers[i] = layers.get(i);
if( i > 0 ) {
this.layers[i-1].setNextLayer(this.layers[i]);
}
}
}
public NeuralNetwork( NeuronLayer... layers ) {
this.layers = new NeuronLayer[layers.length];
for( int i = 0; i < layers.length; i++ ) {
this.layers[i] = layers[i];
if( i > 0 ) {
this.layers[i-1].setNextLayer(this.layers[i]);
}
}
}
public int getLayerCount() {
return layers.length;
}
public NeuronLayer[] getLayers() {
return layers;
}
public NeuronLayer getLayer( int i ) {
return layers[i - 1];
@@ -63,7 +173,7 @@ public class NeuralNetwork {
}
public void learn( double[][] expected ) {
layers[layers.length-1].backprop(expected, learningRate);
layers[layers.length - 1].backprop(expected, learningRate);
}
public void train( double[][] inputs, double[][] expected, int iterations/*, double minChange, int timeout */ ) {

View File

@@ -1,6 +1,6 @@
package schule.ngb.zm.util;
import java.io.*;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
@@ -8,7 +8,6 @@ import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
public final class FileLoader {
@@ -61,10 +60,10 @@ public final class FileLoader {
.lines(Paths.get(ResourceStreamProvider.getResourceURL(source).toURI()), charset)
.skip(n)
.map(
(line) -> Arrays
( line ) -> Arrays
.stream(line.split(Character.toString(separator)))
.mapToDouble(
(value) -> {
( value ) -> {
try {
return Double.parseDouble(value);
} catch( NumberFormatException nfe ) {

View File

@@ -34,7 +34,7 @@ public class FontLoader {
}
// Load userfonts
try( InputStream in = ResourceStreamProvider.getResourceStream(source) ) {
try( InputStream in = ResourceStreamProvider.getInputStream(source) ) {
font = Font.createFont(Font.TRUETYPE_FONT, in).deriveFont(Font.PLAIN);
if( font != null ) {

View File

@@ -13,9 +13,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.lang.ref.SoftReference;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Logger;
public final class ImageLoader {
@@ -29,7 +27,7 @@ public final class ImageLoader {
* Lädt ein Bild von der angegebenen Quelle {@code source}.
* <p>
* Die Bilddatei wird nach den Regeln von
* {@link ResourceStreamProvider#getResourceStream(String)} gesucht und
* {@link ResourceStreamProvider#getInputStream(String)} gesucht und
* geöffnet. Tritt dabei ein Fehler auf oder konnte keine passende Datei
* gefunden werden, wird {@code null} zurückgegeben.
* <p>
@@ -70,7 +68,7 @@ public final class ImageLoader {
}
BufferedImage img = null;
try( InputStream in = ResourceStreamProvider.getResourceStream(source) ) {
try( InputStream in = ResourceStreamProvider.getInputStream(source) ) {
//URL url = ResourceStreamProvider.getResourceURL(source);
//BufferedImage img = ImageIO.read(url);

View File

@@ -3,10 +3,9 @@ package schule.ngb.zm.util;
import schule.ngb.zm.Zeichenmaschine;
import java.io.*;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.logging.Level;
import java.util.logging.LogManager;
import java.util.logging.Logger;
import java.util.stream.StreamSupport;
/**
* Helferklasse, um {@link InputStream}s für Resourcen zu erhalten.
@@ -50,15 +49,11 @@ public class ResourceStreamProvider {
* einer bestehenden Resource oder falls
* keine passende Resource gefunden wurde.
*/
public static InputStream getResourceStream( String source ) throws NullPointerException, IllegalArgumentException, IOException {
if( source == null ) {
throw new NullPointerException("Resource source may not be null");
}
if( source.length() == 0 ) {
throw new IllegalArgumentException("Resource source may not be empty.");
}
public static InputStream getInputStream( String source ) throws NullPointerException, IllegalArgumentException, IOException {
Validator.requireNotNull(source, "Resource source may not be null");
Validator.requireNotEmpty(source, "Resource source may not be empty.");
InputStream in;
InputStream in = null;
// See if source is a readable file
File file = new File(source);
@@ -72,7 +67,9 @@ public class ResourceStreamProvider {
}
// File does not exist, try other means
// load ressource relative to .class-file
if( in == null ) {
in = Zeichenmaschine.class.getResourceAsStream(source);
}
// relative to ClassLoader
if( in == null ) {
@@ -89,6 +86,16 @@ public class ResourceStreamProvider {
return in;
}
public static InputStream getInputStream( File file ) throws IOException {
Validator.requireNotNull(file, "Provided file can't be null.");
return new FileInputStream(file);
}
public static InputStream getInputStream( URL url ) throws IOException {
Validator.requireNotNull(url, "Provided URL can't be null.");
return url.openStream();
}
/**
* Ermittelt zur angegebenen Quelle einen passenden {@link URL} (<em>Unified
* Resource Locator</em>). Eine passende Datei-Resource wird wie folgt
@@ -119,12 +126,8 @@ public class ResourceStreamProvider {
* einer bestehenden Resource.
*/
public static URL getResourceURL( String source ) throws NullPointerException, IllegalArgumentException, IOException {
if( source == null ) {
throw new NullPointerException("Resource source may not be null");
}
if( source.length() == 0 ) {
throw new IllegalArgumentException("Resource source may not be empty.");
}
Validator.requireNotNull(source, "Resource source may not be null");
Validator.requireNotEmpty(source, "Resource source may not be empty.");
File file = new File(source);
if( file.isFile() ) {
@@ -146,20 +149,24 @@ public class ResourceStreamProvider {
return new URL(source);
}
public static InputStream getResourceStream( File file ) throws FileNotFoundException, SecurityException {
if( file == null ) {
throw new NullPointerException("Provided file can't be null.");
public static OutputStream getOutputStream( String source ) throws IOException {
Validator.requireNotNull(source, "Resource source may not be null");
Validator.requireNotEmpty(source, "Resource source may not be empty.");
return getOutputStream(new File(source));
}
return new FileInputStream(file);
public static OutputStream getOutputStream( File file ) throws IOException {
Validator.requireNotNull(file, "Provided file can't be null.");
return new FileOutputStream(file);
}
public static InputStream getResourceStream( URL url ) throws IOException {
if( url == null ) {
throw new NullPointerException("Provided URL can't be null.");
public static Reader getReader( String source ) throws IOException {
return new InputStreamReader(getInputStream(source));
}
return url.openStream();
public static Writer getWriter( String source ) throws IOException {
return new OutputStreamWriter(getOutputStream(source));
}
private ResourceStreamProvider() {