finally works T_T im ded

This commit is contained in:
sparshg 2023-01-05 01:27:59 +05:30
parent 32df37e391
commit 8196383515
4 changed files with 145 additions and 86 deletions

View File

@ -4,14 +4,16 @@ use r::Rng;
use rand_distr::StandardNormal;
extern crate rand as r;
#[derive(PartialEq, Debug, Clone, Copy)]
#[derive(PartialEq, Debug, Clone, Copy, Default)]
enum ActivationFunc {
Sigmoid,
Tanh,
#[default]
ReLU,
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Default)]
pub struct NN {
pub config: Vec<usize>,
pub weights: Vec<DMatrix<f32>>,
@ -36,7 +38,8 @@ impl NN {
.iter()
.zip(config.iter().skip(1))
.map(|(&curr, &last)| {
// let a = DMatrix::<f32>::new_random(last, curr + 1);
// DMatrix::from_fn(last, curr + 1, |_, _| gen_range(-1., 1.))
// DMatrix::<f32>::new_random(last, curr + 1)
// println!("{}", a);
// a
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
@ -44,8 +47,8 @@ impl NN {
})
.collect(),
activ_func: ActivationFunc::ReLU,
mut_rate: 0.02,
mut_rate: 0.04,
..Default::default()
}
}
@ -61,6 +64,7 @@ impl NN {
.zip(b.weights.iter())
.map(|(m1, m2)| m1.zip_map(m2, |ele1, ele2| if r::random() { ele1 } else { ele2 }))
.collect(),
..Default::default()
}
}
@ -69,7 +73,8 @@ impl NN {
for ele in weight {
if gen_range(0., 1.) < self.mut_rate {
// *ele += gen_range(-1., 1.);
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
*ele = gen_range(-1., 1.);
// *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
// *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
}
}
@ -77,6 +82,7 @@ impl NN {
}
pub fn feed_forward(&self, inputs: Vec<f32>) -> Vec<f32> {
// println!("inputs: {:?}", inputs);
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs);
for i in 0..self.config.len() - 1 {
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
@ -86,6 +92,8 @@ impl NN {
ActivationFunc::Tanh => x.tanh(),
}
});
// println!("w{}: {}", i, self.weights[i]);
// println!("y: {}", y);
}
y.column(0).data.into_slice().to_vec()
}

View File

@ -1,7 +1,7 @@
use std::{f32::consts::PI, f64::consts::TAU};
use macroquad::{prelude::*, rand::gen_range};
use nalgebra::{max, partial_max};
use nalgebra::{max, partial_max, partial_min};
use crate::{asteroids::Asteroid, nn::NN};
#[derive(Default)]
@ -9,7 +9,7 @@ pub struct Player {
pub pos: Vec2,
pub vel: Vec2,
acc: f32,
dir: Vec2,
pub dir: Vec2,
rot: f32,
drag: f32,
bullets: Vec<Bullet>,
@ -36,7 +36,7 @@ impl Player {
alive: true,
debug: false,
shots: 4,
raycasts: vec![0.; 8],
raycasts: vec![f32::MAX; 3],
..Default::default()
}
}
@ -44,39 +44,52 @@ impl Player {
pub fn simulate(brain: Option<NN>) -> Self {
let mut p = Player::new();
if let Some(brain) = brain {
assert_eq!(
brain.config[0] - 1,
8 + 0,
"NN input size must match max_asteroids"
);
// assert_eq!(
// brain.config[0] - 1,
// 8 + 5,
// "NN input size must match max_asteroids"
// );
p.brain = Some(brain);
} else {
p.brain = Some(NN::new(vec![8 + 0, 16, 4]));
p.brain = Some(NN::new(vec![3, 8, 8, 4]));
}
p
}
pub fn check_player_collision(&mut self, asteroid: &mut Asteroid) -> bool {
// self.asteroids_data.extend([
// asteroid.pos.x / screen_width() + 0.5,
// asteroid.pos.y / screen_height() + 0.5,
// asteroid.radius / 50.,
// ]);
let v = asteroid.pos - self.pos;
for i in 0..4 {
let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir);
let cross = v.perp_dot(dir);
let dot = v.dot(dir);
if cross.abs() <= asteroid.radius {
self.raycasts[if dot >= 0. { i } else { i + 4 }] = *partial_max(
&self.raycasts[if dot >= 0. { i } else { i + 4 }],
&(1. / (dot.abs()
- (asteroid.radius * asteroid.radius - cross * cross).sqrt())),
// self.raycasts.extend([
if (asteroid.pos).distance_squared(self.pos)
< vec2(
self.raycasts[0] * screen_width(),
self.raycasts[1] * screen_height(),
)
.unwrap();
.distance_squared(self.pos)
{
self.raycasts[0] = asteroid.pos.x / screen_width();
self.raycasts[1] = asteroid.pos.y / screen_height();
self.raycasts[2] = asteroid.radius / 50.;
}
}
if asteroid.check_collision(self.pos, 8.) {
// ]);
// if self.raycasts[0] > (asteroid.pos - self.pos).length_squared() {
// self.raycasts[0] = (asteroid.pos - self.pos).length_squared();
// self.raycasts[1] = Vec2::angle_between(asteroid.pos - self.pos, self.dir).sin();
// self.raycasts[2] = Vec2::angle_between(asteroid.pos - self.pos, self.dir).cos();
// }
// let v = asteroid.pos - self.pos;
// for i in 0..4 {
// let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir);
// let cross = v.perp_dot(dir);
// let dot = v.dot(dir);
// if cross.abs() <= asteroid.radius {
// self.raycasts[if dot >= 0. { i } else { i + 4 }] = *partial_max(
// &self.raycasts[if dot >= 0. { i } else { i + 4 }],
// &(1. / (dot.abs()
// - (asteroid.radius * asteroid.radius - cross * cross).sqrt())),
// )
// .unwrap();
// }
// }
if asteroid.check_collision(self.pos, 8.) || self.lifespan > 2000 {
self.alive = false;
return true;
}
@ -99,22 +112,39 @@ impl Player {
self.last_shot += 1;
self.acc = 0.;
let mut keys = vec![false, false, false, false];
// self.asteroids_data.resize(self.max_asteroids * 3, 0.);
// let mut inputs = vec![
// self.pos.x / screen_width() + 0.5,
// self.pos.y / screen_height() + 0.5,
let mut inputs = vec![
(vec2(
self.raycasts[0] * screen_width(),
self.raycasts[1] * screen_height(),
) - self.pos)
.length()
* 0.707
/ screen_width(),
// self.raycasts[0] - self.pos.x / screen_width(),
// self.raycasts[1] - self.pos.y / screen_height(),
self.dir.angle_between(
vec2(
self.raycasts[0] * screen_width(),
self.raycasts[1] * screen_height(),
) - self.pos,
),
// self.vel.x / 11.,
// self.vel.y / 11.,
// self.rot / TAU as f32,
// ];
// inputs.append(self.raycasts.as_mut());
let inputs = self.raycasts.clone();
self.rot, // self.rot.sin(),
// self.rot.cos(),
];
// self.raycasts.resize(3, 0.);
// inputs.append(self.raycasts.clone().as_mut());
// println!("inputs: {:?}", inputs);
// let inputs = self.raycasts.clone();
// inputs.append(self.asteroids_data.as_mut());
if let Some(brain) = &self.brain {
// println!("{:?}", inputs);
// println!("{:?}", brain.feed_forward(inputs.clone()));
keys = brain.feed_forward(inputs).iter().map(|&x| x > 0.).collect();
}
self.raycasts = vec![0.; 8];
if is_key_down(KeyCode::Right) && self.debug || keys[0] {
self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32;
self.dir = vec2(self.rot.cos(), self.rot.sin());
@ -132,8 +162,8 @@ impl Player {
self.last_shot = 0;
self.shots += 1;
self.bullets.push(Bullet {
pos: self.pos + self.dir.rotate(vec2(20., 0.)),
vel: self.dir.rotate(vec2(8.5, 0.)) + self.vel,
pos: self.pos + self.dir * 20.,
vel: self.dir * 8.5 + self.vel,
alive: true,
});
}
@ -158,6 +188,7 @@ impl Player {
self.bullets.retain(|b| {
b.alive && b.pos.x.abs() * 2. < screen_width() && b.pos.y.abs() * 2. < screen_height()
});
self.raycasts = vec![100.; 3];
}
pub fn draw(&self) {
@ -181,21 +212,34 @@ impl Player {
}
if self.debug {
// for a in self.asteroids_data.chunks(3) {
// draw_circle_lines(a[0], a[1], a[2], 1., GRAY);
// draw_line(self.pos.x, self.pos.y, a[0], a[1], 1., GRAY)
// }
for (i, r) in self.raycasts.iter().enumerate() {
let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir);
draw_line(
self.pos.x,
self.pos.y,
self.pos.x + dir.x / r,
self.pos.y + dir.y / r,
for a in self.raycasts.chunks(3) {
draw_circle_lines(
a[0] * screen_width(),
a[1] * screen_height(),
a[2] * 50.,
1.,
GRAY,
);
draw_line(
self.pos.x,
self.pos.y,
a[0] * screen_width(),
a[1] * screen_height(),
1.,
GRAY,
)
}
// for (i, r) in self.raycasts.iter().enumerate() {
// let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir);
// draw_line(
// self.pos.x,
// self.pos.y,
// self.pos.x + dir.x / r,
// self.pos.y + dir.y / r,
// 1.,
// GRAY,
// );
// }
}
for bullet in &self.bullets {

View File

@ -6,6 +6,7 @@ use crate::{nn::NN, world::World};
pub struct Population {
size: usize,
gen: i32,
best: bool,
pub worlds: Vec<World>,
}
@ -30,12 +31,21 @@ impl Population {
self.gen += 1;
self.next_gen();
}
if is_key_pressed(KeyCode::Z) {
self.best = !self.best;
}
}
pub fn draw(&self) {
for world in self.worlds.iter().rev() {
if !world.over {
if self.best {
if world.player.color.is_some() {
world.draw();
}
} else if !world.over {
world.draw();
}
}
draw_text(
&format!("Gen: {}", self.gen),
-150. + screen_width() * 0.5,
@ -44,8 +54,6 @@ impl Population {
WHITE,
);
}
}
}
pub fn next_gen(&mut self) {
let total = self.worlds.iter().fold(0., |acc, x| acc + x.fitness);

View File

@ -7,9 +7,9 @@ use macroquad::{prelude::*, rand::gen_range};
#[derive(Default)]
pub struct World {
player: Player,
pub player: Player,
asteroids: Vec<Asteroid>,
pub score: u32,
pub score: f32,
pub over: bool,
max_asteroids: usize,
pub fitness: f32,
@ -20,7 +20,7 @@ impl World {
Self {
player: Player::new(),
max_asteroids: 28,
score: 1,
score: 0.,
..Default::default()
}
}
@ -29,7 +29,7 @@ impl World {
Self {
player: Player::simulate(brain),
max_asteroids: 28,
score: 1,
score: 0.,
asteroids: vec![
Asteroid::new_to(vec2(0., 0.), 1.5, AsteroidSize::Large),
Asteroid::new(AsteroidSize::Large),
@ -79,17 +79,13 @@ impl World {
asteroid.update();
if self.player.check_player_collision(asteroid) {
self.over = true;
self.fitness = (self.score as f32 + 1.)
* (self.score as f32 / self.player.shots as f32)
* (self.score as f32 / self.player.shots as f32)
* self.player.lifespan as f32
* 0.01;
// self.fitness = self.player.lifespan as f32 * self.player.lifespan as f32 * 0.001;
self.fitness =
(self.score / self.player.shots as f32).powi(2) * self.player.lifespan as f32;
// println!("{} {} {}", self.score, self.player.lifespan, self.fitness);
}
if self.player.check_bullet_collisions(asteroid) {
self.score += 1;
self.score += 1.;
match asteroid.size {
AsteroidSize::Large => {
let rand = vec2(gen_range(-0.8, 0.8), gen_range(-0.8, 0.8));
@ -130,7 +126,8 @@ impl World {
// AsteroidSize::Small => 1,
// }
// }) < self.max_asteroids
if self.player.lifespan % 400 == 0 {
// {
if self.player.lifespan % 200 == 0 {
self.asteroids
.push(Asteroid::new_to(self.player.pos, 1.5, AsteroidSize::Large));
}
@ -141,12 +138,14 @@ impl World {
for asteroid in &self.asteroids {
asteroid.draw();
}
// draw_text(
// &format!("Score {}", self.score),
draw_text(
&format!("{}", (self.score / self.player.shots as f32).powi(2) as f32),
// 20. - screen_width() * 0.5,
// 30. - screen_height() * 0.5,
// 32.,
// WHITE,
// );
self.player.pos.x - 20.,
self.player.pos.y - 20.,
12.,
WHITE,
);
}
}