fixed bugs

This commit is contained in:
sparshg 2022-10-14 16:27:58 +05:30
parent 666c9a8b8e
commit db2df770c6
7 changed files with 78 additions and 24 deletions

27
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,27 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in library 'yourprogram'",
"cargo": {
"args": [
"test",
"--no-run",
"--lib",
"--package=yourprogram"
],
"filter": {
"name": "yourprogram",
"kind": "lib"
}
},
"args": [],
"cwd": "${workspaceFolder}"
}
]
}

View File

@ -89,7 +89,7 @@ impl Asteroid {
AsteroidSize::Medium => 1.2, AsteroidSize::Medium => 1.2,
AsteroidSize::Small => 0.8, AsteroidSize::Small => 0.8,
}, },
Color::new(1., 1., 1., 0.5), Color::new(1., 1., 1., 0.4),
); );
} }
} }

View File

@ -16,8 +16,11 @@ async fn main() {
..Default::default() ..Default::default()
}; };
set_camera(&cam); set_camera(&cam);
let mut pop = Population::new(100); let mut pop = Population::new(10);
let mut speedup = false; let mut speedup = false;
// for _ in 0..100000 * 5 {
// pop.update();
// }
loop { loop {
clear_background(BLACK); clear_background(BLACK);
if is_key_pressed(KeyCode::S) { if is_key_pressed(KeyCode::S) {

View File

@ -1,3 +1,4 @@
use macroquad::rand::gen_range;
use nalgebra::*; use nalgebra::*;
use r::Rng; use r::Rng;
use rand_distr::StandardNormal; use rand_distr::StandardNormal;
@ -10,10 +11,10 @@ enum ActivationFunc {
ReLU, ReLU,
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct NN { pub struct NN {
pub config: Vec<usize>, pub config: Vec<usize>,
weights: Vec<DMatrix<f32>>, pub weights: Vec<DMatrix<f32>>,
activ_func: ActivationFunc, activ_func: ActivationFunc,
mut_rate: f32, mut_rate: f32,
} }
@ -63,8 +64,8 @@ impl NN {
pub fn mutate(&mut self) { pub fn mutate(&mut self) {
for weight in &mut self.weights { for weight in &mut self.weights {
for ele in weight { for ele in weight {
if r::random() { if gen_range(0., 1.) < 0.05 {
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal) * 0.05; *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
} }
} }
} }

View File

@ -40,7 +40,7 @@ impl Player {
pub fn simulate(brain: NN, max_asteroids: usize) -> Self { pub fn simulate(brain: NN, max_asteroids: usize) -> Self {
assert_eq!( assert_eq!(
brain.config[0] - 1, brain.config[0] - 1,
max_asteroids + 5, max_asteroids * 3 + 5,
"NN input size must match max_asteroids" "NN input size must match max_asteroids"
); );
let mut p = Player::new(); let mut p = Player::new();
@ -77,8 +77,7 @@ impl Player {
self.lifespan += 1; self.lifespan += 1;
let mut mag = 0.; let mut mag = 0.;
let mut keys = vec![false, false, false, false]; let mut keys = vec![false, false, false, false];
self.asteroids_data.resize(self.max_asteroids * 3, 0.);
self.asteroids_data.resize(self.max_asteroids, 0.);
let mut inputs = vec![ let mut inputs = vec![
self.pos.x / screen_width() + 0.5, self.pos.x / screen_width() + 0.5,
self.pos.y / screen_height() + 0.5, self.pos.y / screen_height() + 0.5,
@ -124,10 +123,10 @@ impl Player {
self.vel += mag * self.dir - self.drag * self.vel.length() * self.vel; self.vel += mag * self.dir - self.drag * self.vel.length() * self.vel;
self.pos += self.vel; self.pos += self.vel;
if self.pos.x.abs() > screen_width() / 2. + 10. { if self.pos.x.abs() > screen_width() * 0.5 + 10. {
self.pos.x *= -1.; self.pos.x *= -1.;
} }
if self.pos.y.abs() > screen_height() / 2. + 10. { if self.pos.y.abs() > screen_height() * 0.5 + 10. {
self.pos.y *= -1.; self.pos.y *= -1.;
} }

View File

@ -14,7 +14,7 @@ impl Population {
Self { Self {
size, size,
worlds: (0..size) worlds: (0..size)
.map(|_| World::simulate(NN::new(vec![33, 16, 4]))) .map(|_| World::simulate(NN::new(vec![89, 16, 4])))
.collect(), .collect(),
..Default::default() ..Default::default()
} }
@ -52,26 +52,38 @@ impl Population {
pub fn next_gen(&mut self) { pub fn next_gen(&mut self) {
let total = self.worlds.iter().fold(0., |acc, x| acc + x.fitness()); let total = self.worlds.iter().fold(0., |acc, x| acc + x.fitness());
// self.worlds.sort_by(|a, b| b.fitness().cmp(&a.fitness())); self.worlds
let mut new_worlds = Vec::new(); .sort_by(|a, b| b.fitness().partial_cmp(&a.fitness()).unwrap());
// (0..self.size / 10) let mut new_worlds = (0..self.size / 10)
// .map(|i| World::simulate(self.worlds[i].see_brain().to_owned())) .map(|i| World::simulate(self.worlds[i].see_brain().to_owned()))
// .collect::<Vec<_>>(); .collect::<Vec<_>>();
// println!(
// "Total fitness: {} {} {}",
// total,
// self.worlds[0].fitness(),
// self.worlds[1].fitness()
// );
while new_worlds.len() < self.size { while new_worlds.len() < self.size {
let rands = (gen_range(0., total), gen_range(0., total)); let rands = (gen_range(0., total), gen_range(0., total));
// println!("rands: {} {}", rands.0, rands.1);
let mut sum = 0.; let mut sum = 0.;
let (mut a, mut b) = (None, None); let (mut a, mut b) = (None, None);
for world in &self.worlds { for world in &self.worlds {
sum += world.fitness(); sum += world.fitness();
if sum >= rands.0 { if a.is_none() && sum >= rands.0 {
a = Some(world.see_brain()); a = Some(world.see_brain());
} }
if sum >= rands.1 { if b.is_none() && sum >= rands.1 {
b = Some(world.see_brain()); b = Some(world.see_brain());
} }
} }
// println!("{}", &a.unwrap().weights[0]);
// println!("{}", &b.unwrap().weights[0]);
let mut new_brain = NN::crossover(a.unwrap(), b.unwrap()); let mut new_brain = NN::crossover(a.unwrap(), b.unwrap());
// println!("{}", &a.unwrap().weights[0]);
// println!("{}", &b.unwrap().weights[0]);
// println!("{}", &new_brain.weights[0]);
new_brain.mutate(); new_brain.mutate();
new_worlds.push(World::simulate(new_brain)); new_worlds.push(World::simulate(new_brain));
} }

View File

@ -36,12 +36,24 @@ impl World {
} }
pub fn fitness(&self) -> f32 { pub fn fitness(&self) -> f32 {
self.score as f32 // println!(
+ self.player.lifespan as f32 * 0.01 // "{} {} {}",
+ if self.player.shots > 0 { // self.score as f32,
self.score as f32 / self.player.shots as f32 * 10. // self.player.lifespan as f32 * 0.001,
// if self.player.shots > 0 {
// self.score as f32 / self.player.shots as f32 * 5.
// } else {
// 0.
// }
// );
(self.score + 1) as f32
* 10.
* self.player.lifespan as f32
* if self.player.shots > 0 {
(self.score as f32 / self.player.shots as f32)
* (self.score as f32 / self.player.shots as f32)
} else { } else {
0. 1.
} }
} }