Use a normal distribution for mutations

Rather than StandardNormal, weights are generated with a Normal
distribution centered on 0 with a standard deviation of 0.75,
and mutations are performed by adding a number from that distribution to
the existing weight, rather than replacing the weight entirely.
This commit is contained in:
Leonora Tindall 2023-04-03 22:32:53 -05:00
parent f04309f678
commit 49b720baec
3 changed files with 27 additions and 10 deletions

7
Cargo.lock generated
View File

@ -142,6 +142,7 @@ dependencies = [
name = "genetic"
version = "0.1.0"
dependencies = [
"lazy_static",
"macroquad",
"nalgebra",
"rand",
@ -203,6 +204,12 @@ version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440"
[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "lewton"
version = "0.9.4"

View File

@ -13,6 +13,7 @@ rand_distr = "0.4.3"
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.91"
tinyfiledialogs = "3.9.1"
lazy_static = "1.4"
[profile.dev]
opt-level = 3

View File

@ -1,10 +1,14 @@
use macroquad::{prelude::*, rand::gen_range};
use nalgebra::*;
use r::Rng;
use rand_distr::StandardNormal;
use rand_distr::Normal;
use serde::{Deserialize, Serialize};
extern crate rand as r;
lazy_static::lazy_static! {
static ref CONNECTION_DISTRIBUTION: Normal<f32> = Normal::new(0.0, 0.75).unwrap();
}
#[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)]
pub enum ActivationFunc {
@ -39,8 +43,12 @@ impl NN {
.zip(config.iter().skip(1))
.map(|(&curr, &last)| {
// DMatrix::from_fn(last, curr + 1, |_, _| gen_range(-1., 1.))
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
* (2. / last as f32).sqrt()
DMatrix::<f32>::from_distribution(
last,
curr + 1,
&*CONNECTION_DISTRIBUTION,
&mut rng,
) * (2. / last as f32).sqrt()
})
.collect(),
@ -75,7 +83,9 @@ impl NN {
if gen_range(0., 1.) < self.mut_rate {
// *ele += gen_range(-1., 1.);
// *ele = gen_range(-1., 1.);
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
*ele +=
r::thread_rng().sample::<f32, Normal<f32>>(CONNECTION_DISTRIBUTION.clone());
*ele = ele.min(10.0).max(-10.0);
}
}
}
@ -85,12 +95,11 @@ impl NN {
// println!("inputs: {:?}", inputs);
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs.to_vec());
for i in 0..self.config.len() - 1 {
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
match self.activ_func {
ActivationFunc::ReLU => x.max(0.),
ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()),
ActivationFunc::Tanh => x.tanh(),
}
let row = y.insert_row(self.config[i] - 1, 1.);
y = (&self.weights[i] * row).map(|x| match self.activ_func {
ActivationFunc::ReLU => x.max(0.),
ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()),
ActivationFunc::Tanh => x.tanh(),
});
}
y.column(0).data.into_slice().to_vec()