From 1f0b6e06e0ef448ccfbf3bcddf31081b5d58a112 Mon Sep 17 00:00:00 2001 From: sparshg <43041139+sparshg@users.noreply.github.com> Date: Tue, 11 Oct 2022 00:06:14 +0530 Subject: [PATCH] genetic algorithm --- src/asteroids.rs | 2 +- src/main.rs | 21 +++++++----- src/nn.rs | 6 ++-- src/player.rs | 58 ++++++++++++++++++++++++++------- src/population.rs | 82 ++++++++++++++++++++++++++++++++++++++++------- src/world.rs | 27 ++++++++++------ 6 files changed, 153 insertions(+), 43 deletions(-) diff --git a/src/asteroids.rs b/src/asteroids.rs index 8b7196e..97f9e76 100644 --- a/src/asteroids.rs +++ b/src/asteroids.rs @@ -10,7 +10,7 @@ pub struct Asteroid { pub vel: Vec2, pub size: AsteroidSize, sides: u8, - radius: f32, + pub radius: f32, rot: f32, omega: f32, pub alive: bool, diff --git a/src/main.rs b/src/main.rs index b00fd58..9074fb7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,7 @@ mod population; mod world; use macroquad::prelude::*; -use nn::NN; +use population::Population; use world::World; #[macroquad::main("Camera")] @@ -16,15 +16,20 @@ async fn main() { ..Default::default() }; set_camera(&cam); - let mut world = World::new(); - // let mut nn = NN::new(vec![1, 2, 1]); - + let mut pop = Population::new(5); loop { clear_background(BLACK); - if !world.over { - world.update(); - } - world.draw(); + pop.update(); + pop.draw(); next_frame().await } + // let mut world = World::new(); + // loop { + // clear_background(BLACK); + // if !world.over { + // world.update(); + // } + // world.draw(); + // next_frame().await + // } } diff --git a/src/nn.rs b/src/nn.rs index 9159e28..be59630 100644 --- a/src/nn.rs +++ b/src/nn.rs @@ -9,8 +9,10 @@ enum ActivationFunc { Tanh, ReLU, } + +#[derive(Clone)] pub struct NN { - config: Vec, + pub config: Vec, weights: Vec>, activ_func: ActivationFunc, mut_rate: f32, @@ -58,7 +60,7 @@ impl NN { } } - pub fn mutation(&mut self) { + pub fn mutate(&mut self) { for weight in &mut self.weights { for ele in weight { if r::random() { diff --git a/src/player.rs b/src/player.rs index 1f3c05f..9cf92ec 100644 --- a/src/player.rs +++ b/src/player.rs @@ -1,4 +1,4 @@ -use std::{f32::consts::PI, path::Iter}; +use std::{f32::consts::PI, f64::consts::TAU}; use macroquad::{prelude::*, rand::gen_range}; @@ -13,7 +13,11 @@ pub struct Player { bullets: Vec, last_shot: f32, shot_interval: f32, - brain: Option, + pub brain: Option, + search_radius: f32, + proximity_asteroids: Vec, + max_asteroids: usize, + debug: bool, alive: bool, } @@ -21,21 +25,33 @@ impl Player { pub fn new() -> Self { Self { dir: vec2(0., -1.), - rot: -PI / 2., + rot: 1.5 * PI, + + // Change scaling when passing inputs if this is changed drag: 0.001, shot_interval: 0.3, + search_radius: 300., alive: true, + debug: false, ..Default::default() } } - pub fn simulate(brain: NN) -> Self { + pub fn simulate(brain: NN, max_asteroids: usize) -> Self { + assert_eq!( + brain.config[0] - 1, + max_asteroids + 5, + "NN input size must match max_asteroids" + ); let mut p = Player::new(); p.brain = Some(brain); + p.max_asteroids = max_asteroids; p } pub fn check_player_collision(&mut self, asteroid: &mut Asteroid) -> bool { + self.proximity_asteroids + .extend([asteroid.pos.x, asteroid.pos.y, asteroid.radius]); if asteroid.check_collision(self.pos, 8.) { self.alive = false; return true; @@ -56,28 +72,37 @@ impl Player { pub fn update(&mut self) { let mut mag = 0.; - let mut keys = vec![false, false, false]; + let mut keys = vec![false, false, false, false]; + + self.proximity_asteroids.resize(self.max_asteroids, 0.); + let mut inputs = vec![ + self.pos.x / screen_width() + 0.5, + self.pos.y / screen_height() + 0.5, + self.vel.x / 11., + self.vel.y / 11., + self.rot / TAU as f32, + ]; + inputs.append(self.proximity_asteroids.as_mut()); if let Some(brain) = &self.brain { keys = brain - .feed_forward(vec![ - self.pos.x, self.pos.y, self.vel.x, self.vel.y, self.rot, - ]) + .feed_forward(inputs) .iter() .map(|&x| if x > 0. { true } else { false }) .collect(); } if is_key_down(KeyCode::Right) || keys[0] { - self.rot += 0.1; + self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32; self.dir = vec2(self.rot.cos(), self.rot.sin()); } if is_key_down(KeyCode::Left) || keys[1] { - self.rot -= 0.1; + self.rot = (self.rot - 0.1 + TAU as f32) % TAU as f32; self.dir = vec2(self.rot.cos(), self.rot.sin()); } if is_key_down(KeyCode::Up) || keys[2] { + // Change scaling when passing inputs if this is changed mag = 0.14; } - if is_key_down(KeyCode::Space) { + if is_key_down(KeyCode::Space) || keys[3] { if self.shot_interval + self.last_shot < get_time() as f32 { self.last_shot = get_time() as f32; self.bullets.push(Bullet { @@ -88,6 +113,10 @@ impl Player { } } + if is_key_pressed(KeyCode::D) { + self.debug = !self.debug; + } + self.vel += mag * self.dir - self.drag * self.vel.length() * self.vel; self.pos += self.vel; if self.pos.x.abs() > screen_width() / 2. + 10. { @@ -121,6 +150,13 @@ impl Player { draw_triangle_lines(p6, p7, p8, 2., WHITE); } + if self.debug { + for a in self.proximity_asteroids.chunks(3) { + draw_circle_lines(a[0], a[1], a[2], 1., GRAY); + draw_line(self.pos.x, self.pos.y, a[0], a[1], 1., GRAY) + } + } + for bullet in &self.bullets { bullet.draw(); } diff --git a/src/population.rs b/src/population.rs index a04425c..3187a2b 100644 --- a/src/population.rs +++ b/src/population.rs @@ -1,18 +1,78 @@ -use crate::world::World; +use macroquad::{prelude::*, rand::gen_range}; + +use crate::{nn::NN, world::World}; #[derive(Default)] -struct Population { - size: i32, +pub struct Population { + size: usize, gen: i32, - worlds: Vec, + pub worlds: Vec, } impl Population { - // pub fn new(size: i32) -> Self { - // Self { - // size: size, - // worlds: vec![World::new(); size], - // ..Default::default(), - // } - // } + pub fn new(size: usize) -> Self { + Self { + size, + worlds: (0..size) + .map(|_| World::simulate(NN::new(vec![33, 10, 4]))) + .collect(), + ..Default::default() + } + } + + pub fn update(&mut self) { + let mut alive = false; + for world in &mut self.worlds { + if !world.over { + alive = true; + world.update(); + } + } + if !alive { + self.gen += 1; + self.next_gen(); + } + } + + pub fn draw(&self) { + for world in &self.worlds { + if !world.over { + world.draw(); + draw_text( + &format!("Gen: {}", self.gen), + -100. + screen_width() * 0.5, + 30. - screen_height() * 0.5, + 32., + WHITE, + ); + } + } + } + + pub fn next_gen(&mut self) { + let total = self.worlds.iter().fold(0, |acc, x| acc + x.score); + self.worlds.sort_by(|a, b| b.score.cmp(&a.score)); + let mut new_worlds = (0..self.size / 10) + .map(|i| World::simulate(self.worlds[i].see_brain().to_owned())) + .collect::>(); + + while new_worlds.len() < self.size { + let rands = (gen_range(0, total + 1), gen_range(0, total + 1)); + let mut sum = 0; + let (mut a, mut b) = (None, None); + for world in &self.worlds { + sum += world.score; + if sum >= rands.0 { + a = Some(world.see_brain()); + } + if sum >= rands.1 { + b = Some(world.see_brain()); + } + } + let mut new_brain = NN::crossover(a.unwrap(), b.unwrap()); + new_brain.mutate(); + new_worlds.push(World::simulate(new_brain)); + } + self.worlds = new_worlds; + } } diff --git a/src/world.rs b/src/world.rs index 135aa29..7001192 100644 --- a/src/world.rs +++ b/src/world.rs @@ -9,25 +9,32 @@ use macroquad::{prelude::*, rand::gen_range}; pub struct World { player: Player, asteroids: Vec, - pub score: i32, + pub score: u32, pub over: bool, + max_asteroids: usize, } impl World { pub fn new() -> Self { Self { player: Player::new(), + max_asteroids: 28, ..Default::default() } } pub fn simulate(brain: NN) -> Self { Self { - player: Player::simulate(brain), + player: Player::simulate(brain, 28), + max_asteroids: 28, ..Default::default() } } + pub fn see_brain(&self) -> &NN { + self.player.brain.as_ref().unwrap() + } + pub fn update(&mut self) { self.player.update(); let mut to_add: Vec = Vec::new(); @@ -77,7 +84,7 @@ impl World { AsteroidSize::Medium => 2, AsteroidSize::Small => 1, } - }) < 20 + }) < self.max_asteroids { self.asteroids.push(Asteroid::new(AsteroidSize::Large)); } @@ -88,12 +95,12 @@ impl World { for asteroid in &self.asteroids { asteroid.draw(); } - draw_text( - &format!("Score {}", self.score), - 20. - screen_width() * 0.5, - 30. - screen_height() * 0.5, - 32., - WHITE, - ); + // draw_text( + // &format!("Score {}", self.score), + // 20. - screen_width() * 0.5, + // 30. - screen_height() * 0.5, + // 32., + // WHITE, + // ); } }