From 1c29b6d4191ced87a0e151f02678963203c94b63 Mon Sep 17 00:00:00 2001 From: sparshg <43041139+sparshg@users.noreply.github.com> Date: Sun, 23 Oct 2022 01:10:19 +0530 Subject: [PATCH] bug fixes --- src/asteroids.rs | 2 +- src/main.rs | 4 +-- src/nn.rs | 9 +++++-- src/player.rs | 62 ++++++++++++++++++++++++++--------------------- src/population.rs | 40 ++++++++++++++++++------------ src/world.rs | 61 +++++++++++++++++++++++++++++----------------- 6 files changed, 109 insertions(+), 69 deletions(-) diff --git a/src/asteroids.rs b/src/asteroids.rs index 1cb7e7d..ea667a3 100644 --- a/src/asteroids.rs +++ b/src/asteroids.rs @@ -89,7 +89,7 @@ impl Asteroid { AsteroidSize::Medium => 1.2, AsteroidSize::Small => 0.8, }, - Color::new(1., 1., 1., 0.4), + Color::new(1., 1., 1., 1.), ); } } diff --git a/src/main.rs b/src/main.rs index 07fd409..17f70cf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,9 +16,9 @@ async fn main() { ..Default::default() }; set_camera(&cam); - let mut pop = Population::new(10); + let mut pop = Population::new(100); let mut speedup = false; - // for _ in 0..100000 * 5 { + // for _ in 0..10000 * 10 { // pop.update(); // } loop { diff --git a/src/nn.rs b/src/nn.rs index bb2b7c9..82610f9 100644 --- a/src/nn.rs +++ b/src/nn.rs @@ -36,13 +36,16 @@ impl NN { .iter() .zip(config.iter().skip(1)) .map(|(&curr, &last)| { + // let a = DMatrix::::new_random(last, curr + 1); + // println!("{}", a); + // a DMatrix::::from_distribution(last, curr + 1, &StandardNormal, &mut rng) * (2. / last as f32).sqrt() }) .collect(), activ_func: ActivationFunc::ReLU, - mut_rate: 0.05, + mut_rate: 0.02, } } @@ -64,8 +67,10 @@ impl NN { pub fn mutate(&mut self) { for weight in &mut self.weights { for ele in weight { - if gen_range(0., 1.) < 0.05 { + if gen_range(0., 1.) < self.mut_rate { + // *ele += gen_range(-1., 1.); *ele = r::thread_rng().sample::(StandardNormal); + // *ele = r::thread_rng().sample::(StandardNormal); } } } diff --git a/src/player.rs b/src/player.rs index 33a6aac..e8bd9cb 100644 --- a/src/player.rs +++ b/src/player.rs @@ -5,19 +5,21 @@ use macroquad::{prelude::*, rand::gen_range}; use crate::{asteroids::Asteroid, nn::NN}; #[derive(Default)] pub struct Player { - pos: Vec2, - vel: Vec2, + pub pos: Vec2, + pub vel: Vec2, + acc: f32, dir: Vec2, rot: f32, drag: f32, bullets: Vec, - last_shot: f32, - shot_interval: f32, + last_shot: u8, + shot_interval: u8, pub brain: Option, asteroids_data: Vec, max_asteroids: usize, debug: bool, alive: bool, + pub color: Option, pub lifespan: u32, pub shots: u32, } @@ -30,21 +32,25 @@ impl Player { // Change scaling when passing inputs if this is changed drag: 0.001, - shot_interval: 0.3, + shot_interval: 18, alive: true, debug: false, ..Default::default() } } - pub fn simulate(brain: NN, max_asteroids: usize) -> Self { - assert_eq!( - brain.config[0] - 1, - max_asteroids * 3 + 5, - "NN input size must match max_asteroids" - ); + pub fn simulate(brain: Option, max_asteroids: usize) -> Self { let mut p = Player::new(); - p.brain = Some(brain); + if let Some(brain) = brain { + assert_eq!( + brain.config[0] - 1, + max_asteroids * 3 + 5, + "NN input size must match max_asteroids" + ); + p.brain = Some(brain); + } else { + p.brain = Some(NN::new(vec![max_asteroids * 3 + 5, 16, 4])); + } p.max_asteroids = max_asteroids; p } @@ -75,7 +81,8 @@ impl Player { pub fn update(&mut self) { self.lifespan += 1; - let mut mag = 0.; + self.last_shot += 1; + self.acc = 0.; let mut keys = vec![false, false, false, false]; self.asteroids_data.resize(self.max_asteroids * 3, 0.); let mut inputs = vec![ @@ -87,11 +94,8 @@ impl Player { ]; inputs.append(self.asteroids_data.as_mut()); if let Some(brain) = &self.brain { - keys = brain - .feed_forward(inputs) - .iter() - .map(|&x| if x > 0. { true } else { false }) - .collect(); + // println!("{:?}", inputs); + keys = brain.feed_forward(inputs).iter().map(|&x| x > 0.).collect(); } if is_key_down(KeyCode::Right) || keys[0] { self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32; @@ -103,11 +107,11 @@ impl Player { } if is_key_down(KeyCode::Up) || keys[2] { // Change scaling when passing inputs if this is changed - mag = 0.14; + self.acc = 0.14; } if is_key_down(KeyCode::Space) || keys[3] { - if self.shot_interval + self.last_shot < get_time() as f32 { - self.last_shot = get_time() as f32; + if self.last_shot > self.shot_interval { + self.last_shot = 0; self.shots += 1; self.bullets.push(Bullet { pos: self.pos + self.dir.rotate(vec2(20., 0.)), @@ -121,7 +125,7 @@ impl Player { self.debug = !self.debug; } - self.vel += mag * self.dir - self.drag * self.vel.length() * self.vel; + self.vel += self.acc * self.dir - self.drag * self.vel.length() * self.vel; self.pos += self.vel; if self.pos.x.abs() > screen_width() * 0.5 + 10. { self.pos.x *= -1.; @@ -139,6 +143,10 @@ impl Player { } pub fn draw(&self) { + let color = match self.color { + Some(c) => c, + None => Color::new(1., 1., 1., 0.1), + }; let p1 = self.pos + self.dir.rotate(vec2(20., 0.)); let p2 = self.pos + self.dir.rotate(vec2(-18., -12.667)); let p3 = self.pos + self.dir.rotate(vec2(-18., 12.667)); @@ -147,11 +155,11 @@ impl Player { let p6 = self.pos + self.dir.rotate(vec2(-25., 0.)); let p7 = self.pos + self.dir.rotate(vec2(-10., -6.)); let p8 = self.pos + self.dir.rotate(vec2(-10., 6.)); - draw_line(p1.x, p1.y, p2.x, p2.y, 2., WHITE); - draw_line(p1.x, p1.y, p3.x, p3.y, 2., WHITE); - draw_line(p4.x, p4.y, p5.x, p5.y, 2., WHITE); - if is_key_down(KeyCode::Up) && gen_range(0., 1.) < 0.4 { - draw_triangle_lines(p6, p7, p8, 2., WHITE); + draw_line(p1.x, p1.y, p2.x, p2.y, 2., color); + draw_line(p1.x, p1.y, p3.x, p3.y, 2., color); + draw_line(p4.x, p4.y, p5.x, p5.y, 2., color); + if self.acc > 0. && gen_range(0., 1.) < 0.4 { + draw_triangle_lines(p6, p7, p8, 2., color); } if self.debug { diff --git a/src/population.rs b/src/population.rs index 483f375..8f6a214 100644 --- a/src/population.rs +++ b/src/population.rs @@ -13,9 +13,7 @@ impl Population { pub fn new(size: usize) -> Self { Self { size, - worlds: (0..size) - .map(|_| World::simulate(NN::new(vec![89, 16, 4]))) - .collect(), + worlds: (0..size).map(|_| World::simulate(None)).collect(), ..Default::default() } } @@ -30,18 +28,17 @@ impl Population { } if !alive { self.gen += 1; - println!("{}", self.gen); self.next_gen(); } } pub fn draw(&self) { - for world in &self.worlds { + for world in self.worlds.iter().rev() { if !world.over { world.draw(); draw_text( &format!("Gen: {}", self.gen), - -100. + screen_width() * 0.5, + -150. + screen_width() * 0.5, 30. - screen_height() * 0.5, 32., WHITE, @@ -51,26 +48,33 @@ impl Population { } pub fn next_gen(&mut self) { - let total = self.worlds.iter().fold(0., |acc, x| acc + x.fitness()); + let total = self.worlds.iter().fold(0., |acc, x| acc + x.fitness); self.worlds - .sort_by(|a, b| b.fitness().partial_cmp(&a.fitness()).unwrap()); - let mut new_worlds = (0..self.size / 10) - .map(|i| World::simulate(self.worlds[i].see_brain().to_owned())) + .sort_by(|a, b| b.fitness.partial_cmp(&a.fitness).unwrap()); + for i in &self.worlds { + println!("Fitness: {}", i.fitness); + } + println!("Gen: {}, Fitness: {}", self.gen, self.worlds[0].fitness); + // let mut new_worlds = vec![World::simulate(Some(self.worlds[0].see_brain().to_owned()))]; + let mut new_worlds = (0..self.size / 20) + .map(|i| World::simulate(Some(self.worlds[i].see_brain().to_owned()))) .collect::>(); + // if is_key_down(KeyCode::K) { + new_worlds[0].set_best(); + // } // println!( // "Total fitness: {} {} {}", // total, // self.worlds[0].fitness(), // self.worlds[1].fitness() // ); - while new_worlds.len() < self.size { let rands = (gen_range(0., total), gen_range(0., total)); - // println!("rands: {} {}", rands.0, rands.1); + // println!("rands: {} {} {}", rands.0, rands.1, total); let mut sum = 0.; let (mut a, mut b) = (None, None); for world in &self.worlds { - sum += world.fitness(); + sum += world.fitness; if a.is_none() && sum >= rands.0 { a = Some(world.see_brain()); } @@ -80,12 +84,18 @@ impl Population { } // println!("{}", &a.unwrap().weights[0]); // println!("{}", &b.unwrap().weights[0]); + if a.is_none() { + a = Some(self.worlds.last().unwrap().see_brain()); + } + if b.is_none() { + b = Some(self.worlds.last().unwrap().see_brain()); + } let mut new_brain = NN::crossover(a.unwrap(), b.unwrap()); // println!("{}", &a.unwrap().weights[0]); // println!("{}", &b.unwrap().weights[0]); - // println!("{}", &new_brain.weights[0]); new_brain.mutate(); - new_worlds.push(World::simulate(new_brain)); + // println!("{}", &new_brain.weights[0]); + new_worlds.push(World::simulate(Some(new_brain))); } self.worlds = new_worlds; } diff --git a/src/world.rs b/src/world.rs index aa17faf..234f84e 100644 --- a/src/world.rs +++ b/src/world.rs @@ -12,6 +12,7 @@ pub struct World { pub score: u32, pub over: bool, max_asteroids: usize, + pub fitness: f32, } impl World { @@ -23,7 +24,7 @@ impl World { } } - pub fn simulate(brain: NN) -> Self { + pub fn simulate(brain: Option) -> Self { Self { player: Player::simulate(brain, 28), max_asteroids: 28, @@ -31,39 +32,55 @@ impl World { } } + pub fn set_best(&mut self) { + self.player.color = Some(RED); + } + pub fn see_brain(&self) -> &NN { self.player.brain.as_ref().unwrap() } - pub fn fitness(&self) -> f32 { - // println!( - // "{} {} {}", - // self.score as f32, - // self.player.lifespan as f32 * 0.001, - // if self.player.shots > 0 { - // self.score as f32 / self.player.shots as f32 * 5. - // } else { - // 0. - // } - // ); - (self.score + 1) as f32 - * 10. - * self.player.lifespan as f32 - * if self.player.shots > 0 { - (self.score as f32 / self.player.shots as f32) - * (self.score as f32 / self.player.shots as f32) - } else { - 1. - } - } + // fn calc_fitness(&mut self) { + // println!( + // "{} {} {}", + // self.score as f32, + // self.player.lifespan as f32 * 0.001, + // if self.player.shots > 0 { + // self.score as f32 / self.player.shots as f32 * 5. + // } else { + // 0. + // } + // ); + // } pub fn update(&mut self) { self.player.update(); + // if self.player.lifespan > 150 { + // self.fitness = 1. + // / ((self.player.pos * vec2(2. / screen_width(), 2. / screen_height())) + // .distance_squared(vec2(0., -1.)) + // + self.player.vel.length_squared() + // * self.player.vel.length_squared() + // * 0.00006830134554 + // + 1.); + // self.over = true; + // } let mut to_add: Vec = Vec::new(); for asteroid in &mut self.asteroids { asteroid.update(); if self.player.check_player_collision(asteroid) { self.over = true; + self.fitness = (self.score as f32 + 1.) + * if self.player.shots > 0 { + (self.score as f32 / self.player.shots as f32) + * (self.score as f32 / self.player.shots as f32) + } else { + 1. + } + * self.player.lifespan as f32; + // self.fitness = self.player.lifespan as f32 * self.player.lifespan as f32 * 0.001; + + // println!("{} {} {}", self.score, self.player.lifespan, self.fitness); } if self.player.check_bullet_collisions(asteroid) { self.score += 1;