From 44e2cb36f480096b5bfab0b1f264e17606970e94 Mon Sep 17 00:00:00 2001 From: sparshg <43041139+sparshg@users.noreply.github.com> Date: Mon, 10 Oct 2022 00:03:00 +0530 Subject: [PATCH] feed_forward: Vec -> Vec --- src/main.rs | 4 ++-- src/nn.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/main.rs b/src/main.rs index f15b63b..5506004 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,8 +17,8 @@ async fn main() { }; set_camera(&cam); let mut world = World::new(); - let nn = NN::new(vec![2, 3, 2]); - nn.feed_forward(vec![2., 3.]); + let nn = NN::new(vec![2, 3, 3]); + println!("{:?}", nn.feed_forward(vec![2., 3.])); loop { // clear_background(BLACK); // if !world.over { diff --git a/src/nn.rs b/src/nn.rs index 2be989b..55f094b 100644 --- a/src/nn.rs +++ b/src/nn.rs @@ -34,14 +34,14 @@ impl NN { * (2. / last as f32).sqrt() }) .collect(), + activ_func: ActivationFunc::ReLU, } } - pub fn feed_forward(&self, inputs: Vec) { + pub fn feed_forward(&self, inputs: Vec) -> Vec { let mut y = DMatrix::from_vec(inputs.len(), 1, inputs); for i in 0..self.config.len() - 1 { - println!("{} {}", y, self.weights[i]); y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| { match self.activ_func { ActivationFunc::ReLU => x.max(0.), @@ -50,6 +50,6 @@ impl NN { } }); } - println!("{}", y); + y.column(0).data.into_slice().to_vec() } }