feed_forward: Vec -> Vec
This commit is contained in:
parent
80fb255c36
commit
44e2cb36f4
|
@ -17,8 +17,8 @@ async fn main() {
|
|||
};
|
||||
set_camera(&cam);
|
||||
let mut world = World::new();
|
||||
let nn = NN::new(vec![2, 3, 2]);
|
||||
nn.feed_forward(vec![2., 3.]);
|
||||
let nn = NN::new(vec![2, 3, 3]);
|
||||
println!("{:?}", nn.feed_forward(vec![2., 3.]));
|
||||
loop {
|
||||
// clear_background(BLACK);
|
||||
// if !world.over {
|
||||
|
|
|
@ -34,14 +34,14 @@ impl NN {
|
|||
* (2. / last as f32).sqrt()
|
||||
})
|
||||
.collect(),
|
||||
|
||||
activ_func: ActivationFunc::ReLU,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn feed_forward(&self, inputs: Vec<f32>) {
|
||||
pub fn feed_forward(&self, inputs: Vec<f32>) -> Vec<f32> {
|
||||
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs);
|
||||
for i in 0..self.config.len() - 1 {
|
||||
println!("{} {}", y, self.weights[i]);
|
||||
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
|
||||
match self.activ_func {
|
||||
ActivationFunc::ReLU => x.max(0.),
|
||||
|
@ -50,6 +50,6 @@ impl NN {
|
|||
}
|
||||
});
|
||||
}
|
||||
println!("{}", y);
|
||||
y.column(0).data.into_slice().to_vec()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue