finally works T_T im ded

This commit is contained in:
sparshg 2023-01-05 01:27:59 +05:30
parent 32df37e391
commit 8196383515
4 changed files with 145 additions and 86 deletions

View File

@ -4,14 +4,16 @@ use r::Rng;
use rand_distr::StandardNormal; use rand_distr::StandardNormal;
extern crate rand as r; extern crate rand as r;
#[derive(PartialEq, Debug, Clone, Copy)] #[derive(PartialEq, Debug, Clone, Copy, Default)]
enum ActivationFunc { enum ActivationFunc {
Sigmoid, Sigmoid,
Tanh, Tanh,
#[default]
ReLU, ReLU,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug, Default)]
pub struct NN { pub struct NN {
pub config: Vec<usize>, pub config: Vec<usize>,
pub weights: Vec<DMatrix<f32>>, pub weights: Vec<DMatrix<f32>>,
@ -36,7 +38,8 @@ impl NN {
.iter() .iter()
.zip(config.iter().skip(1)) .zip(config.iter().skip(1))
.map(|(&curr, &last)| { .map(|(&curr, &last)| {
// let a = DMatrix::<f32>::new_random(last, curr + 1); // DMatrix::from_fn(last, curr + 1, |_, _| gen_range(-1., 1.))
// DMatrix::<f32>::new_random(last, curr + 1)
// println!("{}", a); // println!("{}", a);
// a // a
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng) DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
@ -44,8 +47,8 @@ impl NN {
}) })
.collect(), .collect(),
activ_func: ActivationFunc::ReLU, mut_rate: 0.04,
mut_rate: 0.02, ..Default::default()
} }
} }
@ -61,6 +64,7 @@ impl NN {
.zip(b.weights.iter()) .zip(b.weights.iter())
.map(|(m1, m2)| m1.zip_map(m2, |ele1, ele2| if r::random() { ele1 } else { ele2 })) .map(|(m1, m2)| m1.zip_map(m2, |ele1, ele2| if r::random() { ele1 } else { ele2 }))
.collect(), .collect(),
..Default::default()
} }
} }
@ -69,7 +73,8 @@ impl NN {
for ele in weight { for ele in weight {
if gen_range(0., 1.) < self.mut_rate { if gen_range(0., 1.) < self.mut_rate {
// *ele += gen_range(-1., 1.); // *ele += gen_range(-1., 1.);
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal); *ele = gen_range(-1., 1.);
// *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
// *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal); // *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
} }
} }
@ -77,6 +82,7 @@ impl NN {
} }
pub fn feed_forward(&self, inputs: Vec<f32>) -> Vec<f32> { pub fn feed_forward(&self, inputs: Vec<f32>) -> Vec<f32> {
// println!("inputs: {:?}", inputs);
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs); let mut y = DMatrix::from_vec(inputs.len(), 1, inputs);
for i in 0..self.config.len() - 1 { for i in 0..self.config.len() - 1 {
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| { y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
@ -86,6 +92,8 @@ impl NN {
ActivationFunc::Tanh => x.tanh(), ActivationFunc::Tanh => x.tanh(),
} }
}); });
// println!("w{}: {}", i, self.weights[i]);
// println!("y: {}", y);
} }
y.column(0).data.into_slice().to_vec() y.column(0).data.into_slice().to_vec()
} }

View File

@ -1,7 +1,7 @@
use std::{f32::consts::PI, f64::consts::TAU}; use std::{f32::consts::PI, f64::consts::TAU};
use macroquad::{prelude::*, rand::gen_range}; use macroquad::{prelude::*, rand::gen_range};
use nalgebra::{max, partial_max}; use nalgebra::{max, partial_max, partial_min};
use crate::{asteroids::Asteroid, nn::NN}; use crate::{asteroids::Asteroid, nn::NN};
#[derive(Default)] #[derive(Default)]
@ -9,7 +9,7 @@ pub struct Player {
pub pos: Vec2, pub pos: Vec2,
pub vel: Vec2, pub vel: Vec2,
acc: f32, acc: f32,
dir: Vec2, pub dir: Vec2,
rot: f32, rot: f32,
drag: f32, drag: f32,
bullets: Vec<Bullet>, bullets: Vec<Bullet>,
@ -36,7 +36,7 @@ impl Player {
alive: true, alive: true,
debug: false, debug: false,
shots: 4, shots: 4,
raycasts: vec![0.; 8], raycasts: vec![f32::MAX; 3],
..Default::default() ..Default::default()
} }
} }
@ -44,39 +44,52 @@ impl Player {
pub fn simulate(brain: Option<NN>) -> Self { pub fn simulate(brain: Option<NN>) -> Self {
let mut p = Player::new(); let mut p = Player::new();
if let Some(brain) = brain { if let Some(brain) = brain {
assert_eq!( // assert_eq!(
brain.config[0] - 1, // brain.config[0] - 1,
8 + 0, // 8 + 5,
"NN input size must match max_asteroids" // "NN input size must match max_asteroids"
); // );
p.brain = Some(brain); p.brain = Some(brain);
} else { } else {
p.brain = Some(NN::new(vec![8 + 0, 16, 4])); p.brain = Some(NN::new(vec![3, 8, 8, 4]));
} }
p p
} }
pub fn check_player_collision(&mut self, asteroid: &mut Asteroid) -> bool { pub fn check_player_collision(&mut self, asteroid: &mut Asteroid) -> bool {
// self.asteroids_data.extend([ // self.raycasts.extend([
// asteroid.pos.x / screen_width() + 0.5, if (asteroid.pos).distance_squared(self.pos)
// asteroid.pos.y / screen_height() + 0.5, < vec2(
// asteroid.radius / 50., self.raycasts[0] * screen_width(),
// ]); self.raycasts[1] * screen_height(),
let v = asteroid.pos - self.pos; )
for i in 0..4 { .distance_squared(self.pos)
let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir); {
let cross = v.perp_dot(dir); self.raycasts[0] = asteroid.pos.x / screen_width();
let dot = v.dot(dir); self.raycasts[1] = asteroid.pos.y / screen_height();
if cross.abs() <= asteroid.radius { self.raycasts[2] = asteroid.radius / 50.;
self.raycasts[if dot >= 0. { i } else { i + 4 }] = *partial_max(
&self.raycasts[if dot >= 0. { i } else { i + 4 }],
&(1. / (dot.abs()
- (asteroid.radius * asteroid.radius - cross * cross).sqrt())),
)
.unwrap();
}
} }
if asteroid.check_collision(self.pos, 8.) { // ]);
// if self.raycasts[0] > (asteroid.pos - self.pos).length_squared() {
// self.raycasts[0] = (asteroid.pos - self.pos).length_squared();
// self.raycasts[1] = Vec2::angle_between(asteroid.pos - self.pos, self.dir).sin();
// self.raycasts[2] = Vec2::angle_between(asteroid.pos - self.pos, self.dir).cos();
// }
// let v = asteroid.pos - self.pos;
// for i in 0..4 {
// let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir);
// let cross = v.perp_dot(dir);
// let dot = v.dot(dir);
// if cross.abs() <= asteroid.radius {
// self.raycasts[if dot >= 0. { i } else { i + 4 }] = *partial_max(
// &self.raycasts[if dot >= 0. { i } else { i + 4 }],
// &(1. / (dot.abs()
// - (asteroid.radius * asteroid.radius - cross * cross).sqrt())),
// )
// .unwrap();
// }
// }
if asteroid.check_collision(self.pos, 8.) || self.lifespan > 2000 {
self.alive = false; self.alive = false;
return true; return true;
} }
@ -99,22 +112,39 @@ impl Player {
self.last_shot += 1; self.last_shot += 1;
self.acc = 0.; self.acc = 0.;
let mut keys = vec![false, false, false, false]; let mut keys = vec![false, false, false, false];
// self.asteroids_data.resize(self.max_asteroids * 3, 0.); let mut inputs = vec![
// let mut inputs = vec![ (vec2(
// self.pos.x / screen_width() + 0.5, self.raycasts[0] * screen_width(),
// self.pos.y / screen_height() + 0.5, self.raycasts[1] * screen_height(),
// self.vel.x / 11., ) - self.pos)
// self.vel.y / 11., .length()
// self.rot / TAU as f32, * 0.707
// ]; / screen_width(),
// inputs.append(self.raycasts.as_mut()); // self.raycasts[0] - self.pos.x / screen_width(),
let inputs = self.raycasts.clone(); // self.raycasts[1] - self.pos.y / screen_height(),
self.dir.angle_between(
vec2(
self.raycasts[0] * screen_width(),
self.raycasts[1] * screen_height(),
) - self.pos,
),
// self.vel.x / 11.,
// self.vel.y / 11.,
self.rot, // self.rot.sin(),
// self.rot.cos(),
];
// self.raycasts.resize(3, 0.);
// inputs.append(self.raycasts.clone().as_mut());
// println!("inputs: {:?}", inputs);
// let inputs = self.raycasts.clone();
// inputs.append(self.asteroids_data.as_mut()); // inputs.append(self.asteroids_data.as_mut());
if let Some(brain) = &self.brain { if let Some(brain) = &self.brain {
// println!("{:?}", inputs); // println!("{:?}", brain.feed_forward(inputs.clone()));
keys = brain.feed_forward(inputs).iter().map(|&x| x > 0.).collect(); keys = brain.feed_forward(inputs).iter().map(|&x| x > 0.).collect();
} }
self.raycasts = vec![0.; 8];
if is_key_down(KeyCode::Right) && self.debug || keys[0] { if is_key_down(KeyCode::Right) && self.debug || keys[0] {
self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32; self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32;
self.dir = vec2(self.rot.cos(), self.rot.sin()); self.dir = vec2(self.rot.cos(), self.rot.sin());
@ -132,8 +162,8 @@ impl Player {
self.last_shot = 0; self.last_shot = 0;
self.shots += 1; self.shots += 1;
self.bullets.push(Bullet { self.bullets.push(Bullet {
pos: self.pos + self.dir.rotate(vec2(20., 0.)), pos: self.pos + self.dir * 20.,
vel: self.dir.rotate(vec2(8.5, 0.)) + self.vel, vel: self.dir * 8.5 + self.vel,
alive: true, alive: true,
}); });
} }
@ -158,6 +188,7 @@ impl Player {
self.bullets.retain(|b| { self.bullets.retain(|b| {
b.alive && b.pos.x.abs() * 2. < screen_width() && b.pos.y.abs() * 2. < screen_height() b.alive && b.pos.x.abs() * 2. < screen_width() && b.pos.y.abs() * 2. < screen_height()
}); });
self.raycasts = vec![100.; 3];
} }
pub fn draw(&self) { pub fn draw(&self) {
@ -181,21 +212,34 @@ impl Player {
} }
if self.debug { if self.debug {
// for a in self.asteroids_data.chunks(3) { for a in self.raycasts.chunks(3) {
// draw_circle_lines(a[0], a[1], a[2], 1., GRAY); draw_circle_lines(
// draw_line(self.pos.x, self.pos.y, a[0], a[1], 1., GRAY) a[0] * screen_width(),
// } a[1] * screen_height(),
for (i, r) in self.raycasts.iter().enumerate() { a[2] * 50.,
let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir);
draw_line(
self.pos.x,
self.pos.y,
self.pos.x + dir.x / r,
self.pos.y + dir.y / r,
1., 1.,
GRAY, GRAY,
); );
draw_line(
self.pos.x,
self.pos.y,
a[0] * screen_width(),
a[1] * screen_height(),
1.,
GRAY,
)
} }
// for (i, r) in self.raycasts.iter().enumerate() {
// let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir);
// draw_line(
// self.pos.x,
// self.pos.y,
// self.pos.x + dir.x / r,
// self.pos.y + dir.y / r,
// 1.,
// GRAY,
// );
// }
} }
for bullet in &self.bullets { for bullet in &self.bullets {

View File

@ -6,6 +6,7 @@ use crate::{nn::NN, world::World};
pub struct Population { pub struct Population {
size: usize, size: usize,
gen: i32, gen: i32,
best: bool,
pub worlds: Vec<World>, pub worlds: Vec<World>,
} }
@ -30,21 +31,28 @@ impl Population {
self.gen += 1; self.gen += 1;
self.next_gen(); self.next_gen();
} }
if is_key_pressed(KeyCode::Z) {
self.best = !self.best;
}
} }
pub fn draw(&self) { pub fn draw(&self) {
for world in self.worlds.iter().rev() { for world in self.worlds.iter().rev() {
if !world.over { if self.best {
if world.player.color.is_some() {
world.draw();
}
} else if !world.over {
world.draw(); world.draw();
draw_text(
&format!("Gen: {}", self.gen),
-150. + screen_width() * 0.5,
30. - screen_height() * 0.5,
32.,
WHITE,
);
} }
} }
draw_text(
&format!("Gen: {}", self.gen),
-150. + screen_width() * 0.5,
30. - screen_height() * 0.5,
32.,
WHITE,
);
} }
pub fn next_gen(&mut self) { pub fn next_gen(&mut self) {

View File

@ -7,9 +7,9 @@ use macroquad::{prelude::*, rand::gen_range};
#[derive(Default)] #[derive(Default)]
pub struct World { pub struct World {
player: Player, pub player: Player,
asteroids: Vec<Asteroid>, asteroids: Vec<Asteroid>,
pub score: u32, pub score: f32,
pub over: bool, pub over: bool,
max_asteroids: usize, max_asteroids: usize,
pub fitness: f32, pub fitness: f32,
@ -20,7 +20,7 @@ impl World {
Self { Self {
player: Player::new(), player: Player::new(),
max_asteroids: 28, max_asteroids: 28,
score: 1, score: 0.,
..Default::default() ..Default::default()
} }
} }
@ -29,7 +29,7 @@ impl World {
Self { Self {
player: Player::simulate(brain), player: Player::simulate(brain),
max_asteroids: 28, max_asteroids: 28,
score: 1, score: 0.,
asteroids: vec![ asteroids: vec![
Asteroid::new_to(vec2(0., 0.), 1.5, AsteroidSize::Large), Asteroid::new_to(vec2(0., 0.), 1.5, AsteroidSize::Large),
Asteroid::new(AsteroidSize::Large), Asteroid::new(AsteroidSize::Large),
@ -79,17 +79,13 @@ impl World {
asteroid.update(); asteroid.update();
if self.player.check_player_collision(asteroid) { if self.player.check_player_collision(asteroid) {
self.over = true; self.over = true;
self.fitness = (self.score as f32 + 1.) self.fitness =
* (self.score as f32 / self.player.shots as f32) (self.score / self.player.shots as f32).powi(2) * self.player.lifespan as f32;
* (self.score as f32 / self.player.shots as f32)
* self.player.lifespan as f32
* 0.01;
// self.fitness = self.player.lifespan as f32 * self.player.lifespan as f32 * 0.001;
// println!("{} {} {}", self.score, self.player.lifespan, self.fitness); // println!("{} {} {}", self.score, self.player.lifespan, self.fitness);
} }
if self.player.check_bullet_collisions(asteroid) { if self.player.check_bullet_collisions(asteroid) {
self.score += 1; self.score += 1.;
match asteroid.size { match asteroid.size {
AsteroidSize::Large => { AsteroidSize::Large => {
let rand = vec2(gen_range(-0.8, 0.8), gen_range(-0.8, 0.8)); let rand = vec2(gen_range(-0.8, 0.8), gen_range(-0.8, 0.8));
@ -130,7 +126,8 @@ impl World {
// AsteroidSize::Small => 1, // AsteroidSize::Small => 1,
// } // }
// }) < self.max_asteroids // }) < self.max_asteroids
if self.player.lifespan % 400 == 0 { // {
if self.player.lifespan % 200 == 0 {
self.asteroids self.asteroids
.push(Asteroid::new_to(self.player.pos, 1.5, AsteroidSize::Large)); .push(Asteroid::new_to(self.player.pos, 1.5, AsteroidSize::Large));
} }
@ -141,12 +138,14 @@ impl World {
for asteroid in &self.asteroids { for asteroid in &self.asteroids {
asteroid.draw(); asteroid.draw();
} }
// draw_text( draw_text(
// &format!("Score {}", self.score), &format!("{}", (self.score / self.player.shots as f32).powi(2) as f32),
// 20. - screen_width() * 0.5, // 20. - screen_width() * 0.5,
// 30. - screen_height() * 0.5, // 30. - screen_height() * 0.5,
// 32., self.player.pos.x - 20.,
// WHITE, self.player.pos.y - 20.,
// ); 12.,
WHITE,
);
} }
} }