From 81963835157c295b20587f8f5a60d44e5203467b Mon Sep 17 00:00:00 2001 From: sparshg <43041139+sparshg@users.noreply.github.com> Date: Thu, 5 Jan 2023 01:27:59 +0530 Subject: [PATCH] finally works T_T im ded --- src/nn.rs | 20 +++++-- src/player.rs | 150 ++++++++++++++++++++++++++++++---------------- src/population.rs | 24 +++++--- src/world.rs | 37 ++++++------ 4 files changed, 145 insertions(+), 86 deletions(-) diff --git a/src/nn.rs b/src/nn.rs index 82610f9..e6eafcf 100644 --- a/src/nn.rs +++ b/src/nn.rs @@ -4,14 +4,16 @@ use r::Rng; use rand_distr::StandardNormal; extern crate rand as r; -#[derive(PartialEq, Debug, Clone, Copy)] +#[derive(PartialEq, Debug, Clone, Copy, Default)] + enum ActivationFunc { Sigmoid, Tanh, + #[default] ReLU, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct NN { pub config: Vec, pub weights: Vec>, @@ -36,7 +38,8 @@ impl NN { .iter() .zip(config.iter().skip(1)) .map(|(&curr, &last)| { - // let a = DMatrix::::new_random(last, curr + 1); + // DMatrix::from_fn(last, curr + 1, |_, _| gen_range(-1., 1.)) + // DMatrix::::new_random(last, curr + 1) // println!("{}", a); // a DMatrix::::from_distribution(last, curr + 1, &StandardNormal, &mut rng) @@ -44,8 +47,8 @@ impl NN { }) .collect(), - activ_func: ActivationFunc::ReLU, - mut_rate: 0.02, + mut_rate: 0.04, + ..Default::default() } } @@ -61,6 +64,7 @@ impl NN { .zip(b.weights.iter()) .map(|(m1, m2)| m1.zip_map(m2, |ele1, ele2| if r::random() { ele1 } else { ele2 })) .collect(), + ..Default::default() } } @@ -69,7 +73,8 @@ impl NN { for ele in weight { if gen_range(0., 1.) < self.mut_rate { // *ele += gen_range(-1., 1.); - *ele = r::thread_rng().sample::(StandardNormal); + *ele = gen_range(-1., 1.); + // *ele = r::thread_rng().sample::(StandardNormal); // *ele = r::thread_rng().sample::(StandardNormal); } } @@ -77,6 +82,7 @@ impl NN { } pub fn feed_forward(&self, inputs: Vec) -> Vec { + // println!("inputs: {:?}", inputs); let mut y = DMatrix::from_vec(inputs.len(), 1, inputs); for i in 0..self.config.len() - 1 { y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| { @@ -86,6 +92,8 @@ impl NN { ActivationFunc::Tanh => x.tanh(), } }); + // println!("w{}: {}", i, self.weights[i]); + // println!("y: {}", y); } y.column(0).data.into_slice().to_vec() } diff --git a/src/player.rs b/src/player.rs index 66b06e9..9e4751d 100644 --- a/src/player.rs +++ b/src/player.rs @@ -1,7 +1,7 @@ use std::{f32::consts::PI, f64::consts::TAU}; use macroquad::{prelude::*, rand::gen_range}; -use nalgebra::{max, partial_max}; +use nalgebra::{max, partial_max, partial_min}; use crate::{asteroids::Asteroid, nn::NN}; #[derive(Default)] @@ -9,7 +9,7 @@ pub struct Player { pub pos: Vec2, pub vel: Vec2, acc: f32, - dir: Vec2, + pub dir: Vec2, rot: f32, drag: f32, bullets: Vec, @@ -36,7 +36,7 @@ impl Player { alive: true, debug: false, shots: 4, - raycasts: vec![0.; 8], + raycasts: vec![f32::MAX; 3], ..Default::default() } } @@ -44,39 +44,52 @@ impl Player { pub fn simulate(brain: Option) -> Self { let mut p = Player::new(); if let Some(brain) = brain { - assert_eq!( - brain.config[0] - 1, - 8 + 0, - "NN input size must match max_asteroids" - ); + // assert_eq!( + // brain.config[0] - 1, + // 8 + 5, + // "NN input size must match max_asteroids" + // ); p.brain = Some(brain); } else { - p.brain = Some(NN::new(vec![8 + 0, 16, 4])); + p.brain = Some(NN::new(vec![3, 8, 8, 4])); } p } pub fn check_player_collision(&mut self, asteroid: &mut Asteroid) -> bool { - // self.asteroids_data.extend([ - // asteroid.pos.x / screen_width() + 0.5, - // asteroid.pos.y / screen_height() + 0.5, - // asteroid.radius / 50., - // ]); - let v = asteroid.pos - self.pos; - for i in 0..4 { - let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir); - let cross = v.perp_dot(dir); - let dot = v.dot(dir); - if cross.abs() <= asteroid.radius { - self.raycasts[if dot >= 0. { i } else { i + 4 }] = *partial_max( - &self.raycasts[if dot >= 0. { i } else { i + 4 }], - &(1. / (dot.abs() - - (asteroid.radius * asteroid.radius - cross * cross).sqrt())), - ) - .unwrap(); - } + // self.raycasts.extend([ + if (asteroid.pos).distance_squared(self.pos) + < vec2( + self.raycasts[0] * screen_width(), + self.raycasts[1] * screen_height(), + ) + .distance_squared(self.pos) + { + self.raycasts[0] = asteroid.pos.x / screen_width(); + self.raycasts[1] = asteroid.pos.y / screen_height(); + self.raycasts[2] = asteroid.radius / 50.; } - if asteroid.check_collision(self.pos, 8.) { + // ]); + // if self.raycasts[0] > (asteroid.pos - self.pos).length_squared() { + // self.raycasts[0] = (asteroid.pos - self.pos).length_squared(); + // self.raycasts[1] = Vec2::angle_between(asteroid.pos - self.pos, self.dir).sin(); + // self.raycasts[2] = Vec2::angle_between(asteroid.pos - self.pos, self.dir).cos(); + // } + // let v = asteroid.pos - self.pos; + // for i in 0..4 { + // let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir); + // let cross = v.perp_dot(dir); + // let dot = v.dot(dir); + // if cross.abs() <= asteroid.radius { + // self.raycasts[if dot >= 0. { i } else { i + 4 }] = *partial_max( + // &self.raycasts[if dot >= 0. { i } else { i + 4 }], + // &(1. / (dot.abs() + // - (asteroid.radius * asteroid.radius - cross * cross).sqrt())), + // ) + // .unwrap(); + // } + // } + if asteroid.check_collision(self.pos, 8.) || self.lifespan > 2000 { self.alive = false; return true; } @@ -99,22 +112,39 @@ impl Player { self.last_shot += 1; self.acc = 0.; let mut keys = vec![false, false, false, false]; - // self.asteroids_data.resize(self.max_asteroids * 3, 0.); - // let mut inputs = vec![ - // self.pos.x / screen_width() + 0.5, - // self.pos.y / screen_height() + 0.5, - // self.vel.x / 11., - // self.vel.y / 11., - // self.rot / TAU as f32, - // ]; - // inputs.append(self.raycasts.as_mut()); - let inputs = self.raycasts.clone(); + let mut inputs = vec![ + (vec2( + self.raycasts[0] * screen_width(), + self.raycasts[1] * screen_height(), + ) - self.pos) + .length() + * 0.707 + / screen_width(), + // self.raycasts[0] - self.pos.x / screen_width(), + // self.raycasts[1] - self.pos.y / screen_height(), + self.dir.angle_between( + vec2( + self.raycasts[0] * screen_width(), + self.raycasts[1] * screen_height(), + ) - self.pos, + ), + // self.vel.x / 11., + // self.vel.y / 11., + self.rot, // self.rot.sin(), + // self.rot.cos(), + ]; + + // self.raycasts.resize(3, 0.); + // inputs.append(self.raycasts.clone().as_mut()); + // println!("inputs: {:?}", inputs); + + // let inputs = self.raycasts.clone(); // inputs.append(self.asteroids_data.as_mut()); if let Some(brain) = &self.brain { - // println!("{:?}", inputs); + // println!("{:?}", brain.feed_forward(inputs.clone())); + keys = brain.feed_forward(inputs).iter().map(|&x| x > 0.).collect(); } - self.raycasts = vec![0.; 8]; if is_key_down(KeyCode::Right) && self.debug || keys[0] { self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32; self.dir = vec2(self.rot.cos(), self.rot.sin()); @@ -132,8 +162,8 @@ impl Player { self.last_shot = 0; self.shots += 1; self.bullets.push(Bullet { - pos: self.pos + self.dir.rotate(vec2(20., 0.)), - vel: self.dir.rotate(vec2(8.5, 0.)) + self.vel, + pos: self.pos + self.dir * 20., + vel: self.dir * 8.5 + self.vel, alive: true, }); } @@ -158,6 +188,7 @@ impl Player { self.bullets.retain(|b| { b.alive && b.pos.x.abs() * 2. < screen_width() && b.pos.y.abs() * 2. < screen_height() }); + self.raycasts = vec![100.; 3]; } pub fn draw(&self) { @@ -181,21 +212,34 @@ impl Player { } if self.debug { - // for a in self.asteroids_data.chunks(3) { - // draw_circle_lines(a[0], a[1], a[2], 1., GRAY); - // draw_line(self.pos.x, self.pos.y, a[0], a[1], 1., GRAY) - // } - for (i, r) in self.raycasts.iter().enumerate() { - let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir); - draw_line( - self.pos.x, - self.pos.y, - self.pos.x + dir.x / r, - self.pos.y + dir.y / r, + for a in self.raycasts.chunks(3) { + draw_circle_lines( + a[0] * screen_width(), + a[1] * screen_height(), + a[2] * 50., 1., GRAY, ); + draw_line( + self.pos.x, + self.pos.y, + a[0] * screen_width(), + a[1] * screen_height(), + 1., + GRAY, + ) } + // for (i, r) in self.raycasts.iter().enumerate() { + // let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir); + // draw_line( + // self.pos.x, + // self.pos.y, + // self.pos.x + dir.x / r, + // self.pos.y + dir.y / r, + // 1., + // GRAY, + // ); + // } } for bullet in &self.bullets { diff --git a/src/population.rs b/src/population.rs index 8f6a214..1cd4c31 100644 --- a/src/population.rs +++ b/src/population.rs @@ -6,6 +6,7 @@ use crate::{nn::NN, world::World}; pub struct Population { size: usize, gen: i32, + best: bool, pub worlds: Vec, } @@ -30,21 +31,28 @@ impl Population { self.gen += 1; self.next_gen(); } + if is_key_pressed(KeyCode::Z) { + self.best = !self.best; + } } pub fn draw(&self) { for world in self.worlds.iter().rev() { - if !world.over { + if self.best { + if world.player.color.is_some() { + world.draw(); + } + } else if !world.over { world.draw(); - draw_text( - &format!("Gen: {}", self.gen), - -150. + screen_width() * 0.5, - 30. - screen_height() * 0.5, - 32., - WHITE, - ); } } + draw_text( + &format!("Gen: {}", self.gen), + -150. + screen_width() * 0.5, + 30. - screen_height() * 0.5, + 32., + WHITE, + ); } pub fn next_gen(&mut self) { diff --git a/src/world.rs b/src/world.rs index afda072..9b5860d 100644 --- a/src/world.rs +++ b/src/world.rs @@ -7,9 +7,9 @@ use macroquad::{prelude::*, rand::gen_range}; #[derive(Default)] pub struct World { - player: Player, + pub player: Player, asteroids: Vec, - pub score: u32, + pub score: f32, pub over: bool, max_asteroids: usize, pub fitness: f32, @@ -20,7 +20,7 @@ impl World { Self { player: Player::new(), max_asteroids: 28, - score: 1, + score: 0., ..Default::default() } } @@ -29,7 +29,7 @@ impl World { Self { player: Player::simulate(brain), max_asteroids: 28, - score: 1, + score: 0., asteroids: vec![ Asteroid::new_to(vec2(0., 0.), 1.5, AsteroidSize::Large), Asteroid::new(AsteroidSize::Large), @@ -79,17 +79,13 @@ impl World { asteroid.update(); if self.player.check_player_collision(asteroid) { self.over = true; - self.fitness = (self.score as f32 + 1.) - * (self.score as f32 / self.player.shots as f32) - * (self.score as f32 / self.player.shots as f32) - * self.player.lifespan as f32 - * 0.01; - // self.fitness = self.player.lifespan as f32 * self.player.lifespan as f32 * 0.001; + self.fitness = + (self.score / self.player.shots as f32).powi(2) * self.player.lifespan as f32; // println!("{} {} {}", self.score, self.player.lifespan, self.fitness); } if self.player.check_bullet_collisions(asteroid) { - self.score += 1; + self.score += 1.; match asteroid.size { AsteroidSize::Large => { let rand = vec2(gen_range(-0.8, 0.8), gen_range(-0.8, 0.8)); @@ -130,7 +126,8 @@ impl World { // AsteroidSize::Small => 1, // } // }) < self.max_asteroids - if self.player.lifespan % 400 == 0 { + // { + if self.player.lifespan % 200 == 0 { self.asteroids .push(Asteroid::new_to(self.player.pos, 1.5, AsteroidSize::Large)); } @@ -141,12 +138,14 @@ impl World { for asteroid in &self.asteroids { asteroid.draw(); } - // draw_text( - // &format!("Score {}", self.score), - // 20. - screen_width() * 0.5, - // 30. - screen_height() * 0.5, - // 32., - // WHITE, - // ); + draw_text( + &format!("{}", (self.score / self.player.shots as f32).powi(2) as f32), + // 20. - screen_width() * 0.5, + // 30. - screen_height() * 0.5, + self.player.pos.x - 20., + self.player.pos.y - 20., + 12., + WHITE, + ); } }