activ func

This commit is contained in:
sparshg 2023-01-12 15:25:26 +05:30
parent 8d317346f8
commit b7ad50da37
5 changed files with 95 additions and 28 deletions

View File

@ -4,7 +4,7 @@ mod player;
mod population; mod population;
mod world; mod world;
use nn::NN; use nn::{ActivationFunc, NN};
use tinyfiledialogs::*; use tinyfiledialogs::*;
use macroquad::{ use macroquad::{
@ -62,12 +62,22 @@ async fn main() {
let mut paused = false; let mut paused = false;
let mut bias = false; let mut bias = false;
let mut size: u32 = 100; let mut size: u32 = 100;
let mut hlayers: Vec<usize> = vec![6, 6, 0]; let mut hlayers: Vec<usize> = vec![6, 6, 0];
let mut prev_hlayers = hlayers.clone(); let mut prev_hlayers = hlayers.clone();
let mut mut_rate = 0.05; let mut mut_rate = 0.05;
let mut prev_mut_rate = 0.05; let mut prev_mut_rate = 0.05;
let mut pop = Population::new(size as usize, hlayers.clone(), mut_rate);
let mut activ: usize = 0; let mut activ: usize = 0;
let mut prev_activ: usize = 0;
let activs = [
ActivationFunc::ReLU,
ActivationFunc::Sigmoid,
ActivationFunc::Tanh,
];
let mut pop = Population::new(size as usize, hlayers.clone(), mut_rate, activs[activ]);
let ui_thick = 34.; let ui_thick = 34.;
let nums = &[ let nums = &[
@ -262,6 +272,9 @@ async fn main() {
.position(vec2(0., 0.)) .position(vec2(0., 0.))
.ui(ui, |ui| { .ui(ui, |ui| {
ui.label(None, &format!("Generation: {}", pop.gen)); ui.label(None, &format!("Generation: {}", pop.gen));
ui.push_skin(&skin2);
ui.label(vec2(200., 8.), &format!("{: >4}x", speedup));
ui.pop_skin();
ui.same_line(242.); ui.same_line(242.);
if widgets::Button::new("Load Model").ui(ui) { if widgets::Button::new("Load Model").ui(ui) {
if let Some(path) = open_file_dialog("Load Model", "model.json", None) { if let Some(path) = open_file_dialog("Load Model", "model.json", None) {
@ -275,10 +288,19 @@ async fn main() {
.map(|x| x - 1) .map(|x| x - 1)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
hlayers.resize(3, 0); hlayers.resize(3, 0);
prev_hlayers = hlayers.clone();
mut_rate = brain.mut_rate; mut_rate = brain.mut_rate;
prev_mut_rate = brain.mut_rate; activ = activs.iter().position(|&x| x == brain.activ_func).unwrap();
pop = Population::new(size as usize, hlayers.clone(), mut_rate);
prev_hlayers = hlayers.clone();
prev_mut_rate = mut_rate;
prev_activ = activ;
pop = Population::new(
size as usize,
hlayers.clone(),
mut_rate,
activs[activ],
);
pop.worlds[0] = World::simulate(brain); pop.worlds[0] = World::simulate(brain);
} }
} }
@ -336,7 +358,12 @@ async fn main() {
}; };
ui.same_line(0.); ui.same_line(0.);
if widgets::Button::new(restart).ui(ui) { if widgets::Button::new(restart).ui(ui) {
pop = Population::new(size as usize, hlayers.clone(), mut_rate); pop = Population::new(
size as usize,
hlayers.clone(),
mut_rate,
activs[activ],
);
}; };
}); });
ui.push_skin(&skin2); ui.push_skin(&skin2);
@ -354,7 +381,12 @@ async fn main() {
ui.combo_box(hash!(), "Layer 2", nums, &mut hlayers[1]); ui.combo_box(hash!(), "Layer 2", nums, &mut hlayers[1]);
ui.combo_box(hash!(), "Layer 3", nums, &mut hlayers[2]); ui.combo_box(hash!(), "Layer 3", nums, &mut hlayers[2]);
if prev_hlayers != hlayers { if prev_hlayers != hlayers {
pop = Population::new(size as usize, hlayers.clone(), mut_rate); pop = Population::new(
size as usize,
hlayers.clone(),
mut_rate,
activs[activ],
);
prev_hlayers = hlayers.clone(); prev_hlayers = hlayers.clone();
} }
ui.label(None, " "); ui.label(None, " ");
@ -366,7 +398,11 @@ async fn main() {
} }
ui.label(None, " "); ui.label(None, " ");
ui.label(None, "Activation Func"); ui.label(None, "Activation Func");
ui.combo_box(hash!(), "«Select»", &["ReLU", "Sigm"], &mut activ); ui.combo_box(hash!(), "«Select»", &["ReLU", "Sigm", "Tanh"], &mut activ);
if prev_activ != activ {
pop.change_activ(activs[activ]);
prev_activ = activ;
}
}); });
ui.pop_skin(); ui.pop_skin();
}, },

View File

@ -5,16 +5,15 @@ use rand_distr::StandardNormal;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
extern crate rand as r; extern crate rand as r;
#[derive(PartialEq, Debug, Clone, Copy, Default, Serialize, Deserialize)] #[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)]
pub enum ActivationFunc { pub enum ActivationFunc {
ReLU,
Sigmoid, Sigmoid,
Tanh, Tanh,
#[default]
ReLU,
} }
#[derive(Clone, Debug, Default, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct NN { pub struct NN {
pub config: Vec<usize>, pub config: Vec<usize>,
pub weights: Vec<DMatrix<f32>>, pub weights: Vec<DMatrix<f32>>,
@ -24,7 +23,7 @@ pub struct NN {
impl NN { impl NN {
// Vec of number of neurons in input, hidden 1, hidden 2, ..., output layers // Vec of number of neurons in input, hidden 1, hidden 2, ..., output layers
pub fn new(config: Vec<usize>, mut_rate: f32) -> Self { pub fn new(config: Vec<usize>, mut_rate: f32, activ: ActivationFunc) -> Self {
let mut rng = r::thread_rng(); let mut rng = r::thread_rng();
Self { Self {
@ -46,7 +45,7 @@ impl NN {
.collect(), .collect(),
mut_rate, mut_rate,
..Default::default() activ_func: activ,
} }
} }
@ -67,7 +66,6 @@ impl NN {
) )
}) })
.collect(), .collect(),
..Default::default()
} }
} }

View File

@ -2,7 +2,11 @@ use std::{f32::consts::PI, f64::consts::TAU};
use macroquad::{prelude::*, rand::gen_range}; use macroquad::{prelude::*, rand::gen_range};
use crate::{asteroids::Asteroid, nn::NN, HEIGHT, WIDTH}; use crate::{
asteroids::Asteroid,
nn::{ActivationFunc, NN},
HEIGHT, WIDTH,
};
#[derive(Default)] #[derive(Default)]
pub struct Player { pub struct Player {
pub pos: Vec2, pub pos: Vec2,
@ -25,14 +29,18 @@ pub struct Player {
} }
impl Player { impl Player {
pub fn new(config: Option<Vec<usize>>, mut_rate: Option<f32>) -> Self { pub fn new(
config: Option<Vec<usize>>,
mut_rate: Option<f32>,
activ: Option<ActivationFunc>,
) -> Self {
Self { Self {
brain: match config { brain: match config {
Some(mut c) => { Some(mut c) => {
c.retain(|&x| x != 0); c.retain(|&x| x != 0);
c.insert(0, 5); c.insert(0, 5);
c.push(4); c.push(4);
Some(NN::new(c, mut_rate.unwrap())) Some(NN::new(c, mut_rate.unwrap(), activ.unwrap()))
} }
_ => None, _ => None,
}, },
@ -112,6 +120,7 @@ impl Player {
self.last_shot += 1; self.last_shot += 1;
self.acc = 0.; self.acc = 0.;
self.outputs = vec![0.; 4]; self.outputs = vec![0.; 4];
let mut keys = vec![false; 4];
if let Some(ast) = self.asteroid.as_ref() { if let Some(ast) = self.asteroid.as_ref() {
self.inputs = vec![ self.inputs = vec![
(ast.pos - self.pos).length() / HEIGHT, (ast.pos - self.pos).length() / HEIGHT,
@ -135,9 +144,19 @@ impl Player {
if let Some(brain) = &self.brain { if let Some(brain) = &self.brain {
self.outputs = brain.feed_forward(&self.inputs); self.outputs = brain.feed_forward(&self.inputs);
keys = self
.outputs
.iter()
.map(|&x| {
x > if brain.activ_func == ActivationFunc::Sigmoid {
0.85
} else {
0.
}
})
.collect();
} }
} }
let keys: Vec<bool> = self.outputs.iter().map(|&x| x > 0.).collect();
if keys[0] { if keys[0] {
// RIGHT // RIGHT
self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32; self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32;
@ -251,7 +270,7 @@ impl Bullet {
fn update(&mut self) { fn update(&mut self) {
self.pos += self.vel; self.pos += self.vel;
} }
fn draw(&self, color: Color) { fn draw(&self, c: Color) {
draw_circle(self.pos.x, self.pos.y, 2., color); draw_circle(self.pos.x, self.pos.y, 2., Color::new(c.r, c.g, c.b, 0.9));
} }
} }

View File

@ -1,6 +1,10 @@
use macroquad::{prelude::*, rand::gen_range}; use macroquad::{prelude::*, rand::gen_range};
use crate::{nn::NN, world::World, HEIGHT, WIDTH}; use crate::{
nn::{ActivationFunc, NN},
world::World,
HEIGHT, WIDTH,
};
#[derive(Default)] #[derive(Default)]
pub struct Population { pub struct Population {
@ -14,12 +18,12 @@ pub struct Population {
} }
impl Population { impl Population {
pub fn new(size: usize, hlayers: Vec<usize>, mut_rate: f32) -> Self { pub fn new(size: usize, hlayers: Vec<usize>, mut_rate: f32, activ: ActivationFunc) -> Self {
Self { Self {
size, size,
hlayers: hlayers.clone(), hlayers: hlayers.clone(),
worlds: (0..size) worlds: (0..size)
.map(|_| World::new(Some(hlayers.clone()), Some(mut_rate))) .map(|_| World::new(Some(hlayers.clone()), Some(mut_rate), Some(activ)))
.collect(), .collect(),
..Default::default() ..Default::default()
} }
@ -65,6 +69,12 @@ impl Population {
} }
} }
pub fn change_activ(&mut self, activ: ActivationFunc) {
for world in &mut self.worlds {
world.player.brain.as_mut().unwrap().activ_func = activ;
}
}
pub fn draw(&self) { pub fn draw(&self) {
for world in self.worlds.iter().rev() { for world in self.worlds.iter().rev() {
if self.focus { if self.focus {

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
asteroids::{Asteroid, AsteroidSize}, asteroids::{Asteroid, AsteroidSize},
nn::NN, nn::{ActivationFunc, NN},
player::Player, player::Player,
}; };
use macroquad::{prelude::*, rand::gen_range}; use macroquad::{prelude::*, rand::gen_range};
@ -17,9 +17,13 @@ pub struct World {
} }
impl World { impl World {
pub fn new(hlayers: Option<Vec<usize>>, mut_rate: Option<f32>) -> Self { pub fn new(
hlayers: Option<Vec<usize>>,
mut_rate: Option<f32>,
activ: Option<ActivationFunc>,
) -> Self {
Self { Self {
player: Player::new(hlayers, mut_rate), player: Player::new(hlayers, mut_rate, activ),
score: 1., score: 1.,
asteroids: vec![ asteroids: vec![
Asteroid::new_to(vec2(0., 0.), 1.5, AsteroidSize::Large), Asteroid::new_to(vec2(0., 0.), 1.5, AsteroidSize::Large),
@ -33,7 +37,7 @@ impl World {
} }
} }
pub fn simulate(brain: NN) -> Self { pub fn simulate(brain: NN) -> Self {
let mut w = World::new(None, None); let mut w = World::new(None, None, None);
w.player.brain = Some(brain); w.player.brain = Some(brain);
w w
} }