bug fixes

This commit is contained in:
sparshg 2022-10-23 01:10:19 +05:30
parent db2df770c6
commit 1c29b6d419
6 changed files with 109 additions and 69 deletions

View File

@ -89,7 +89,7 @@ impl Asteroid {
AsteroidSize::Medium => 1.2,
AsteroidSize::Small => 0.8,
},
Color::new(1., 1., 1., 0.4),
Color::new(1., 1., 1., 1.),
);
}
}

View File

@ -16,9 +16,9 @@ async fn main() {
..Default::default()
};
set_camera(&cam);
let mut pop = Population::new(10);
let mut pop = Population::new(100);
let mut speedup = false;
// for _ in 0..100000 * 5 {
// for _ in 0..10000 * 10 {
// pop.update();
// }
loop {

View File

@ -36,13 +36,16 @@ impl NN {
.iter()
.zip(config.iter().skip(1))
.map(|(&curr, &last)| {
// let a = DMatrix::<f32>::new_random(last, curr + 1);
// println!("{}", a);
// a
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
* (2. / last as f32).sqrt()
})
.collect(),
activ_func: ActivationFunc::ReLU,
mut_rate: 0.05,
mut_rate: 0.02,
}
}
@ -64,8 +67,10 @@ impl NN {
pub fn mutate(&mut self) {
for weight in &mut self.weights {
for ele in weight {
if gen_range(0., 1.) < 0.05 {
if gen_range(0., 1.) < self.mut_rate {
// *ele += gen_range(-1., 1.);
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
// *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
}
}
}

View File

@ -5,19 +5,21 @@ use macroquad::{prelude::*, rand::gen_range};
use crate::{asteroids::Asteroid, nn::NN};
#[derive(Default)]
pub struct Player {
pos: Vec2,
vel: Vec2,
pub pos: Vec2,
pub vel: Vec2,
acc: f32,
dir: Vec2,
rot: f32,
drag: f32,
bullets: Vec<Bullet>,
last_shot: f32,
shot_interval: f32,
last_shot: u8,
shot_interval: u8,
pub brain: Option<NN>,
asteroids_data: Vec<f32>,
max_asteroids: usize,
debug: bool,
alive: bool,
pub color: Option<Color>,
pub lifespan: u32,
pub shots: u32,
}
@ -30,21 +32,25 @@ impl Player {
// Change scaling when passing inputs if this is changed
drag: 0.001,
shot_interval: 0.3,
shot_interval: 18,
alive: true,
debug: false,
..Default::default()
}
}
pub fn simulate(brain: NN, max_asteroids: usize) -> Self {
assert_eq!(
brain.config[0] - 1,
max_asteroids * 3 + 5,
"NN input size must match max_asteroids"
);
pub fn simulate(brain: Option<NN>, max_asteroids: usize) -> Self {
let mut p = Player::new();
p.brain = Some(brain);
if let Some(brain) = brain {
assert_eq!(
brain.config[0] - 1,
max_asteroids * 3 + 5,
"NN input size must match max_asteroids"
);
p.brain = Some(brain);
} else {
p.brain = Some(NN::new(vec![max_asteroids * 3 + 5, 16, 4]));
}
p.max_asteroids = max_asteroids;
p
}
@ -75,7 +81,8 @@ impl Player {
pub fn update(&mut self) {
self.lifespan += 1;
let mut mag = 0.;
self.last_shot += 1;
self.acc = 0.;
let mut keys = vec![false, false, false, false];
self.asteroids_data.resize(self.max_asteroids * 3, 0.);
let mut inputs = vec![
@ -87,11 +94,8 @@ impl Player {
];
inputs.append(self.asteroids_data.as_mut());
if let Some(brain) = &self.brain {
keys = brain
.feed_forward(inputs)
.iter()
.map(|&x| if x > 0. { true } else { false })
.collect();
// println!("{:?}", inputs);
keys = brain.feed_forward(inputs).iter().map(|&x| x > 0.).collect();
}
if is_key_down(KeyCode::Right) || keys[0] {
self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32;
@ -103,11 +107,11 @@ impl Player {
}
if is_key_down(KeyCode::Up) || keys[2] {
// Change scaling when passing inputs if this is changed
mag = 0.14;
self.acc = 0.14;
}
if is_key_down(KeyCode::Space) || keys[3] {
if self.shot_interval + self.last_shot < get_time() as f32 {
self.last_shot = get_time() as f32;
if self.last_shot > self.shot_interval {
self.last_shot = 0;
self.shots += 1;
self.bullets.push(Bullet {
pos: self.pos + self.dir.rotate(vec2(20., 0.)),
@ -121,7 +125,7 @@ impl Player {
self.debug = !self.debug;
}
self.vel += mag * self.dir - self.drag * self.vel.length() * self.vel;
self.vel += self.acc * self.dir - self.drag * self.vel.length() * self.vel;
self.pos += self.vel;
if self.pos.x.abs() > screen_width() * 0.5 + 10. {
self.pos.x *= -1.;
@ -139,6 +143,10 @@ impl Player {
}
pub fn draw(&self) {
let color = match self.color {
Some(c) => c,
None => Color::new(1., 1., 1., 0.1),
};
let p1 = self.pos + self.dir.rotate(vec2(20., 0.));
let p2 = self.pos + self.dir.rotate(vec2(-18., -12.667));
let p3 = self.pos + self.dir.rotate(vec2(-18., 12.667));
@ -147,11 +155,11 @@ impl Player {
let p6 = self.pos + self.dir.rotate(vec2(-25., 0.));
let p7 = self.pos + self.dir.rotate(vec2(-10., -6.));
let p8 = self.pos + self.dir.rotate(vec2(-10., 6.));
draw_line(p1.x, p1.y, p2.x, p2.y, 2., WHITE);
draw_line(p1.x, p1.y, p3.x, p3.y, 2., WHITE);
draw_line(p4.x, p4.y, p5.x, p5.y, 2., WHITE);
if is_key_down(KeyCode::Up) && gen_range(0., 1.) < 0.4 {
draw_triangle_lines(p6, p7, p8, 2., WHITE);
draw_line(p1.x, p1.y, p2.x, p2.y, 2., color);
draw_line(p1.x, p1.y, p3.x, p3.y, 2., color);
draw_line(p4.x, p4.y, p5.x, p5.y, 2., color);
if self.acc > 0. && gen_range(0., 1.) < 0.4 {
draw_triangle_lines(p6, p7, p8, 2., color);
}
if self.debug {

View File

@ -13,9 +13,7 @@ impl Population {
pub fn new(size: usize) -> Self {
Self {
size,
worlds: (0..size)
.map(|_| World::simulate(NN::new(vec![89, 16, 4])))
.collect(),
worlds: (0..size).map(|_| World::simulate(None)).collect(),
..Default::default()
}
}
@ -30,18 +28,17 @@ impl Population {
}
if !alive {
self.gen += 1;
println!("{}", self.gen);
self.next_gen();
}
}
pub fn draw(&self) {
for world in &self.worlds {
for world in self.worlds.iter().rev() {
if !world.over {
world.draw();
draw_text(
&format!("Gen: {}", self.gen),
-100. + screen_width() * 0.5,
-150. + screen_width() * 0.5,
30. - screen_height() * 0.5,
32.,
WHITE,
@ -51,26 +48,33 @@ impl Population {
}
pub fn next_gen(&mut self) {
let total = self.worlds.iter().fold(0., |acc, x| acc + x.fitness());
let total = self.worlds.iter().fold(0., |acc, x| acc + x.fitness);
self.worlds
.sort_by(|a, b| b.fitness().partial_cmp(&a.fitness()).unwrap());
let mut new_worlds = (0..self.size / 10)
.map(|i| World::simulate(self.worlds[i].see_brain().to_owned()))
.sort_by(|a, b| b.fitness.partial_cmp(&a.fitness).unwrap());
for i in &self.worlds {
println!("Fitness: {}", i.fitness);
}
println!("Gen: {}, Fitness: {}", self.gen, self.worlds[0].fitness);
// let mut new_worlds = vec![World::simulate(Some(self.worlds[0].see_brain().to_owned()))];
let mut new_worlds = (0..self.size / 20)
.map(|i| World::simulate(Some(self.worlds[i].see_brain().to_owned())))
.collect::<Vec<_>>();
// if is_key_down(KeyCode::K) {
new_worlds[0].set_best();
// }
// println!(
// "Total fitness: {} {} {}",
// total,
// self.worlds[0].fitness(),
// self.worlds[1].fitness()
// );
while new_worlds.len() < self.size {
let rands = (gen_range(0., total), gen_range(0., total));
// println!("rands: {} {}", rands.0, rands.1);
// println!("rands: {} {} {}", rands.0, rands.1, total);
let mut sum = 0.;
let (mut a, mut b) = (None, None);
for world in &self.worlds {
sum += world.fitness();
sum += world.fitness;
if a.is_none() && sum >= rands.0 {
a = Some(world.see_brain());
}
@ -80,12 +84,18 @@ impl Population {
}
// println!("{}", &a.unwrap().weights[0]);
// println!("{}", &b.unwrap().weights[0]);
if a.is_none() {
a = Some(self.worlds.last().unwrap().see_brain());
}
if b.is_none() {
b = Some(self.worlds.last().unwrap().see_brain());
}
let mut new_brain = NN::crossover(a.unwrap(), b.unwrap());
// println!("{}", &a.unwrap().weights[0]);
// println!("{}", &b.unwrap().weights[0]);
// println!("{}", &new_brain.weights[0]);
new_brain.mutate();
new_worlds.push(World::simulate(new_brain));
// println!("{}", &new_brain.weights[0]);
new_worlds.push(World::simulate(Some(new_brain)));
}
self.worlds = new_worlds;
}

View File

@ -12,6 +12,7 @@ pub struct World {
pub score: u32,
pub over: bool,
max_asteroids: usize,
pub fitness: f32,
}
impl World {
@ -23,7 +24,7 @@ impl World {
}
}
pub fn simulate(brain: NN) -> Self {
pub fn simulate(brain: Option<NN>) -> Self {
Self {
player: Player::simulate(brain, 28),
max_asteroids: 28,
@ -31,39 +32,55 @@ impl World {
}
}
pub fn set_best(&mut self) {
self.player.color = Some(RED);
}
pub fn see_brain(&self) -> &NN {
self.player.brain.as_ref().unwrap()
}
pub fn fitness(&self) -> f32 {
// println!(
// "{} {} {}",
// self.score as f32,
// self.player.lifespan as f32 * 0.001,
// if self.player.shots > 0 {
// self.score as f32 / self.player.shots as f32 * 5.
// } else {
// 0.
// }
// );
(self.score + 1) as f32
* 10.
* self.player.lifespan as f32
* if self.player.shots > 0 {
(self.score as f32 / self.player.shots as f32)
* (self.score as f32 / self.player.shots as f32)
} else {
1.
}
}
// fn calc_fitness(&mut self) {
// println!(
// "{} {} {}",
// self.score as f32,
// self.player.lifespan as f32 * 0.001,
// if self.player.shots > 0 {
// self.score as f32 / self.player.shots as f32 * 5.
// } else {
// 0.
// }
// );
// }
pub fn update(&mut self) {
self.player.update();
// if self.player.lifespan > 150 {
// self.fitness = 1.
// / ((self.player.pos * vec2(2. / screen_width(), 2. / screen_height()))
// .distance_squared(vec2(0., -1.))
// + self.player.vel.length_squared()
// * self.player.vel.length_squared()
// * 0.00006830134554
// + 1.);
// self.over = true;
// }
let mut to_add: Vec<Asteroid> = Vec::new();
for asteroid in &mut self.asteroids {
asteroid.update();
if self.player.check_player_collision(asteroid) {
self.over = true;
self.fitness = (self.score as f32 + 1.)
* if self.player.shots > 0 {
(self.score as f32 / self.player.shots as f32)
* (self.score as f32 / self.player.shots as f32)
} else {
1.
}
* self.player.lifespan as f32;
// self.fitness = self.player.lifespan as f32 * self.player.lifespan as f32 * 0.001;
// println!("{} {} {}", self.score, self.player.lifespan, self.fitness);
}
if self.player.check_bullet_collisions(asteroid) {
self.score += 1;