From 49b720baecd2c42bca4fb1bae026637f6e2f707f Mon Sep 17 00:00:00 2001 From: Leonora Tindall Date: Mon, 3 Apr 2023 22:32:53 -0500 Subject: [PATCH] Use a normal distribution for mutations Rather than StandardNormal, weights are generated with a Normal distribution centered on 0 with a standard deviation of 0.75, and mutations are performed by adding a number from that distribution to the existing weight, rather than replacing the weight entirely. --- Cargo.lock | 7 +++++++ Cargo.toml | 1 + src/nn.rs | 29 +++++++++++++++++++---------- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9f945db..17ab57a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,6 +142,7 @@ dependencies = [ name = "genetic" version = "0.1.0" dependencies = [ + "lazy_static", "macroquad", "nalgebra", "rand", @@ -203,6 +204,12 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "lewton" version = "0.9.4" diff --git a/Cargo.toml b/Cargo.toml index cb49d0d..9c1878b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ rand_distr = "0.4.3" serde = { version = "1.0.152", features = ["derive"] } serde_json = "1.0.91" tinyfiledialogs = "3.9.1" +lazy_static = "1.4" [profile.dev] opt-level = 3 diff --git a/src/nn.rs b/src/nn.rs index 79400cf..093ebe2 100644 --- a/src/nn.rs +++ b/src/nn.rs @@ -1,10 +1,14 @@ use macroquad::{prelude::*, rand::gen_range}; use nalgebra::*; use r::Rng; -use rand_distr::StandardNormal; +use rand_distr::Normal; use serde::{Deserialize, Serialize}; extern crate rand as r; +lazy_static::lazy_static! { + static ref CONNECTION_DISTRIBUTION: Normal = Normal::new(0.0, 0.75).unwrap(); +} + #[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)] pub enum ActivationFunc { @@ -39,8 +43,12 @@ impl NN { .zip(config.iter().skip(1)) .map(|(&curr, &last)| { // DMatrix::from_fn(last, curr + 1, |_, _| gen_range(-1., 1.)) - DMatrix::::from_distribution(last, curr + 1, &StandardNormal, &mut rng) - * (2. / last as f32).sqrt() + DMatrix::::from_distribution( + last, + curr + 1, + &*CONNECTION_DISTRIBUTION, + &mut rng, + ) * (2. / last as f32).sqrt() }) .collect(), @@ -75,7 +83,9 @@ impl NN { if gen_range(0., 1.) < self.mut_rate { // *ele += gen_range(-1., 1.); // *ele = gen_range(-1., 1.); - *ele = r::thread_rng().sample::(StandardNormal); + *ele += + r::thread_rng().sample::>(CONNECTION_DISTRIBUTION.clone()); + *ele = ele.min(10.0).max(-10.0); } } } @@ -85,12 +95,11 @@ impl NN { // println!("inputs: {:?}", inputs); let mut y = DMatrix::from_vec(inputs.len(), 1, inputs.to_vec()); for i in 0..self.config.len() - 1 { - y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| { - match self.activ_func { - ActivationFunc::ReLU => x.max(0.), - ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()), - ActivationFunc::Tanh => x.tanh(), - } + let row = y.insert_row(self.config[i] - 1, 1.); + y = (&self.weights[i] * row).map(|x| match self.activ_func { + ActivationFunc::ReLU => x.max(0.), + ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()), + ActivationFunc::Tanh => x.tanh(), }); } y.column(0).data.into_slice().to_vec()