From 91bc4e27a40e8eadf968fdb56c3fc68042cf38eb Mon Sep 17 00:00:00 2001 From: sparshg <43041139+sparshg@users.noreply.github.com> Date: Mon, 10 Oct 2022 01:16:27 +0530 Subject: [PATCH] crossover, mutation --- Cargo.lock | 1 + Cargo.toml | 2 +- src/main.rs | 8 +++++--- src/nn.rs | 31 ++++++++++++++++++++++++++++++- src/utils.rs | 30 ------------------------------ 5 files changed, 37 insertions(+), 35 deletions(-) delete mode 100644 src/utils.rs diff --git a/Cargo.lock b/Cargo.lock index 556a31d..8b3499f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -291,6 +291,7 @@ dependencies = [ "num-rational", "num-traits", "rand", + "rand_distr", "simba", "typenum", ] diff --git a/Cargo.toml b/Cargo.toml index 6632eba..a09a5b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,6 @@ edition = "2021" [dependencies] macroquad = "0.3.24" -nalgebra = { version = "0.31.1", features = ["rand-no-std"] } +nalgebra = { version = "0.31.1", features = ["rand-no-std", "rand"] } rand = "0.8.5" rand_distr = "0.4.3" diff --git a/src/main.rs b/src/main.rs index 5506004..de7f6de 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ mod asteroids; mod nn; mod player; -mod utils; mod world; use macroquad::prelude::*; @@ -17,8 +16,11 @@ async fn main() { }; set_camera(&cam); let mut world = World::new(); - let nn = NN::new(vec![2, 3, 3]); - println!("{:?}", nn.feed_forward(vec![2., 3.])); + let mut nn = NN::new(vec![1, 2, 1]); + println!("{} {}", nn.weights[0], nn.weights[1]); + nn.mutation(); + println!("{} {}", nn.weights[0], nn.weights[1]); + loop { // clear_background(BLACK); // if !world.over { diff --git a/src/nn.rs b/src/nn.rs index 55f094b..f602dff 100644 --- a/src/nn.rs +++ b/src/nn.rs @@ -1,7 +1,9 @@ use nalgebra::*; +use r::Rng; use rand_distr::StandardNormal; extern crate rand as r; +#[derive(PartialEq, Debug, Clone, Copy)] enum ActivationFunc { Sigmoid, Tanh, @@ -9,8 +11,9 @@ enum ActivationFunc { } pub struct NN { config: Vec, - weights: Vec>, + pub weights: Vec>, activ_func: ActivationFunc, + mut_rate: f32, } impl NN { @@ -36,6 +39,32 @@ impl NN { .collect(), activ_func: ActivationFunc::ReLU, + mut_rate: 0.05, + } + } + + pub fn crossover(a: &NN, b: &NN) -> Self { + assert_eq!(a.config, b.config, "NN configs not same."); + Self { + config: a.config.to_owned(), + activ_func: a.activ_func, + mut_rate: a.mut_rate, + weights: a + .weights + .iter() + .zip(b.weights.iter()) + .map(|(m1, m2)| m1.zip_map(m2, |ele1, ele2| if r::random() { ele1 } else { ele2 })) + .collect(), + } + } + + pub fn mutation(&mut self) { + for weight in &mut self.weights { + for ele in weight { + if r::random() { + *ele = r::thread_rng().sample::(StandardNormal) * 0.05; + } + } } } diff --git a/src/utils.rs b/src/utils.rs deleted file mode 100644 index 970206b..0000000 --- a/src/utils.rs +++ /dev/null @@ -1,30 +0,0 @@ -use macroquad::prelude::*; - -pub fn rotate_vec(vec: Vec2, angle: f32) -> Vec2 { - vec2(angle.cos(), angle.sin()).rotate(vec) -} - -pub fn draw_polygon(x: f32, y: f32, points: Vec, color: Color) { - let points_length = points.len(); - let mut vertices = Vec::with_capacity(points_length as usize + 2); - let mut indices = Vec::::with_capacity(points_length as usize * 3); - - for (i, point) in points.iter().enumerate() { - let vertex = macroquad::models::Vertex { - position: Vec3::new(x + point.x, y + point.y, 0.0), - uv: Vec2::default(), - color, - }; - - vertices.push(vertex); - indices.extend_from_slice(&[0, i as u16 + 1, i as u16 + 2]); - } - - let mesh = Mesh { - vertices, - indices, - texture: None, - }; - - draw_mesh(&mesh); -}