finally works T_T im ded
This commit is contained in:
parent
32df37e391
commit
8196383515
20
src/nn.rs
20
src/nn.rs
|
@ -4,14 +4,16 @@ use r::Rng;
|
|||
use rand_distr::StandardNormal;
|
||||
extern crate rand as r;
|
||||
|
||||
#[derive(PartialEq, Debug, Clone, Copy)]
|
||||
#[derive(PartialEq, Debug, Clone, Copy, Default)]
|
||||
|
||||
enum ActivationFunc {
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
#[default]
|
||||
ReLU,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct NN {
|
||||
pub config: Vec<usize>,
|
||||
pub weights: Vec<DMatrix<f32>>,
|
||||
|
@ -36,7 +38,8 @@ impl NN {
|
|||
.iter()
|
||||
.zip(config.iter().skip(1))
|
||||
.map(|(&curr, &last)| {
|
||||
// let a = DMatrix::<f32>::new_random(last, curr + 1);
|
||||
// DMatrix::from_fn(last, curr + 1, |_, _| gen_range(-1., 1.))
|
||||
// DMatrix::<f32>::new_random(last, curr + 1)
|
||||
// println!("{}", a);
|
||||
// a
|
||||
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
|
||||
|
@ -44,8 +47,8 @@ impl NN {
|
|||
})
|
||||
.collect(),
|
||||
|
||||
activ_func: ActivationFunc::ReLU,
|
||||
mut_rate: 0.02,
|
||||
mut_rate: 0.04,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -61,6 +64,7 @@ impl NN {
|
|||
.zip(b.weights.iter())
|
||||
.map(|(m1, m2)| m1.zip_map(m2, |ele1, ele2| if r::random() { ele1 } else { ele2 }))
|
||||
.collect(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,7 +73,8 @@ impl NN {
|
|||
for ele in weight {
|
||||
if gen_range(0., 1.) < self.mut_rate {
|
||||
// *ele += gen_range(-1., 1.);
|
||||
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
|
||||
*ele = gen_range(-1., 1.);
|
||||
// *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
|
||||
// *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
|
||||
}
|
||||
}
|
||||
|
@ -77,6 +82,7 @@ impl NN {
|
|||
}
|
||||
|
||||
pub fn feed_forward(&self, inputs: Vec<f32>) -> Vec<f32> {
|
||||
// println!("inputs: {:?}", inputs);
|
||||
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs);
|
||||
for i in 0..self.config.len() - 1 {
|
||||
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
|
||||
|
@ -86,6 +92,8 @@ impl NN {
|
|||
ActivationFunc::Tanh => x.tanh(),
|
||||
}
|
||||
});
|
||||
// println!("w{}: {}", i, self.weights[i]);
|
||||
// println!("y: {}", y);
|
||||
}
|
||||
y.column(0).data.into_slice().to_vec()
|
||||
}
|
||||
|
|
144
src/player.rs
144
src/player.rs
|
@ -1,7 +1,7 @@
|
|||
use std::{f32::consts::PI, f64::consts::TAU};
|
||||
|
||||
use macroquad::{prelude::*, rand::gen_range};
|
||||
use nalgebra::{max, partial_max};
|
||||
use nalgebra::{max, partial_max, partial_min};
|
||||
|
||||
use crate::{asteroids::Asteroid, nn::NN};
|
||||
#[derive(Default)]
|
||||
|
@ -9,7 +9,7 @@ pub struct Player {
|
|||
pub pos: Vec2,
|
||||
pub vel: Vec2,
|
||||
acc: f32,
|
||||
dir: Vec2,
|
||||
pub dir: Vec2,
|
||||
rot: f32,
|
||||
drag: f32,
|
||||
bullets: Vec<Bullet>,
|
||||
|
@ -36,7 +36,7 @@ impl Player {
|
|||
alive: true,
|
||||
debug: false,
|
||||
shots: 4,
|
||||
raycasts: vec![0.; 8],
|
||||
raycasts: vec![f32::MAX; 3],
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
@ -44,39 +44,52 @@ impl Player {
|
|||
pub fn simulate(brain: Option<NN>) -> Self {
|
||||
let mut p = Player::new();
|
||||
if let Some(brain) = brain {
|
||||
assert_eq!(
|
||||
brain.config[0] - 1,
|
||||
8 + 0,
|
||||
"NN input size must match max_asteroids"
|
||||
);
|
||||
// assert_eq!(
|
||||
// brain.config[0] - 1,
|
||||
// 8 + 5,
|
||||
// "NN input size must match max_asteroids"
|
||||
// );
|
||||
p.brain = Some(brain);
|
||||
} else {
|
||||
p.brain = Some(NN::new(vec![8 + 0, 16, 4]));
|
||||
p.brain = Some(NN::new(vec![3, 8, 8, 4]));
|
||||
}
|
||||
p
|
||||
}
|
||||
|
||||
pub fn check_player_collision(&mut self, asteroid: &mut Asteroid) -> bool {
|
||||
// self.asteroids_data.extend([
|
||||
// asteroid.pos.x / screen_width() + 0.5,
|
||||
// asteroid.pos.y / screen_height() + 0.5,
|
||||
// asteroid.radius / 50.,
|
||||
// ]);
|
||||
let v = asteroid.pos - self.pos;
|
||||
for i in 0..4 {
|
||||
let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir);
|
||||
let cross = v.perp_dot(dir);
|
||||
let dot = v.dot(dir);
|
||||
if cross.abs() <= asteroid.radius {
|
||||
self.raycasts[if dot >= 0. { i } else { i + 4 }] = *partial_max(
|
||||
&self.raycasts[if dot >= 0. { i } else { i + 4 }],
|
||||
&(1. / (dot.abs()
|
||||
- (asteroid.radius * asteroid.radius - cross * cross).sqrt())),
|
||||
// self.raycasts.extend([
|
||||
if (asteroid.pos).distance_squared(self.pos)
|
||||
< vec2(
|
||||
self.raycasts[0] * screen_width(),
|
||||
self.raycasts[1] * screen_height(),
|
||||
)
|
||||
.unwrap();
|
||||
.distance_squared(self.pos)
|
||||
{
|
||||
self.raycasts[0] = asteroid.pos.x / screen_width();
|
||||
self.raycasts[1] = asteroid.pos.y / screen_height();
|
||||
self.raycasts[2] = asteroid.radius / 50.;
|
||||
}
|
||||
}
|
||||
if asteroid.check_collision(self.pos, 8.) {
|
||||
// ]);
|
||||
// if self.raycasts[0] > (asteroid.pos - self.pos).length_squared() {
|
||||
// self.raycasts[0] = (asteroid.pos - self.pos).length_squared();
|
||||
// self.raycasts[1] = Vec2::angle_between(asteroid.pos - self.pos, self.dir).sin();
|
||||
// self.raycasts[2] = Vec2::angle_between(asteroid.pos - self.pos, self.dir).cos();
|
||||
// }
|
||||
// let v = asteroid.pos - self.pos;
|
||||
// for i in 0..4 {
|
||||
// let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir);
|
||||
// let cross = v.perp_dot(dir);
|
||||
// let dot = v.dot(dir);
|
||||
// if cross.abs() <= asteroid.radius {
|
||||
// self.raycasts[if dot >= 0. { i } else { i + 4 }] = *partial_max(
|
||||
// &self.raycasts[if dot >= 0. { i } else { i + 4 }],
|
||||
// &(1. / (dot.abs()
|
||||
// - (asteroid.radius * asteroid.radius - cross * cross).sqrt())),
|
||||
// )
|
||||
// .unwrap();
|
||||
// }
|
||||
// }
|
||||
if asteroid.check_collision(self.pos, 8.) || self.lifespan > 2000 {
|
||||
self.alive = false;
|
||||
return true;
|
||||
}
|
||||
|
@ -99,22 +112,39 @@ impl Player {
|
|||
self.last_shot += 1;
|
||||
self.acc = 0.;
|
||||
let mut keys = vec![false, false, false, false];
|
||||
// self.asteroids_data.resize(self.max_asteroids * 3, 0.);
|
||||
// let mut inputs = vec![
|
||||
// self.pos.x / screen_width() + 0.5,
|
||||
// self.pos.y / screen_height() + 0.5,
|
||||
let mut inputs = vec![
|
||||
(vec2(
|
||||
self.raycasts[0] * screen_width(),
|
||||
self.raycasts[1] * screen_height(),
|
||||
) - self.pos)
|
||||
.length()
|
||||
* 0.707
|
||||
/ screen_width(),
|
||||
// self.raycasts[0] - self.pos.x / screen_width(),
|
||||
// self.raycasts[1] - self.pos.y / screen_height(),
|
||||
self.dir.angle_between(
|
||||
vec2(
|
||||
self.raycasts[0] * screen_width(),
|
||||
self.raycasts[1] * screen_height(),
|
||||
) - self.pos,
|
||||
),
|
||||
// self.vel.x / 11.,
|
||||
// self.vel.y / 11.,
|
||||
// self.rot / TAU as f32,
|
||||
// ];
|
||||
// inputs.append(self.raycasts.as_mut());
|
||||
let inputs = self.raycasts.clone();
|
||||
self.rot, // self.rot.sin(),
|
||||
// self.rot.cos(),
|
||||
];
|
||||
|
||||
// self.raycasts.resize(3, 0.);
|
||||
// inputs.append(self.raycasts.clone().as_mut());
|
||||
// println!("inputs: {:?}", inputs);
|
||||
|
||||
// let inputs = self.raycasts.clone();
|
||||
// inputs.append(self.asteroids_data.as_mut());
|
||||
if let Some(brain) = &self.brain {
|
||||
// println!("{:?}", inputs);
|
||||
// println!("{:?}", brain.feed_forward(inputs.clone()));
|
||||
|
||||
keys = brain.feed_forward(inputs).iter().map(|&x| x > 0.).collect();
|
||||
}
|
||||
self.raycasts = vec![0.; 8];
|
||||
if is_key_down(KeyCode::Right) && self.debug || keys[0] {
|
||||
self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32;
|
||||
self.dir = vec2(self.rot.cos(), self.rot.sin());
|
||||
|
@ -132,8 +162,8 @@ impl Player {
|
|||
self.last_shot = 0;
|
||||
self.shots += 1;
|
||||
self.bullets.push(Bullet {
|
||||
pos: self.pos + self.dir.rotate(vec2(20., 0.)),
|
||||
vel: self.dir.rotate(vec2(8.5, 0.)) + self.vel,
|
||||
pos: self.pos + self.dir * 20.,
|
||||
vel: self.dir * 8.5 + self.vel,
|
||||
alive: true,
|
||||
});
|
||||
}
|
||||
|
@ -158,6 +188,7 @@ impl Player {
|
|||
self.bullets.retain(|b| {
|
||||
b.alive && b.pos.x.abs() * 2. < screen_width() && b.pos.y.abs() * 2. < screen_height()
|
||||
});
|
||||
self.raycasts = vec![100.; 3];
|
||||
}
|
||||
|
||||
pub fn draw(&self) {
|
||||
|
@ -181,21 +212,34 @@ impl Player {
|
|||
}
|
||||
|
||||
if self.debug {
|
||||
// for a in self.asteroids_data.chunks(3) {
|
||||
// draw_circle_lines(a[0], a[1], a[2], 1., GRAY);
|
||||
// draw_line(self.pos.x, self.pos.y, a[0], a[1], 1., GRAY)
|
||||
// }
|
||||
for (i, r) in self.raycasts.iter().enumerate() {
|
||||
let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir);
|
||||
draw_line(
|
||||
self.pos.x,
|
||||
self.pos.y,
|
||||
self.pos.x + dir.x / r,
|
||||
self.pos.y + dir.y / r,
|
||||
for a in self.raycasts.chunks(3) {
|
||||
draw_circle_lines(
|
||||
a[0] * screen_width(),
|
||||
a[1] * screen_height(),
|
||||
a[2] * 50.,
|
||||
1.,
|
||||
GRAY,
|
||||
);
|
||||
draw_line(
|
||||
self.pos.x,
|
||||
self.pos.y,
|
||||
a[0] * screen_width(),
|
||||
a[1] * screen_height(),
|
||||
1.,
|
||||
GRAY,
|
||||
)
|
||||
}
|
||||
// for (i, r) in self.raycasts.iter().enumerate() {
|
||||
// let dir = Vec2::from_angle(PI / 4. * i as f32).rotate(self.dir);
|
||||
// draw_line(
|
||||
// self.pos.x,
|
||||
// self.pos.y,
|
||||
// self.pos.x + dir.x / r,
|
||||
// self.pos.y + dir.y / r,
|
||||
// 1.,
|
||||
// GRAY,
|
||||
// );
|
||||
// }
|
||||
}
|
||||
|
||||
for bullet in &self.bullets {
|
||||
|
|
|
@ -6,6 +6,7 @@ use crate::{nn::NN, world::World};
|
|||
pub struct Population {
|
||||
size: usize,
|
||||
gen: i32,
|
||||
best: bool,
|
||||
pub worlds: Vec<World>,
|
||||
}
|
||||
|
||||
|
@ -30,12 +31,21 @@ impl Population {
|
|||
self.gen += 1;
|
||||
self.next_gen();
|
||||
}
|
||||
if is_key_pressed(KeyCode::Z) {
|
||||
self.best = !self.best;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn draw(&self) {
|
||||
for world in self.worlds.iter().rev() {
|
||||
if !world.over {
|
||||
if self.best {
|
||||
if world.player.color.is_some() {
|
||||
world.draw();
|
||||
}
|
||||
} else if !world.over {
|
||||
world.draw();
|
||||
}
|
||||
}
|
||||
draw_text(
|
||||
&format!("Gen: {}", self.gen),
|
||||
-150. + screen_width() * 0.5,
|
||||
|
@ -44,8 +54,6 @@ impl Population {
|
|||
WHITE,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_gen(&mut self) {
|
||||
let total = self.worlds.iter().fold(0., |acc, x| acc + x.fitness);
|
||||
|
|
33
src/world.rs
33
src/world.rs
|
@ -7,9 +7,9 @@ use macroquad::{prelude::*, rand::gen_range};
|
|||
|
||||
#[derive(Default)]
|
||||
pub struct World {
|
||||
player: Player,
|
||||
pub player: Player,
|
||||
asteroids: Vec<Asteroid>,
|
||||
pub score: u32,
|
||||
pub score: f32,
|
||||
pub over: bool,
|
||||
max_asteroids: usize,
|
||||
pub fitness: f32,
|
||||
|
@ -20,7 +20,7 @@ impl World {
|
|||
Self {
|
||||
player: Player::new(),
|
||||
max_asteroids: 28,
|
||||
score: 1,
|
||||
score: 0.,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ impl World {
|
|||
Self {
|
||||
player: Player::simulate(brain),
|
||||
max_asteroids: 28,
|
||||
score: 1,
|
||||
score: 0.,
|
||||
asteroids: vec![
|
||||
Asteroid::new_to(vec2(0., 0.), 1.5, AsteroidSize::Large),
|
||||
Asteroid::new(AsteroidSize::Large),
|
||||
|
@ -79,17 +79,13 @@ impl World {
|
|||
asteroid.update();
|
||||
if self.player.check_player_collision(asteroid) {
|
||||
self.over = true;
|
||||
self.fitness = (self.score as f32 + 1.)
|
||||
* (self.score as f32 / self.player.shots as f32)
|
||||
* (self.score as f32 / self.player.shots as f32)
|
||||
* self.player.lifespan as f32
|
||||
* 0.01;
|
||||
// self.fitness = self.player.lifespan as f32 * self.player.lifespan as f32 * 0.001;
|
||||
self.fitness =
|
||||
(self.score / self.player.shots as f32).powi(2) * self.player.lifespan as f32;
|
||||
|
||||
// println!("{} {} {}", self.score, self.player.lifespan, self.fitness);
|
||||
}
|
||||
if self.player.check_bullet_collisions(asteroid) {
|
||||
self.score += 1;
|
||||
self.score += 1.;
|
||||
match asteroid.size {
|
||||
AsteroidSize::Large => {
|
||||
let rand = vec2(gen_range(-0.8, 0.8), gen_range(-0.8, 0.8));
|
||||
|
@ -130,7 +126,8 @@ impl World {
|
|||
// AsteroidSize::Small => 1,
|
||||
// }
|
||||
// }) < self.max_asteroids
|
||||
if self.player.lifespan % 400 == 0 {
|
||||
// {
|
||||
if self.player.lifespan % 200 == 0 {
|
||||
self.asteroids
|
||||
.push(Asteroid::new_to(self.player.pos, 1.5, AsteroidSize::Large));
|
||||
}
|
||||
|
@ -141,12 +138,14 @@ impl World {
|
|||
for asteroid in &self.asteroids {
|
||||
asteroid.draw();
|
||||
}
|
||||
// draw_text(
|
||||
// &format!("Score {}", self.score),
|
||||
draw_text(
|
||||
&format!("{}", (self.score / self.player.shots as f32).powi(2) as f32),
|
||||
// 20. - screen_width() * 0.5,
|
||||
// 30. - screen_height() * 0.5,
|
||||
// 32.,
|
||||
// WHITE,
|
||||
// );
|
||||
self.player.pos.x - 20.,
|
||||
self.player.pos.y - 20.,
|
||||
12.,
|
||||
WHITE,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue