diff --git a/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java b/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java index b1cd4d6..6babdab 100644 --- a/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java +++ b/src/main/java/schule/ngb/zm/ml/NeuralNetwork.java @@ -192,8 +192,11 @@ public class NeuralNetwork { return output; } + public MLMatrix predict( double[] inputs ) { + return predict(MatrixFactory.create(new double[][]{inputs})); + } + public MLMatrix predict( double[][] inputs ) { - //this.output = layers[1].apply(layers[0].apply(inputs)); return predict(MatrixFactory.create(inputs)); }