bug fixes
This commit is contained in:
parent
db2df770c6
commit
1c29b6d419
|
@ -89,7 +89,7 @@ impl Asteroid {
|
||||||
AsteroidSize::Medium => 1.2,
|
AsteroidSize::Medium => 1.2,
|
||||||
AsteroidSize::Small => 0.8,
|
AsteroidSize::Small => 0.8,
|
||||||
},
|
},
|
||||||
Color::new(1., 1., 1., 0.4),
|
Color::new(1., 1., 1., 1.),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,9 +16,9 @@ async fn main() {
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
set_camera(&cam);
|
set_camera(&cam);
|
||||||
let mut pop = Population::new(10);
|
let mut pop = Population::new(100);
|
||||||
let mut speedup = false;
|
let mut speedup = false;
|
||||||
// for _ in 0..100000 * 5 {
|
// for _ in 0..10000 * 10 {
|
||||||
// pop.update();
|
// pop.update();
|
||||||
// }
|
// }
|
||||||
loop {
|
loop {
|
||||||
|
|
|
@ -36,13 +36,16 @@ impl NN {
|
||||||
.iter()
|
.iter()
|
||||||
.zip(config.iter().skip(1))
|
.zip(config.iter().skip(1))
|
||||||
.map(|(&curr, &last)| {
|
.map(|(&curr, &last)| {
|
||||||
|
// let a = DMatrix::<f32>::new_random(last, curr + 1);
|
||||||
|
// println!("{}", a);
|
||||||
|
// a
|
||||||
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
|
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
|
||||||
* (2. / last as f32).sqrt()
|
* (2. / last as f32).sqrt()
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
|
|
||||||
activ_func: ActivationFunc::ReLU,
|
activ_func: ActivationFunc::ReLU,
|
||||||
mut_rate: 0.05,
|
mut_rate: 0.02,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,8 +67,10 @@ impl NN {
|
||||||
pub fn mutate(&mut self) {
|
pub fn mutate(&mut self) {
|
||||||
for weight in &mut self.weights {
|
for weight in &mut self.weights {
|
||||||
for ele in weight {
|
for ele in weight {
|
||||||
if gen_range(0., 1.) < 0.05 {
|
if gen_range(0., 1.) < self.mut_rate {
|
||||||
|
// *ele += gen_range(-1., 1.);
|
||||||
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
|
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
|
||||||
|
// *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,19 +5,21 @@ use macroquad::{prelude::*, rand::gen_range};
|
||||||
use crate::{asteroids::Asteroid, nn::NN};
|
use crate::{asteroids::Asteroid, nn::NN};
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct Player {
|
pub struct Player {
|
||||||
pos: Vec2,
|
pub pos: Vec2,
|
||||||
vel: Vec2,
|
pub vel: Vec2,
|
||||||
|
acc: f32,
|
||||||
dir: Vec2,
|
dir: Vec2,
|
||||||
rot: f32,
|
rot: f32,
|
||||||
drag: f32,
|
drag: f32,
|
||||||
bullets: Vec<Bullet>,
|
bullets: Vec<Bullet>,
|
||||||
last_shot: f32,
|
last_shot: u8,
|
||||||
shot_interval: f32,
|
shot_interval: u8,
|
||||||
pub brain: Option<NN>,
|
pub brain: Option<NN>,
|
||||||
asteroids_data: Vec<f32>,
|
asteroids_data: Vec<f32>,
|
||||||
max_asteroids: usize,
|
max_asteroids: usize,
|
||||||
debug: bool,
|
debug: bool,
|
||||||
alive: bool,
|
alive: bool,
|
||||||
|
pub color: Option<Color>,
|
||||||
pub lifespan: u32,
|
pub lifespan: u32,
|
||||||
pub shots: u32,
|
pub shots: u32,
|
||||||
}
|
}
|
||||||
|
@ -30,21 +32,25 @@ impl Player {
|
||||||
|
|
||||||
// Change scaling when passing inputs if this is changed
|
// Change scaling when passing inputs if this is changed
|
||||||
drag: 0.001,
|
drag: 0.001,
|
||||||
shot_interval: 0.3,
|
shot_interval: 18,
|
||||||
alive: true,
|
alive: true,
|
||||||
debug: false,
|
debug: false,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn simulate(brain: NN, max_asteroids: usize) -> Self {
|
pub fn simulate(brain: Option<NN>, max_asteroids: usize) -> Self {
|
||||||
assert_eq!(
|
|
||||||
brain.config[0] - 1,
|
|
||||||
max_asteroids * 3 + 5,
|
|
||||||
"NN input size must match max_asteroids"
|
|
||||||
);
|
|
||||||
let mut p = Player::new();
|
let mut p = Player::new();
|
||||||
p.brain = Some(brain);
|
if let Some(brain) = brain {
|
||||||
|
assert_eq!(
|
||||||
|
brain.config[0] - 1,
|
||||||
|
max_asteroids * 3 + 5,
|
||||||
|
"NN input size must match max_asteroids"
|
||||||
|
);
|
||||||
|
p.brain = Some(brain);
|
||||||
|
} else {
|
||||||
|
p.brain = Some(NN::new(vec![max_asteroids * 3 + 5, 16, 4]));
|
||||||
|
}
|
||||||
p.max_asteroids = max_asteroids;
|
p.max_asteroids = max_asteroids;
|
||||||
p
|
p
|
||||||
}
|
}
|
||||||
|
@ -75,7 +81,8 @@ impl Player {
|
||||||
|
|
||||||
pub fn update(&mut self) {
|
pub fn update(&mut self) {
|
||||||
self.lifespan += 1;
|
self.lifespan += 1;
|
||||||
let mut mag = 0.;
|
self.last_shot += 1;
|
||||||
|
self.acc = 0.;
|
||||||
let mut keys = vec![false, false, false, false];
|
let mut keys = vec![false, false, false, false];
|
||||||
self.asteroids_data.resize(self.max_asteroids * 3, 0.);
|
self.asteroids_data.resize(self.max_asteroids * 3, 0.);
|
||||||
let mut inputs = vec![
|
let mut inputs = vec![
|
||||||
|
@ -87,11 +94,8 @@ impl Player {
|
||||||
];
|
];
|
||||||
inputs.append(self.asteroids_data.as_mut());
|
inputs.append(self.asteroids_data.as_mut());
|
||||||
if let Some(brain) = &self.brain {
|
if let Some(brain) = &self.brain {
|
||||||
keys = brain
|
// println!("{:?}", inputs);
|
||||||
.feed_forward(inputs)
|
keys = brain.feed_forward(inputs).iter().map(|&x| x > 0.).collect();
|
||||||
.iter()
|
|
||||||
.map(|&x| if x > 0. { true } else { false })
|
|
||||||
.collect();
|
|
||||||
}
|
}
|
||||||
if is_key_down(KeyCode::Right) || keys[0] {
|
if is_key_down(KeyCode::Right) || keys[0] {
|
||||||
self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32;
|
self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32;
|
||||||
|
@ -103,11 +107,11 @@ impl Player {
|
||||||
}
|
}
|
||||||
if is_key_down(KeyCode::Up) || keys[2] {
|
if is_key_down(KeyCode::Up) || keys[2] {
|
||||||
// Change scaling when passing inputs if this is changed
|
// Change scaling when passing inputs if this is changed
|
||||||
mag = 0.14;
|
self.acc = 0.14;
|
||||||
}
|
}
|
||||||
if is_key_down(KeyCode::Space) || keys[3] {
|
if is_key_down(KeyCode::Space) || keys[3] {
|
||||||
if self.shot_interval + self.last_shot < get_time() as f32 {
|
if self.last_shot > self.shot_interval {
|
||||||
self.last_shot = get_time() as f32;
|
self.last_shot = 0;
|
||||||
self.shots += 1;
|
self.shots += 1;
|
||||||
self.bullets.push(Bullet {
|
self.bullets.push(Bullet {
|
||||||
pos: self.pos + self.dir.rotate(vec2(20., 0.)),
|
pos: self.pos + self.dir.rotate(vec2(20., 0.)),
|
||||||
|
@ -121,7 +125,7 @@ impl Player {
|
||||||
self.debug = !self.debug;
|
self.debug = !self.debug;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.vel += mag * self.dir - self.drag * self.vel.length() * self.vel;
|
self.vel += self.acc * self.dir - self.drag * self.vel.length() * self.vel;
|
||||||
self.pos += self.vel;
|
self.pos += self.vel;
|
||||||
if self.pos.x.abs() > screen_width() * 0.5 + 10. {
|
if self.pos.x.abs() > screen_width() * 0.5 + 10. {
|
||||||
self.pos.x *= -1.;
|
self.pos.x *= -1.;
|
||||||
|
@ -139,6 +143,10 @@ impl Player {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn draw(&self) {
|
pub fn draw(&self) {
|
||||||
|
let color = match self.color {
|
||||||
|
Some(c) => c,
|
||||||
|
None => Color::new(1., 1., 1., 0.1),
|
||||||
|
};
|
||||||
let p1 = self.pos + self.dir.rotate(vec2(20., 0.));
|
let p1 = self.pos + self.dir.rotate(vec2(20., 0.));
|
||||||
let p2 = self.pos + self.dir.rotate(vec2(-18., -12.667));
|
let p2 = self.pos + self.dir.rotate(vec2(-18., -12.667));
|
||||||
let p3 = self.pos + self.dir.rotate(vec2(-18., 12.667));
|
let p3 = self.pos + self.dir.rotate(vec2(-18., 12.667));
|
||||||
|
@ -147,11 +155,11 @@ impl Player {
|
||||||
let p6 = self.pos + self.dir.rotate(vec2(-25., 0.));
|
let p6 = self.pos + self.dir.rotate(vec2(-25., 0.));
|
||||||
let p7 = self.pos + self.dir.rotate(vec2(-10., -6.));
|
let p7 = self.pos + self.dir.rotate(vec2(-10., -6.));
|
||||||
let p8 = self.pos + self.dir.rotate(vec2(-10., 6.));
|
let p8 = self.pos + self.dir.rotate(vec2(-10., 6.));
|
||||||
draw_line(p1.x, p1.y, p2.x, p2.y, 2., WHITE);
|
draw_line(p1.x, p1.y, p2.x, p2.y, 2., color);
|
||||||
draw_line(p1.x, p1.y, p3.x, p3.y, 2., WHITE);
|
draw_line(p1.x, p1.y, p3.x, p3.y, 2., color);
|
||||||
draw_line(p4.x, p4.y, p5.x, p5.y, 2., WHITE);
|
draw_line(p4.x, p4.y, p5.x, p5.y, 2., color);
|
||||||
if is_key_down(KeyCode::Up) && gen_range(0., 1.) < 0.4 {
|
if self.acc > 0. && gen_range(0., 1.) < 0.4 {
|
||||||
draw_triangle_lines(p6, p7, p8, 2., WHITE);
|
draw_triangle_lines(p6, p7, p8, 2., color);
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.debug {
|
if self.debug {
|
||||||
|
|
|
@ -13,9 +13,7 @@ impl Population {
|
||||||
pub fn new(size: usize) -> Self {
|
pub fn new(size: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
size,
|
size,
|
||||||
worlds: (0..size)
|
worlds: (0..size).map(|_| World::simulate(None)).collect(),
|
||||||
.map(|_| World::simulate(NN::new(vec![89, 16, 4])))
|
|
||||||
.collect(),
|
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -30,18 +28,17 @@ impl Population {
|
||||||
}
|
}
|
||||||
if !alive {
|
if !alive {
|
||||||
self.gen += 1;
|
self.gen += 1;
|
||||||
println!("{}", self.gen);
|
|
||||||
self.next_gen();
|
self.next_gen();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn draw(&self) {
|
pub fn draw(&self) {
|
||||||
for world in &self.worlds {
|
for world in self.worlds.iter().rev() {
|
||||||
if !world.over {
|
if !world.over {
|
||||||
world.draw();
|
world.draw();
|
||||||
draw_text(
|
draw_text(
|
||||||
&format!("Gen: {}", self.gen),
|
&format!("Gen: {}", self.gen),
|
||||||
-100. + screen_width() * 0.5,
|
-150. + screen_width() * 0.5,
|
||||||
30. - screen_height() * 0.5,
|
30. - screen_height() * 0.5,
|
||||||
32.,
|
32.,
|
||||||
WHITE,
|
WHITE,
|
||||||
|
@ -51,26 +48,33 @@ impl Population {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn next_gen(&mut self) {
|
pub fn next_gen(&mut self) {
|
||||||
let total = self.worlds.iter().fold(0., |acc, x| acc + x.fitness());
|
let total = self.worlds.iter().fold(0., |acc, x| acc + x.fitness);
|
||||||
self.worlds
|
self.worlds
|
||||||
.sort_by(|a, b| b.fitness().partial_cmp(&a.fitness()).unwrap());
|
.sort_by(|a, b| b.fitness.partial_cmp(&a.fitness).unwrap());
|
||||||
let mut new_worlds = (0..self.size / 10)
|
for i in &self.worlds {
|
||||||
.map(|i| World::simulate(self.worlds[i].see_brain().to_owned()))
|
println!("Fitness: {}", i.fitness);
|
||||||
|
}
|
||||||
|
println!("Gen: {}, Fitness: {}", self.gen, self.worlds[0].fitness);
|
||||||
|
// let mut new_worlds = vec![World::simulate(Some(self.worlds[0].see_brain().to_owned()))];
|
||||||
|
let mut new_worlds = (0..self.size / 20)
|
||||||
|
.map(|i| World::simulate(Some(self.worlds[i].see_brain().to_owned())))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
// if is_key_down(KeyCode::K) {
|
||||||
|
new_worlds[0].set_best();
|
||||||
|
// }
|
||||||
// println!(
|
// println!(
|
||||||
// "Total fitness: {} {} {}",
|
// "Total fitness: {} {} {}",
|
||||||
// total,
|
// total,
|
||||||
// self.worlds[0].fitness(),
|
// self.worlds[0].fitness(),
|
||||||
// self.worlds[1].fitness()
|
// self.worlds[1].fitness()
|
||||||
// );
|
// );
|
||||||
|
|
||||||
while new_worlds.len() < self.size {
|
while new_worlds.len() < self.size {
|
||||||
let rands = (gen_range(0., total), gen_range(0., total));
|
let rands = (gen_range(0., total), gen_range(0., total));
|
||||||
// println!("rands: {} {}", rands.0, rands.1);
|
// println!("rands: {} {} {}", rands.0, rands.1, total);
|
||||||
let mut sum = 0.;
|
let mut sum = 0.;
|
||||||
let (mut a, mut b) = (None, None);
|
let (mut a, mut b) = (None, None);
|
||||||
for world in &self.worlds {
|
for world in &self.worlds {
|
||||||
sum += world.fitness();
|
sum += world.fitness;
|
||||||
if a.is_none() && sum >= rands.0 {
|
if a.is_none() && sum >= rands.0 {
|
||||||
a = Some(world.see_brain());
|
a = Some(world.see_brain());
|
||||||
}
|
}
|
||||||
|
@ -80,12 +84,18 @@ impl Population {
|
||||||
}
|
}
|
||||||
// println!("{}", &a.unwrap().weights[0]);
|
// println!("{}", &a.unwrap().weights[0]);
|
||||||
// println!("{}", &b.unwrap().weights[0]);
|
// println!("{}", &b.unwrap().weights[0]);
|
||||||
|
if a.is_none() {
|
||||||
|
a = Some(self.worlds.last().unwrap().see_brain());
|
||||||
|
}
|
||||||
|
if b.is_none() {
|
||||||
|
b = Some(self.worlds.last().unwrap().see_brain());
|
||||||
|
}
|
||||||
let mut new_brain = NN::crossover(a.unwrap(), b.unwrap());
|
let mut new_brain = NN::crossover(a.unwrap(), b.unwrap());
|
||||||
// println!("{}", &a.unwrap().weights[0]);
|
// println!("{}", &a.unwrap().weights[0]);
|
||||||
// println!("{}", &b.unwrap().weights[0]);
|
// println!("{}", &b.unwrap().weights[0]);
|
||||||
// println!("{}", &new_brain.weights[0]);
|
|
||||||
new_brain.mutate();
|
new_brain.mutate();
|
||||||
new_worlds.push(World::simulate(new_brain));
|
// println!("{}", &new_brain.weights[0]);
|
||||||
|
new_worlds.push(World::simulate(Some(new_brain)));
|
||||||
}
|
}
|
||||||
self.worlds = new_worlds;
|
self.worlds = new_worlds;
|
||||||
}
|
}
|
||||||
|
|
61
src/world.rs
61
src/world.rs
|
@ -12,6 +12,7 @@ pub struct World {
|
||||||
pub score: u32,
|
pub score: u32,
|
||||||
pub over: bool,
|
pub over: bool,
|
||||||
max_asteroids: usize,
|
max_asteroids: usize,
|
||||||
|
pub fitness: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl World {
|
impl World {
|
||||||
|
@ -23,7 +24,7 @@ impl World {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn simulate(brain: NN) -> Self {
|
pub fn simulate(brain: Option<NN>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
player: Player::simulate(brain, 28),
|
player: Player::simulate(brain, 28),
|
||||||
max_asteroids: 28,
|
max_asteroids: 28,
|
||||||
|
@ -31,39 +32,55 @@ impl World {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_best(&mut self) {
|
||||||
|
self.player.color = Some(RED);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn see_brain(&self) -> &NN {
|
pub fn see_brain(&self) -> &NN {
|
||||||
self.player.brain.as_ref().unwrap()
|
self.player.brain.as_ref().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fitness(&self) -> f32 {
|
// fn calc_fitness(&mut self) {
|
||||||
// println!(
|
// println!(
|
||||||
// "{} {} {}",
|
// "{} {} {}",
|
||||||
// self.score as f32,
|
// self.score as f32,
|
||||||
// self.player.lifespan as f32 * 0.001,
|
// self.player.lifespan as f32 * 0.001,
|
||||||
// if self.player.shots > 0 {
|
// if self.player.shots > 0 {
|
||||||
// self.score as f32 / self.player.shots as f32 * 5.
|
// self.score as f32 / self.player.shots as f32 * 5.
|
||||||
// } else {
|
// } else {
|
||||||
// 0.
|
// 0.
|
||||||
// }
|
// }
|
||||||
// );
|
// );
|
||||||
(self.score + 1) as f32
|
// }
|
||||||
* 10.
|
|
||||||
* self.player.lifespan as f32
|
|
||||||
* if self.player.shots > 0 {
|
|
||||||
(self.score as f32 / self.player.shots as f32)
|
|
||||||
* (self.score as f32 / self.player.shots as f32)
|
|
||||||
} else {
|
|
||||||
1.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn update(&mut self) {
|
pub fn update(&mut self) {
|
||||||
self.player.update();
|
self.player.update();
|
||||||
|
// if self.player.lifespan > 150 {
|
||||||
|
// self.fitness = 1.
|
||||||
|
// / ((self.player.pos * vec2(2. / screen_width(), 2. / screen_height()))
|
||||||
|
// .distance_squared(vec2(0., -1.))
|
||||||
|
// + self.player.vel.length_squared()
|
||||||
|
// * self.player.vel.length_squared()
|
||||||
|
// * 0.00006830134554
|
||||||
|
// + 1.);
|
||||||
|
// self.over = true;
|
||||||
|
// }
|
||||||
let mut to_add: Vec<Asteroid> = Vec::new();
|
let mut to_add: Vec<Asteroid> = Vec::new();
|
||||||
for asteroid in &mut self.asteroids {
|
for asteroid in &mut self.asteroids {
|
||||||
asteroid.update();
|
asteroid.update();
|
||||||
if self.player.check_player_collision(asteroid) {
|
if self.player.check_player_collision(asteroid) {
|
||||||
self.over = true;
|
self.over = true;
|
||||||
|
self.fitness = (self.score as f32 + 1.)
|
||||||
|
* if self.player.shots > 0 {
|
||||||
|
(self.score as f32 / self.player.shots as f32)
|
||||||
|
* (self.score as f32 / self.player.shots as f32)
|
||||||
|
} else {
|
||||||
|
1.
|
||||||
|
}
|
||||||
|
* self.player.lifespan as f32;
|
||||||
|
// self.fitness = self.player.lifespan as f32 * self.player.lifespan as f32 * 0.001;
|
||||||
|
|
||||||
|
// println!("{} {} {}", self.score, self.player.lifespan, self.fitness);
|
||||||
}
|
}
|
||||||
if self.player.check_bullet_collisions(asteroid) {
|
if self.player.check_bullet_collisions(asteroid) {
|
||||||
self.score += 1;
|
self.score += 1;
|
||||||
|
|
Loading…
Reference in New Issue