diff --git a/src/main.rs b/src/main.rs index f15b63b..5506004 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,8 +17,8 @@ async fn main() { }; set_camera(&cam); let mut world = World::new(); - let nn = NN::new(vec![2, 3, 2]); - nn.feed_forward(vec![2., 3.]); + let nn = NN::new(vec![2, 3, 3]); + println!("{:?}", nn.feed_forward(vec![2., 3.])); loop { // clear_background(BLACK); // if !world.over { diff --git a/src/nn.rs b/src/nn.rs index 2be989b..55f094b 100644 --- a/src/nn.rs +++ b/src/nn.rs @@ -34,14 +34,14 @@ impl NN { * (2. / last as f32).sqrt() }) .collect(), + activ_func: ActivationFunc::ReLU, } } - pub fn feed_forward(&self, inputs: Vec) { + pub fn feed_forward(&self, inputs: Vec) -> Vec { let mut y = DMatrix::from_vec(inputs.len(), 1, inputs); for i in 0..self.config.len() - 1 { - println!("{} {}", y, self.weights[i]); y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| { match self.activ_func { ActivationFunc::ReLU => x.max(0.), @@ -50,6 +50,6 @@ impl NN { } }); } - println!("{}", y); + y.column(0).data.into_slice().to_vec() } }