feed_forward: Vec -> Vec
This commit is contained in:
parent
80fb255c36
commit
44e2cb36f4
|
@ -17,8 +17,8 @@ async fn main() {
|
||||||
};
|
};
|
||||||
set_camera(&cam);
|
set_camera(&cam);
|
||||||
let mut world = World::new();
|
let mut world = World::new();
|
||||||
let nn = NN::new(vec![2, 3, 2]);
|
let nn = NN::new(vec![2, 3, 3]);
|
||||||
nn.feed_forward(vec![2., 3.]);
|
println!("{:?}", nn.feed_forward(vec![2., 3.]));
|
||||||
loop {
|
loop {
|
||||||
// clear_background(BLACK);
|
// clear_background(BLACK);
|
||||||
// if !world.over {
|
// if !world.over {
|
||||||
|
|
|
@ -34,14 +34,14 @@ impl NN {
|
||||||
* (2. / last as f32).sqrt()
|
* (2. / last as f32).sqrt()
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
|
|
||||||
activ_func: ActivationFunc::ReLU,
|
activ_func: ActivationFunc::ReLU,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn feed_forward(&self, inputs: Vec<f32>) {
|
pub fn feed_forward(&self, inputs: Vec<f32>) -> Vec<f32> {
|
||||||
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs);
|
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs);
|
||||||
for i in 0..self.config.len() - 1 {
|
for i in 0..self.config.len() - 1 {
|
||||||
println!("{} {}", y, self.weights[i]);
|
|
||||||
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
|
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
|
||||||
match self.activ_func {
|
match self.activ_func {
|
||||||
ActivationFunc::ReLU => x.max(0.),
|
ActivationFunc::ReLU => x.max(0.),
|
||||||
|
@ -50,6 +50,6 @@ impl NN {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
println!("{}", y);
|
y.column(0).data.into_slice().to_vec()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue