genetic algorithm

This commit is contained in:
sparshg 2022-10-11 00:06:14 +05:30
parent e66d8ff9d4
commit 1f0b6e06e0
6 changed files with 153 additions and 43 deletions

View File

@ -10,7 +10,7 @@ pub struct Asteroid {
pub vel: Vec2, pub vel: Vec2,
pub size: AsteroidSize, pub size: AsteroidSize,
sides: u8, sides: u8,
radius: f32, pub radius: f32,
rot: f32, rot: f32,
omega: f32, omega: f32,
pub alive: bool, pub alive: bool,

View File

@ -5,7 +5,7 @@ mod population;
mod world; mod world;
use macroquad::prelude::*; use macroquad::prelude::*;
use nn::NN; use population::Population;
use world::World; use world::World;
#[macroquad::main("Camera")] #[macroquad::main("Camera")]
@ -16,15 +16,20 @@ async fn main() {
..Default::default() ..Default::default()
}; };
set_camera(&cam); set_camera(&cam);
let mut world = World::new(); let mut pop = Population::new(5);
// let mut nn = NN::new(vec![1, 2, 1]);
loop { loop {
clear_background(BLACK); clear_background(BLACK);
if !world.over { pop.update();
world.update(); pop.draw();
}
world.draw();
next_frame().await next_frame().await
} }
// let mut world = World::new();
// loop {
// clear_background(BLACK);
// if !world.over {
// world.update();
// }
// world.draw();
// next_frame().await
// }
} }

View File

@ -9,8 +9,10 @@ enum ActivationFunc {
Tanh, Tanh,
ReLU, ReLU,
} }
#[derive(Clone)]
pub struct NN { pub struct NN {
config: Vec<usize>, pub config: Vec<usize>,
weights: Vec<DMatrix<f32>>, weights: Vec<DMatrix<f32>>,
activ_func: ActivationFunc, activ_func: ActivationFunc,
mut_rate: f32, mut_rate: f32,
@ -58,7 +60,7 @@ impl NN {
} }
} }
pub fn mutation(&mut self) { pub fn mutate(&mut self) {
for weight in &mut self.weights { for weight in &mut self.weights {
for ele in weight { for ele in weight {
if r::random() { if r::random() {

View File

@ -1,4 +1,4 @@
use std::{f32::consts::PI, path::Iter}; use std::{f32::consts::PI, f64::consts::TAU};
use macroquad::{prelude::*, rand::gen_range}; use macroquad::{prelude::*, rand::gen_range};
@ -13,7 +13,11 @@ pub struct Player {
bullets: Vec<Bullet>, bullets: Vec<Bullet>,
last_shot: f32, last_shot: f32,
shot_interval: f32, shot_interval: f32,
brain: Option<NN>, pub brain: Option<NN>,
search_radius: f32,
proximity_asteroids: Vec<f32>,
max_asteroids: usize,
debug: bool,
alive: bool, alive: bool,
} }
@ -21,21 +25,33 @@ impl Player {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
dir: vec2(0., -1.), dir: vec2(0., -1.),
rot: -PI / 2., rot: 1.5 * PI,
// Change scaling when passing inputs if this is changed
drag: 0.001, drag: 0.001,
shot_interval: 0.3, shot_interval: 0.3,
search_radius: 300.,
alive: true, alive: true,
debug: false,
..Default::default() ..Default::default()
} }
} }
pub fn simulate(brain: NN) -> Self { pub fn simulate(brain: NN, max_asteroids: usize) -> Self {
assert_eq!(
brain.config[0] - 1,
max_asteroids + 5,
"NN input size must match max_asteroids"
);
let mut p = Player::new(); let mut p = Player::new();
p.brain = Some(brain); p.brain = Some(brain);
p.max_asteroids = max_asteroids;
p p
} }
pub fn check_player_collision(&mut self, asteroid: &mut Asteroid) -> bool { pub fn check_player_collision(&mut self, asteroid: &mut Asteroid) -> bool {
self.proximity_asteroids
.extend([asteroid.pos.x, asteroid.pos.y, asteroid.radius]);
if asteroid.check_collision(self.pos, 8.) { if asteroid.check_collision(self.pos, 8.) {
self.alive = false; self.alive = false;
return true; return true;
@ -56,28 +72,37 @@ impl Player {
pub fn update(&mut self) { pub fn update(&mut self) {
let mut mag = 0.; let mut mag = 0.;
let mut keys = vec![false, false, false]; let mut keys = vec![false, false, false, false];
self.proximity_asteroids.resize(self.max_asteroids, 0.);
let mut inputs = vec![
self.pos.x / screen_width() + 0.5,
self.pos.y / screen_height() + 0.5,
self.vel.x / 11.,
self.vel.y / 11.,
self.rot / TAU as f32,
];
inputs.append(self.proximity_asteroids.as_mut());
if let Some(brain) = &self.brain { if let Some(brain) = &self.brain {
keys = brain keys = brain
.feed_forward(vec![ .feed_forward(inputs)
self.pos.x, self.pos.y, self.vel.x, self.vel.y, self.rot,
])
.iter() .iter()
.map(|&x| if x > 0. { true } else { false }) .map(|&x| if x > 0. { true } else { false })
.collect(); .collect();
} }
if is_key_down(KeyCode::Right) || keys[0] { if is_key_down(KeyCode::Right) || keys[0] {
self.rot += 0.1; self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32;
self.dir = vec2(self.rot.cos(), self.rot.sin()); self.dir = vec2(self.rot.cos(), self.rot.sin());
} }
if is_key_down(KeyCode::Left) || keys[1] { if is_key_down(KeyCode::Left) || keys[1] {
self.rot -= 0.1; self.rot = (self.rot - 0.1 + TAU as f32) % TAU as f32;
self.dir = vec2(self.rot.cos(), self.rot.sin()); self.dir = vec2(self.rot.cos(), self.rot.sin());
} }
if is_key_down(KeyCode::Up) || keys[2] { if is_key_down(KeyCode::Up) || keys[2] {
// Change scaling when passing inputs if this is changed
mag = 0.14; mag = 0.14;
} }
if is_key_down(KeyCode::Space) { if is_key_down(KeyCode::Space) || keys[3] {
if self.shot_interval + self.last_shot < get_time() as f32 { if self.shot_interval + self.last_shot < get_time() as f32 {
self.last_shot = get_time() as f32; self.last_shot = get_time() as f32;
self.bullets.push(Bullet { self.bullets.push(Bullet {
@ -88,6 +113,10 @@ impl Player {
} }
} }
if is_key_pressed(KeyCode::D) {
self.debug = !self.debug;
}
self.vel += mag * self.dir - self.drag * self.vel.length() * self.vel; self.vel += mag * self.dir - self.drag * self.vel.length() * self.vel;
self.pos += self.vel; self.pos += self.vel;
if self.pos.x.abs() > screen_width() / 2. + 10. { if self.pos.x.abs() > screen_width() / 2. + 10. {
@ -121,6 +150,13 @@ impl Player {
draw_triangle_lines(p6, p7, p8, 2., WHITE); draw_triangle_lines(p6, p7, p8, 2., WHITE);
} }
if self.debug {
for a in self.proximity_asteroids.chunks(3) {
draw_circle_lines(a[0], a[1], a[2], 1., GRAY);
draw_line(self.pos.x, self.pos.y, a[0], a[1], 1., GRAY)
}
}
for bullet in &self.bullets { for bullet in &self.bullets {
bullet.draw(); bullet.draw();
} }

View File

@ -1,18 +1,78 @@
use crate::world::World; use macroquad::{prelude::*, rand::gen_range};
use crate::{nn::NN, world::World};
#[derive(Default)] #[derive(Default)]
struct Population { pub struct Population {
size: i32, size: usize,
gen: i32, gen: i32,
worlds: Vec<World>, pub worlds: Vec<World>,
} }
impl Population { impl Population {
// pub fn new(size: i32) -> Self { pub fn new(size: usize) -> Self {
// Self { Self {
// size: size, size,
// worlds: vec![World::new(); size], worlds: (0..size)
// ..Default::default(), .map(|_| World::simulate(NN::new(vec![33, 10, 4])))
// } .collect(),
// } ..Default::default()
}
}
pub fn update(&mut self) {
let mut alive = false;
for world in &mut self.worlds {
if !world.over {
alive = true;
world.update();
}
}
if !alive {
self.gen += 1;
self.next_gen();
}
}
pub fn draw(&self) {
for world in &self.worlds {
if !world.over {
world.draw();
draw_text(
&format!("Gen: {}", self.gen),
-100. + screen_width() * 0.5,
30. - screen_height() * 0.5,
32.,
WHITE,
);
}
}
}
pub fn next_gen(&mut self) {
let total = self.worlds.iter().fold(0, |acc, x| acc + x.score);
self.worlds.sort_by(|a, b| b.score.cmp(&a.score));
let mut new_worlds = (0..self.size / 10)
.map(|i| World::simulate(self.worlds[i].see_brain().to_owned()))
.collect::<Vec<_>>();
while new_worlds.len() < self.size {
let rands = (gen_range(0, total + 1), gen_range(0, total + 1));
let mut sum = 0;
let (mut a, mut b) = (None, None);
for world in &self.worlds {
sum += world.score;
if sum >= rands.0 {
a = Some(world.see_brain());
}
if sum >= rands.1 {
b = Some(world.see_brain());
}
}
let mut new_brain = NN::crossover(a.unwrap(), b.unwrap());
new_brain.mutate();
new_worlds.push(World::simulate(new_brain));
}
self.worlds = new_worlds;
}
} }

View File

@ -9,25 +9,32 @@ use macroquad::{prelude::*, rand::gen_range};
pub struct World { pub struct World {
player: Player, player: Player,
asteroids: Vec<Asteroid>, asteroids: Vec<Asteroid>,
pub score: i32, pub score: u32,
pub over: bool, pub over: bool,
max_asteroids: usize,
} }
impl World { impl World {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
player: Player::new(), player: Player::new(),
max_asteroids: 28,
..Default::default() ..Default::default()
} }
} }
pub fn simulate(brain: NN) -> Self { pub fn simulate(brain: NN) -> Self {
Self { Self {
player: Player::simulate(brain), player: Player::simulate(brain, 28),
max_asteroids: 28,
..Default::default() ..Default::default()
} }
} }
pub fn see_brain(&self) -> &NN {
self.player.brain.as_ref().unwrap()
}
pub fn update(&mut self) { pub fn update(&mut self) {
self.player.update(); self.player.update();
let mut to_add: Vec<Asteroid> = Vec::new(); let mut to_add: Vec<Asteroid> = Vec::new();
@ -77,7 +84,7 @@ impl World {
AsteroidSize::Medium => 2, AsteroidSize::Medium => 2,
AsteroidSize::Small => 1, AsteroidSize::Small => 1,
} }
}) < 20 }) < self.max_asteroids
{ {
self.asteroids.push(Asteroid::new(AsteroidSize::Large)); self.asteroids.push(Asteroid::new(AsteroidSize::Large));
} }
@ -88,12 +95,12 @@ impl World {
for asteroid in &self.asteroids { for asteroid in &self.asteroids {
asteroid.draw(); asteroid.draw();
} }
draw_text( // draw_text(
&format!("Score {}", self.score), // &format!("Score {}", self.score),
20. - screen_width() * 0.5, // 20. - screen_width() * 0.5,
30. - screen_height() * 0.5, // 30. - screen_height() * 0.5,
32., // 32.,
WHITE, // WHITE,
); // );
} }
} }