Use a normal distribution for mutations
Rather than StandardNormal, weights are generated with a Normal distribution centered on 0 with a standard deviation of 0.75, and mutations are performed by adding a number from that distribution to the existing weight, rather than replacing the weight entirely.
This commit is contained in:
parent
f04309f678
commit
49b720baec
|
@ -142,6 +142,7 @@ dependencies = [
|
||||||
name = "genetic"
|
name = "genetic"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"lazy_static",
|
||||||
"macroquad",
|
"macroquad",
|
||||||
"nalgebra",
|
"nalgebra",
|
||||||
"rand",
|
"rand",
|
||||||
|
@ -203,6 +204,12 @@ version = "1.0.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440"
|
checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lazy_static"
|
||||||
|
version = "1.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lewton"
|
name = "lewton"
|
||||||
version = "0.9.4"
|
version = "0.9.4"
|
||||||
|
|
|
@ -13,6 +13,7 @@ rand_distr = "0.4.3"
|
||||||
serde = { version = "1.0.152", features = ["derive"] }
|
serde = { version = "1.0.152", features = ["derive"] }
|
||||||
serde_json = "1.0.91"
|
serde_json = "1.0.91"
|
||||||
tinyfiledialogs = "3.9.1"
|
tinyfiledialogs = "3.9.1"
|
||||||
|
lazy_static = "1.4"
|
||||||
|
|
||||||
[profile.dev]
|
[profile.dev]
|
||||||
opt-level = 3
|
opt-level = 3
|
||||||
|
|
29
src/nn.rs
29
src/nn.rs
|
@ -1,10 +1,14 @@
|
||||||
use macroquad::{prelude::*, rand::gen_range};
|
use macroquad::{prelude::*, rand::gen_range};
|
||||||
use nalgebra::*;
|
use nalgebra::*;
|
||||||
use r::Rng;
|
use r::Rng;
|
||||||
use rand_distr::StandardNormal;
|
use rand_distr::Normal;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
extern crate rand as r;
|
extern crate rand as r;
|
||||||
|
|
||||||
|
lazy_static::lazy_static! {
|
||||||
|
static ref CONNECTION_DISTRIBUTION: Normal<f32> = Normal::new(0.0, 0.75).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)]
|
#[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)]
|
||||||
|
|
||||||
pub enum ActivationFunc {
|
pub enum ActivationFunc {
|
||||||
|
@ -39,8 +43,12 @@ impl NN {
|
||||||
.zip(config.iter().skip(1))
|
.zip(config.iter().skip(1))
|
||||||
.map(|(&curr, &last)| {
|
.map(|(&curr, &last)| {
|
||||||
// DMatrix::from_fn(last, curr + 1, |_, _| gen_range(-1., 1.))
|
// DMatrix::from_fn(last, curr + 1, |_, _| gen_range(-1., 1.))
|
||||||
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
|
DMatrix::<f32>::from_distribution(
|
||||||
* (2. / last as f32).sqrt()
|
last,
|
||||||
|
curr + 1,
|
||||||
|
&*CONNECTION_DISTRIBUTION,
|
||||||
|
&mut rng,
|
||||||
|
) * (2. / last as f32).sqrt()
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
|
|
||||||
|
@ -75,7 +83,9 @@ impl NN {
|
||||||
if gen_range(0., 1.) < self.mut_rate {
|
if gen_range(0., 1.) < self.mut_rate {
|
||||||
// *ele += gen_range(-1., 1.);
|
// *ele += gen_range(-1., 1.);
|
||||||
// *ele = gen_range(-1., 1.);
|
// *ele = gen_range(-1., 1.);
|
||||||
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
|
*ele +=
|
||||||
|
r::thread_rng().sample::<f32, Normal<f32>>(CONNECTION_DISTRIBUTION.clone());
|
||||||
|
*ele = ele.min(10.0).max(-10.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -85,12 +95,11 @@ impl NN {
|
||||||
// println!("inputs: {:?}", inputs);
|
// println!("inputs: {:?}", inputs);
|
||||||
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs.to_vec());
|
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs.to_vec());
|
||||||
for i in 0..self.config.len() - 1 {
|
for i in 0..self.config.len() - 1 {
|
||||||
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
|
let row = y.insert_row(self.config[i] - 1, 1.);
|
||||||
match self.activ_func {
|
y = (&self.weights[i] * row).map(|x| match self.activ_func {
|
||||||
ActivationFunc::ReLU => x.max(0.),
|
ActivationFunc::ReLU => x.max(0.),
|
||||||
ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()),
|
ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()),
|
||||||
ActivationFunc::Tanh => x.tanh(),
|
ActivationFunc::Tanh => x.tanh(),
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
y.column(0).data.into_slice().to_vec()
|
y.column(0).data.into_slice().to_vec()
|
||||||
|
|
Loading…
Reference in New Issue