bug fixes

This commit is contained in:
sparshg 2022-10-23 01:10:19 +05:30
parent db2df770c6
commit 1c29b6d419
6 changed files with 109 additions and 69 deletions

View File

@ -89,7 +89,7 @@ impl Asteroid {
AsteroidSize::Medium => 1.2, AsteroidSize::Medium => 1.2,
AsteroidSize::Small => 0.8, AsteroidSize::Small => 0.8,
}, },
Color::new(1., 1., 1., 0.4), Color::new(1., 1., 1., 1.),
); );
} }
} }

View File

@ -16,9 +16,9 @@ async fn main() {
..Default::default() ..Default::default()
}; };
set_camera(&cam); set_camera(&cam);
let mut pop = Population::new(10); let mut pop = Population::new(100);
let mut speedup = false; let mut speedup = false;
// for _ in 0..100000 * 5 { // for _ in 0..10000 * 10 {
// pop.update(); // pop.update();
// } // }
loop { loop {

View File

@ -36,13 +36,16 @@ impl NN {
.iter() .iter()
.zip(config.iter().skip(1)) .zip(config.iter().skip(1))
.map(|(&curr, &last)| { .map(|(&curr, &last)| {
// let a = DMatrix::<f32>::new_random(last, curr + 1);
// println!("{}", a);
// a
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng) DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
* (2. / last as f32).sqrt() * (2. / last as f32).sqrt()
}) })
.collect(), .collect(),
activ_func: ActivationFunc::ReLU, activ_func: ActivationFunc::ReLU,
mut_rate: 0.05, mut_rate: 0.02,
} }
} }
@ -64,8 +67,10 @@ impl NN {
pub fn mutate(&mut self) { pub fn mutate(&mut self) {
for weight in &mut self.weights { for weight in &mut self.weights {
for ele in weight { for ele in weight {
if gen_range(0., 1.) < 0.05 { if gen_range(0., 1.) < self.mut_rate {
// *ele += gen_range(-1., 1.);
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal); *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
// *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
} }
} }
} }

View File

@ -5,19 +5,21 @@ use macroquad::{prelude::*, rand::gen_range};
use crate::{asteroids::Asteroid, nn::NN}; use crate::{asteroids::Asteroid, nn::NN};
#[derive(Default)] #[derive(Default)]
pub struct Player { pub struct Player {
pos: Vec2, pub pos: Vec2,
vel: Vec2, pub vel: Vec2,
acc: f32,
dir: Vec2, dir: Vec2,
rot: f32, rot: f32,
drag: f32, drag: f32,
bullets: Vec<Bullet>, bullets: Vec<Bullet>,
last_shot: f32, last_shot: u8,
shot_interval: f32, shot_interval: u8,
pub brain: Option<NN>, pub brain: Option<NN>,
asteroids_data: Vec<f32>, asteroids_data: Vec<f32>,
max_asteroids: usize, max_asteroids: usize,
debug: bool, debug: bool,
alive: bool, alive: bool,
pub color: Option<Color>,
pub lifespan: u32, pub lifespan: u32,
pub shots: u32, pub shots: u32,
} }
@ -30,21 +32,25 @@ impl Player {
// Change scaling when passing inputs if this is changed // Change scaling when passing inputs if this is changed
drag: 0.001, drag: 0.001,
shot_interval: 0.3, shot_interval: 18,
alive: true, alive: true,
debug: false, debug: false,
..Default::default() ..Default::default()
} }
} }
pub fn simulate(brain: NN, max_asteroids: usize) -> Self { pub fn simulate(brain: Option<NN>, max_asteroids: usize) -> Self {
assert_eq!(
brain.config[0] - 1,
max_asteroids * 3 + 5,
"NN input size must match max_asteroids"
);
let mut p = Player::new(); let mut p = Player::new();
p.brain = Some(brain); if let Some(brain) = brain {
assert_eq!(
brain.config[0] - 1,
max_asteroids * 3 + 5,
"NN input size must match max_asteroids"
);
p.brain = Some(brain);
} else {
p.brain = Some(NN::new(vec![max_asteroids * 3 + 5, 16, 4]));
}
p.max_asteroids = max_asteroids; p.max_asteroids = max_asteroids;
p p
} }
@ -75,7 +81,8 @@ impl Player {
pub fn update(&mut self) { pub fn update(&mut self) {
self.lifespan += 1; self.lifespan += 1;
let mut mag = 0.; self.last_shot += 1;
self.acc = 0.;
let mut keys = vec![false, false, false, false]; let mut keys = vec![false, false, false, false];
self.asteroids_data.resize(self.max_asteroids * 3, 0.); self.asteroids_data.resize(self.max_asteroids * 3, 0.);
let mut inputs = vec![ let mut inputs = vec![
@ -87,11 +94,8 @@ impl Player {
]; ];
inputs.append(self.asteroids_data.as_mut()); inputs.append(self.asteroids_data.as_mut());
if let Some(brain) = &self.brain { if let Some(brain) = &self.brain {
keys = brain // println!("{:?}", inputs);
.feed_forward(inputs) keys = brain.feed_forward(inputs).iter().map(|&x| x > 0.).collect();
.iter()
.map(|&x| if x > 0. { true } else { false })
.collect();
} }
if is_key_down(KeyCode::Right) || keys[0] { if is_key_down(KeyCode::Right) || keys[0] {
self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32; self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32;
@ -103,11 +107,11 @@ impl Player {
} }
if is_key_down(KeyCode::Up) || keys[2] { if is_key_down(KeyCode::Up) || keys[2] {
// Change scaling when passing inputs if this is changed // Change scaling when passing inputs if this is changed
mag = 0.14; self.acc = 0.14;
} }
if is_key_down(KeyCode::Space) || keys[3] { if is_key_down(KeyCode::Space) || keys[3] {
if self.shot_interval + self.last_shot < get_time() as f32 { if self.last_shot > self.shot_interval {
self.last_shot = get_time() as f32; self.last_shot = 0;
self.shots += 1; self.shots += 1;
self.bullets.push(Bullet { self.bullets.push(Bullet {
pos: self.pos + self.dir.rotate(vec2(20., 0.)), pos: self.pos + self.dir.rotate(vec2(20., 0.)),
@ -121,7 +125,7 @@ impl Player {
self.debug = !self.debug; self.debug = !self.debug;
} }
self.vel += mag * self.dir - self.drag * self.vel.length() * self.vel; self.vel += self.acc * self.dir - self.drag * self.vel.length() * self.vel;
self.pos += self.vel; self.pos += self.vel;
if self.pos.x.abs() > screen_width() * 0.5 + 10. { if self.pos.x.abs() > screen_width() * 0.5 + 10. {
self.pos.x *= -1.; self.pos.x *= -1.;
@ -139,6 +143,10 @@ impl Player {
} }
pub fn draw(&self) { pub fn draw(&self) {
let color = match self.color {
Some(c) => c,
None => Color::new(1., 1., 1., 0.1),
};
let p1 = self.pos + self.dir.rotate(vec2(20., 0.)); let p1 = self.pos + self.dir.rotate(vec2(20., 0.));
let p2 = self.pos + self.dir.rotate(vec2(-18., -12.667)); let p2 = self.pos + self.dir.rotate(vec2(-18., -12.667));
let p3 = self.pos + self.dir.rotate(vec2(-18., 12.667)); let p3 = self.pos + self.dir.rotate(vec2(-18., 12.667));
@ -147,11 +155,11 @@ impl Player {
let p6 = self.pos + self.dir.rotate(vec2(-25., 0.)); let p6 = self.pos + self.dir.rotate(vec2(-25., 0.));
let p7 = self.pos + self.dir.rotate(vec2(-10., -6.)); let p7 = self.pos + self.dir.rotate(vec2(-10., -6.));
let p8 = self.pos + self.dir.rotate(vec2(-10., 6.)); let p8 = self.pos + self.dir.rotate(vec2(-10., 6.));
draw_line(p1.x, p1.y, p2.x, p2.y, 2., WHITE); draw_line(p1.x, p1.y, p2.x, p2.y, 2., color);
draw_line(p1.x, p1.y, p3.x, p3.y, 2., WHITE); draw_line(p1.x, p1.y, p3.x, p3.y, 2., color);
draw_line(p4.x, p4.y, p5.x, p5.y, 2., WHITE); draw_line(p4.x, p4.y, p5.x, p5.y, 2., color);
if is_key_down(KeyCode::Up) && gen_range(0., 1.) < 0.4 { if self.acc > 0. && gen_range(0., 1.) < 0.4 {
draw_triangle_lines(p6, p7, p8, 2., WHITE); draw_triangle_lines(p6, p7, p8, 2., color);
} }
if self.debug { if self.debug {

View File

@ -13,9 +13,7 @@ impl Population {
pub fn new(size: usize) -> Self { pub fn new(size: usize) -> Self {
Self { Self {
size, size,
worlds: (0..size) worlds: (0..size).map(|_| World::simulate(None)).collect(),
.map(|_| World::simulate(NN::new(vec![89, 16, 4])))
.collect(),
..Default::default() ..Default::default()
} }
} }
@ -30,18 +28,17 @@ impl Population {
} }
if !alive { if !alive {
self.gen += 1; self.gen += 1;
println!("{}", self.gen);
self.next_gen(); self.next_gen();
} }
} }
pub fn draw(&self) { pub fn draw(&self) {
for world in &self.worlds { for world in self.worlds.iter().rev() {
if !world.over { if !world.over {
world.draw(); world.draw();
draw_text( draw_text(
&format!("Gen: {}", self.gen), &format!("Gen: {}", self.gen),
-100. + screen_width() * 0.5, -150. + screen_width() * 0.5,
30. - screen_height() * 0.5, 30. - screen_height() * 0.5,
32., 32.,
WHITE, WHITE,
@ -51,26 +48,33 @@ impl Population {
} }
pub fn next_gen(&mut self) { pub fn next_gen(&mut self) {
let total = self.worlds.iter().fold(0., |acc, x| acc + x.fitness()); let total = self.worlds.iter().fold(0., |acc, x| acc + x.fitness);
self.worlds self.worlds
.sort_by(|a, b| b.fitness().partial_cmp(&a.fitness()).unwrap()); .sort_by(|a, b| b.fitness.partial_cmp(&a.fitness).unwrap());
let mut new_worlds = (0..self.size / 10) for i in &self.worlds {
.map(|i| World::simulate(self.worlds[i].see_brain().to_owned())) println!("Fitness: {}", i.fitness);
}
println!("Gen: {}, Fitness: {}", self.gen, self.worlds[0].fitness);
// let mut new_worlds = vec![World::simulate(Some(self.worlds[0].see_brain().to_owned()))];
let mut new_worlds = (0..self.size / 20)
.map(|i| World::simulate(Some(self.worlds[i].see_brain().to_owned())))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// if is_key_down(KeyCode::K) {
new_worlds[0].set_best();
// }
// println!( // println!(
// "Total fitness: {} {} {}", // "Total fitness: {} {} {}",
// total, // total,
// self.worlds[0].fitness(), // self.worlds[0].fitness(),
// self.worlds[1].fitness() // self.worlds[1].fitness()
// ); // );
while new_worlds.len() < self.size { while new_worlds.len() < self.size {
let rands = (gen_range(0., total), gen_range(0., total)); let rands = (gen_range(0., total), gen_range(0., total));
// println!("rands: {} {}", rands.0, rands.1); // println!("rands: {} {} {}", rands.0, rands.1, total);
let mut sum = 0.; let mut sum = 0.;
let (mut a, mut b) = (None, None); let (mut a, mut b) = (None, None);
for world in &self.worlds { for world in &self.worlds {
sum += world.fitness(); sum += world.fitness;
if a.is_none() && sum >= rands.0 { if a.is_none() && sum >= rands.0 {
a = Some(world.see_brain()); a = Some(world.see_brain());
} }
@ -80,12 +84,18 @@ impl Population {
} }
// println!("{}", &a.unwrap().weights[0]); // println!("{}", &a.unwrap().weights[0]);
// println!("{}", &b.unwrap().weights[0]); // println!("{}", &b.unwrap().weights[0]);
if a.is_none() {
a = Some(self.worlds.last().unwrap().see_brain());
}
if b.is_none() {
b = Some(self.worlds.last().unwrap().see_brain());
}
let mut new_brain = NN::crossover(a.unwrap(), b.unwrap()); let mut new_brain = NN::crossover(a.unwrap(), b.unwrap());
// println!("{}", &a.unwrap().weights[0]); // println!("{}", &a.unwrap().weights[0]);
// println!("{}", &b.unwrap().weights[0]); // println!("{}", &b.unwrap().weights[0]);
// println!("{}", &new_brain.weights[0]);
new_brain.mutate(); new_brain.mutate();
new_worlds.push(World::simulate(new_brain)); // println!("{}", &new_brain.weights[0]);
new_worlds.push(World::simulate(Some(new_brain)));
} }
self.worlds = new_worlds; self.worlds = new_worlds;
} }

View File

@ -12,6 +12,7 @@ pub struct World {
pub score: u32, pub score: u32,
pub over: bool, pub over: bool,
max_asteroids: usize, max_asteroids: usize,
pub fitness: f32,
} }
impl World { impl World {
@ -23,7 +24,7 @@ impl World {
} }
} }
pub fn simulate(brain: NN) -> Self { pub fn simulate(brain: Option<NN>) -> Self {
Self { Self {
player: Player::simulate(brain, 28), player: Player::simulate(brain, 28),
max_asteroids: 28, max_asteroids: 28,
@ -31,39 +32,55 @@ impl World {
} }
} }
pub fn set_best(&mut self) {
self.player.color = Some(RED);
}
pub fn see_brain(&self) -> &NN { pub fn see_brain(&self) -> &NN {
self.player.brain.as_ref().unwrap() self.player.brain.as_ref().unwrap()
} }
pub fn fitness(&self) -> f32 { // fn calc_fitness(&mut self) {
// println!( // println!(
// "{} {} {}", // "{} {} {}",
// self.score as f32, // self.score as f32,
// self.player.lifespan as f32 * 0.001, // self.player.lifespan as f32 * 0.001,
// if self.player.shots > 0 { // if self.player.shots > 0 {
// self.score as f32 / self.player.shots as f32 * 5. // self.score as f32 / self.player.shots as f32 * 5.
// } else { // } else {
// 0. // 0.
// } // }
// ); // );
(self.score + 1) as f32 // }
* 10.
* self.player.lifespan as f32
* if self.player.shots > 0 {
(self.score as f32 / self.player.shots as f32)
* (self.score as f32 / self.player.shots as f32)
} else {
1.
}
}
pub fn update(&mut self) { pub fn update(&mut self) {
self.player.update(); self.player.update();
// if self.player.lifespan > 150 {
// self.fitness = 1.
// / ((self.player.pos * vec2(2. / screen_width(), 2. / screen_height()))
// .distance_squared(vec2(0., -1.))
// + self.player.vel.length_squared()
// * self.player.vel.length_squared()
// * 0.00006830134554
// + 1.);
// self.over = true;
// }
let mut to_add: Vec<Asteroid> = Vec::new(); let mut to_add: Vec<Asteroid> = Vec::new();
for asteroid in &mut self.asteroids { for asteroid in &mut self.asteroids {
asteroid.update(); asteroid.update();
if self.player.check_player_collision(asteroid) { if self.player.check_player_collision(asteroid) {
self.over = true; self.over = true;
self.fitness = (self.score as f32 + 1.)
* if self.player.shots > 0 {
(self.score as f32 / self.player.shots as f32)
* (self.score as f32 / self.player.shots as f32)
} else {
1.
}
* self.player.lifespan as f32;
// self.fitness = self.player.lifespan as f32 * self.player.lifespan as f32 * 0.001;
// println!("{} {} {}", self.score, self.player.lifespan, self.fitness);
} }
if self.player.check_bullet_collisions(asteroid) { if self.player.check_bullet_collisions(asteroid) {
self.score += 1; self.score += 1;