activ func
This commit is contained in:
parent
8d317346f8
commit
b7ad50da37
52
src/main.rs
52
src/main.rs
|
@ -4,7 +4,7 @@ mod player;
|
||||||
mod population;
|
mod population;
|
||||||
mod world;
|
mod world;
|
||||||
|
|
||||||
use nn::NN;
|
use nn::{ActivationFunc, NN};
|
||||||
use tinyfiledialogs::*;
|
use tinyfiledialogs::*;
|
||||||
|
|
||||||
use macroquad::{
|
use macroquad::{
|
||||||
|
@ -62,12 +62,22 @@ async fn main() {
|
||||||
let mut paused = false;
|
let mut paused = false;
|
||||||
let mut bias = false;
|
let mut bias = false;
|
||||||
let mut size: u32 = 100;
|
let mut size: u32 = 100;
|
||||||
|
|
||||||
let mut hlayers: Vec<usize> = vec![6, 6, 0];
|
let mut hlayers: Vec<usize> = vec![6, 6, 0];
|
||||||
let mut prev_hlayers = hlayers.clone();
|
let mut prev_hlayers = hlayers.clone();
|
||||||
|
|
||||||
let mut mut_rate = 0.05;
|
let mut mut_rate = 0.05;
|
||||||
let mut prev_mut_rate = 0.05;
|
let mut prev_mut_rate = 0.05;
|
||||||
let mut pop = Population::new(size as usize, hlayers.clone(), mut_rate);
|
|
||||||
let mut activ: usize = 0;
|
let mut activ: usize = 0;
|
||||||
|
let mut prev_activ: usize = 0;
|
||||||
|
let activs = [
|
||||||
|
ActivationFunc::ReLU,
|
||||||
|
ActivationFunc::Sigmoid,
|
||||||
|
ActivationFunc::Tanh,
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut pop = Population::new(size as usize, hlayers.clone(), mut_rate, activs[activ]);
|
||||||
|
|
||||||
let ui_thick = 34.;
|
let ui_thick = 34.;
|
||||||
let nums = &[
|
let nums = &[
|
||||||
|
@ -262,6 +272,9 @@ async fn main() {
|
||||||
.position(vec2(0., 0.))
|
.position(vec2(0., 0.))
|
||||||
.ui(ui, |ui| {
|
.ui(ui, |ui| {
|
||||||
ui.label(None, &format!("Generation: {}", pop.gen));
|
ui.label(None, &format!("Generation: {}", pop.gen));
|
||||||
|
ui.push_skin(&skin2);
|
||||||
|
ui.label(vec2(200., 8.), &format!("{: >4}x", speedup));
|
||||||
|
ui.pop_skin();
|
||||||
ui.same_line(242.);
|
ui.same_line(242.);
|
||||||
if widgets::Button::new("Load Model").ui(ui) {
|
if widgets::Button::new("Load Model").ui(ui) {
|
||||||
if let Some(path) = open_file_dialog("Load Model", "model.json", None) {
|
if let Some(path) = open_file_dialog("Load Model", "model.json", None) {
|
||||||
|
@ -275,10 +288,19 @@ async fn main() {
|
||||||
.map(|x| x - 1)
|
.map(|x| x - 1)
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
hlayers.resize(3, 0);
|
hlayers.resize(3, 0);
|
||||||
prev_hlayers = hlayers.clone();
|
|
||||||
mut_rate = brain.mut_rate;
|
mut_rate = brain.mut_rate;
|
||||||
prev_mut_rate = brain.mut_rate;
|
activ = activs.iter().position(|&x| x == brain.activ_func).unwrap();
|
||||||
pop = Population::new(size as usize, hlayers.clone(), mut_rate);
|
|
||||||
|
prev_hlayers = hlayers.clone();
|
||||||
|
prev_mut_rate = mut_rate;
|
||||||
|
prev_activ = activ;
|
||||||
|
|
||||||
|
pop = Population::new(
|
||||||
|
size as usize,
|
||||||
|
hlayers.clone(),
|
||||||
|
mut_rate,
|
||||||
|
activs[activ],
|
||||||
|
);
|
||||||
pop.worlds[0] = World::simulate(brain);
|
pop.worlds[0] = World::simulate(brain);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -336,7 +358,12 @@ async fn main() {
|
||||||
};
|
};
|
||||||
ui.same_line(0.);
|
ui.same_line(0.);
|
||||||
if widgets::Button::new(restart).ui(ui) {
|
if widgets::Button::new(restart).ui(ui) {
|
||||||
pop = Population::new(size as usize, hlayers.clone(), mut_rate);
|
pop = Population::new(
|
||||||
|
size as usize,
|
||||||
|
hlayers.clone(),
|
||||||
|
mut_rate,
|
||||||
|
activs[activ],
|
||||||
|
);
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
ui.push_skin(&skin2);
|
ui.push_skin(&skin2);
|
||||||
|
@ -354,7 +381,12 @@ async fn main() {
|
||||||
ui.combo_box(hash!(), "Layer 2", nums, &mut hlayers[1]);
|
ui.combo_box(hash!(), "Layer 2", nums, &mut hlayers[1]);
|
||||||
ui.combo_box(hash!(), "Layer 3", nums, &mut hlayers[2]);
|
ui.combo_box(hash!(), "Layer 3", nums, &mut hlayers[2]);
|
||||||
if prev_hlayers != hlayers {
|
if prev_hlayers != hlayers {
|
||||||
pop = Population::new(size as usize, hlayers.clone(), mut_rate);
|
pop = Population::new(
|
||||||
|
size as usize,
|
||||||
|
hlayers.clone(),
|
||||||
|
mut_rate,
|
||||||
|
activs[activ],
|
||||||
|
);
|
||||||
prev_hlayers = hlayers.clone();
|
prev_hlayers = hlayers.clone();
|
||||||
}
|
}
|
||||||
ui.label(None, " ");
|
ui.label(None, " ");
|
||||||
|
@ -366,7 +398,11 @@ async fn main() {
|
||||||
}
|
}
|
||||||
ui.label(None, " ");
|
ui.label(None, " ");
|
||||||
ui.label(None, "Activation Func");
|
ui.label(None, "Activation Func");
|
||||||
ui.combo_box(hash!(), "«Select»", &["ReLU", "Sigm"], &mut activ);
|
ui.combo_box(hash!(), "«Select»", &["ReLU", "Sigm", "Tanh"], &mut activ);
|
||||||
|
if prev_activ != activ {
|
||||||
|
pop.change_activ(activs[activ]);
|
||||||
|
prev_activ = activ;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
ui.pop_skin();
|
ui.pop_skin();
|
||||||
},
|
},
|
||||||
|
|
12
src/nn.rs
12
src/nn.rs
|
@ -5,16 +5,15 @@ use rand_distr::StandardNormal;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
extern crate rand as r;
|
extern crate rand as r;
|
||||||
|
|
||||||
#[derive(PartialEq, Debug, Clone, Copy, Default, Serialize, Deserialize)]
|
#[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)]
|
||||||
|
|
||||||
pub enum ActivationFunc {
|
pub enum ActivationFunc {
|
||||||
|
ReLU,
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
Tanh,
|
Tanh,
|
||||||
#[default]
|
|
||||||
ReLU,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub struct NN {
|
pub struct NN {
|
||||||
pub config: Vec<usize>,
|
pub config: Vec<usize>,
|
||||||
pub weights: Vec<DMatrix<f32>>,
|
pub weights: Vec<DMatrix<f32>>,
|
||||||
|
@ -24,7 +23,7 @@ pub struct NN {
|
||||||
|
|
||||||
impl NN {
|
impl NN {
|
||||||
// Vec of number of neurons in input, hidden 1, hidden 2, ..., output layers
|
// Vec of number of neurons in input, hidden 1, hidden 2, ..., output layers
|
||||||
pub fn new(config: Vec<usize>, mut_rate: f32) -> Self {
|
pub fn new(config: Vec<usize>, mut_rate: f32, activ: ActivationFunc) -> Self {
|
||||||
let mut rng = r::thread_rng();
|
let mut rng = r::thread_rng();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
@ -46,7 +45,7 @@ impl NN {
|
||||||
.collect(),
|
.collect(),
|
||||||
|
|
||||||
mut_rate,
|
mut_rate,
|
||||||
..Default::default()
|
activ_func: activ,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,7 +66,6 @@ impl NN {
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
..Default::default()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,11 @@ use std::{f32::consts::PI, f64::consts::TAU};
|
||||||
|
|
||||||
use macroquad::{prelude::*, rand::gen_range};
|
use macroquad::{prelude::*, rand::gen_range};
|
||||||
|
|
||||||
use crate::{asteroids::Asteroid, nn::NN, HEIGHT, WIDTH};
|
use crate::{
|
||||||
|
asteroids::Asteroid,
|
||||||
|
nn::{ActivationFunc, NN},
|
||||||
|
HEIGHT, WIDTH,
|
||||||
|
};
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct Player {
|
pub struct Player {
|
||||||
pub pos: Vec2,
|
pub pos: Vec2,
|
||||||
|
@ -25,14 +29,18 @@ pub struct Player {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Player {
|
impl Player {
|
||||||
pub fn new(config: Option<Vec<usize>>, mut_rate: Option<f32>) -> Self {
|
pub fn new(
|
||||||
|
config: Option<Vec<usize>>,
|
||||||
|
mut_rate: Option<f32>,
|
||||||
|
activ: Option<ActivationFunc>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
brain: match config {
|
brain: match config {
|
||||||
Some(mut c) => {
|
Some(mut c) => {
|
||||||
c.retain(|&x| x != 0);
|
c.retain(|&x| x != 0);
|
||||||
c.insert(0, 5);
|
c.insert(0, 5);
|
||||||
c.push(4);
|
c.push(4);
|
||||||
Some(NN::new(c, mut_rate.unwrap()))
|
Some(NN::new(c, mut_rate.unwrap(), activ.unwrap()))
|
||||||
}
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
},
|
},
|
||||||
|
@ -112,6 +120,7 @@ impl Player {
|
||||||
self.last_shot += 1;
|
self.last_shot += 1;
|
||||||
self.acc = 0.;
|
self.acc = 0.;
|
||||||
self.outputs = vec![0.; 4];
|
self.outputs = vec![0.; 4];
|
||||||
|
let mut keys = vec![false; 4];
|
||||||
if let Some(ast) = self.asteroid.as_ref() {
|
if let Some(ast) = self.asteroid.as_ref() {
|
||||||
self.inputs = vec![
|
self.inputs = vec![
|
||||||
(ast.pos - self.pos).length() / HEIGHT,
|
(ast.pos - self.pos).length() / HEIGHT,
|
||||||
|
@ -135,9 +144,19 @@ impl Player {
|
||||||
|
|
||||||
if let Some(brain) = &self.brain {
|
if let Some(brain) = &self.brain {
|
||||||
self.outputs = brain.feed_forward(&self.inputs);
|
self.outputs = brain.feed_forward(&self.inputs);
|
||||||
|
keys = self
|
||||||
|
.outputs
|
||||||
|
.iter()
|
||||||
|
.map(|&x| {
|
||||||
|
x > if brain.activ_func == ActivationFunc::Sigmoid {
|
||||||
|
0.85
|
||||||
|
} else {
|
||||||
|
0.
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let keys: Vec<bool> = self.outputs.iter().map(|&x| x > 0.).collect();
|
|
||||||
if keys[0] {
|
if keys[0] {
|
||||||
// RIGHT
|
// RIGHT
|
||||||
self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32;
|
self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32;
|
||||||
|
@ -251,7 +270,7 @@ impl Bullet {
|
||||||
fn update(&mut self) {
|
fn update(&mut self) {
|
||||||
self.pos += self.vel;
|
self.pos += self.vel;
|
||||||
}
|
}
|
||||||
fn draw(&self, color: Color) {
|
fn draw(&self, c: Color) {
|
||||||
draw_circle(self.pos.x, self.pos.y, 2., color);
|
draw_circle(self.pos.x, self.pos.y, 2., Color::new(c.r, c.g, c.b, 0.9));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
use macroquad::{prelude::*, rand::gen_range};
|
use macroquad::{prelude::*, rand::gen_range};
|
||||||
|
|
||||||
use crate::{nn::NN, world::World, HEIGHT, WIDTH};
|
use crate::{
|
||||||
|
nn::{ActivationFunc, NN},
|
||||||
|
world::World,
|
||||||
|
HEIGHT, WIDTH,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct Population {
|
pub struct Population {
|
||||||
|
@ -14,12 +18,12 @@ pub struct Population {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Population {
|
impl Population {
|
||||||
pub fn new(size: usize, hlayers: Vec<usize>, mut_rate: f32) -> Self {
|
pub fn new(size: usize, hlayers: Vec<usize>, mut_rate: f32, activ: ActivationFunc) -> Self {
|
||||||
Self {
|
Self {
|
||||||
size,
|
size,
|
||||||
hlayers: hlayers.clone(),
|
hlayers: hlayers.clone(),
|
||||||
worlds: (0..size)
|
worlds: (0..size)
|
||||||
.map(|_| World::new(Some(hlayers.clone()), Some(mut_rate)))
|
.map(|_| World::new(Some(hlayers.clone()), Some(mut_rate), Some(activ)))
|
||||||
.collect(),
|
.collect(),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
|
@ -65,6 +69,12 @@ impl Population {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn change_activ(&mut self, activ: ActivationFunc) {
|
||||||
|
for world in &mut self.worlds {
|
||||||
|
world.player.brain.as_mut().unwrap().activ_func = activ;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn draw(&self) {
|
pub fn draw(&self) {
|
||||||
for world in self.worlds.iter().rev() {
|
for world in self.worlds.iter().rev() {
|
||||||
if self.focus {
|
if self.focus {
|
||||||
|
|
12
src/world.rs
12
src/world.rs
|
@ -1,6 +1,6 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
asteroids::{Asteroid, AsteroidSize},
|
asteroids::{Asteroid, AsteroidSize},
|
||||||
nn::NN,
|
nn::{ActivationFunc, NN},
|
||||||
player::Player,
|
player::Player,
|
||||||
};
|
};
|
||||||
use macroquad::{prelude::*, rand::gen_range};
|
use macroquad::{prelude::*, rand::gen_range};
|
||||||
|
@ -17,9 +17,13 @@ pub struct World {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl World {
|
impl World {
|
||||||
pub fn new(hlayers: Option<Vec<usize>>, mut_rate: Option<f32>) -> Self {
|
pub fn new(
|
||||||
|
hlayers: Option<Vec<usize>>,
|
||||||
|
mut_rate: Option<f32>,
|
||||||
|
activ: Option<ActivationFunc>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
player: Player::new(hlayers, mut_rate),
|
player: Player::new(hlayers, mut_rate, activ),
|
||||||
score: 1.,
|
score: 1.,
|
||||||
asteroids: vec![
|
asteroids: vec![
|
||||||
Asteroid::new_to(vec2(0., 0.), 1.5, AsteroidSize::Large),
|
Asteroid::new_to(vec2(0., 0.), 1.5, AsteroidSize::Large),
|
||||||
|
@ -33,7 +37,7 @@ impl World {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn simulate(brain: NN) -> Self {
|
pub fn simulate(brain: NN) -> Self {
|
||||||
let mut w = World::new(None, None);
|
let mut w = World::new(None, None, None);
|
||||||
w.player.brain = Some(brain);
|
w.player.brain = Some(brain);
|
||||||
w
|
w
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue