From 3a486d47bd5c6fd3823c7d14d000072df6f53784 Mon Sep 17 00:00:00 2001 From: Braden Everson Date: Mon, 12 Feb 2024 06:30:31 -0600 Subject: [PATCH] Add more inputs to test --- src/main.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main.rs b/src/main.rs index 978661b..098aaff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ async fn main() { println!("Done Generating MNIST"); outputs = to_categorical(outputs_uncat); - for i in 0..600{ + for i in 0..3000{ inputs.push(&inputs_undyn[i]); true_outputs.push(outputs[i].clone()); } @@ -24,15 +24,15 @@ async fn main() { let mut network = Network::new(128); network.set_input(InputTypes::DENSE(784)); + network.add_layer(LayerTypes::DENSE(256, Activations::RELU, 0.001)); network.add_layer(LayerTypes::DENSE(64, Activations::RELU, 0.001)); - network.add_layer(LayerTypes::DENSE(32, Activations::RELU, 0.001)); network.add_layer(LayerTypes::DENSE(10, Activations::SOFTMAX, 0.001)); //network.set_log(false); network.compile(); - network.fit(&inputs, &true_outputs, 5, ErrorTypes::CategoricalCrossEntropy); + network.fit(&inputs, &true_outputs, 1, ErrorTypes::CategoricalCrossEntropy); for i in 0..10{ println!("predicted: {:?} \n\n actual: {:?}\n", network.predict(inputs[i]), true_outputs[i]); }