feed_forward: Vec -> Vec

This commit is contained in:
sparshg 2022-10-10 00:03:00 +05:30
parent 80fb255c36
commit 44e2cb36f4
2 changed files with 5 additions and 5 deletions

View File

@ -17,8 +17,8 @@ async fn main() {
};
set_camera(&cam);
let mut world = World::new();
let nn = NN::new(vec![2, 3, 2]);
nn.feed_forward(vec![2., 3.]);
let nn = NN::new(vec![2, 3, 3]);
println!("{:?}", nn.feed_forward(vec![2., 3.]));
loop {
// clear_background(BLACK);
// if !world.over {

View File

@ -34,14 +34,14 @@ impl NN {
* (2. / last as f32).sqrt()
})
.collect(),
activ_func: ActivationFunc::ReLU,
}
}
pub fn feed_forward(&self, inputs: Vec<f32>) {
pub fn feed_forward(&self, inputs: Vec<f32>) -> Vec<f32> {
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs);
for i in 0..self.config.len() - 1 {
println!("{} {}", y, self.weights[i]);
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
match self.activ_func {
ActivationFunc::ReLU => x.max(0.),
@ -50,6 +50,6 @@ impl NN {
}
});
}
println!("{}", y);
y.column(0).data.into_slice().to_vec()
}
}