Use a normal distribution for mutations
Rather than StandardNormal, weights are generated with a Normal distribution centered on 0 with a standard deviation of 0.75, and mutations are performed by adding a number from that distribution to the existing weight, rather than replacing the weight entirely.
This commit is contained in:
parent
f04309f678
commit
49b720baec
|
@ -142,6 +142,7 @@ dependencies = [
|
|||
name = "genetic"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"macroquad",
|
||||
"nalgebra",
|
||||
"rand",
|
||||
|
@ -203,6 +204,12 @@ version = "1.0.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440"
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
|
||||
[[package]]
|
||||
name = "lewton"
|
||||
version = "0.9.4"
|
||||
|
|
|
@ -13,6 +13,7 @@ rand_distr = "0.4.3"
|
|||
serde = { version = "1.0.152", features = ["derive"] }
|
||||
serde_json = "1.0.91"
|
||||
tinyfiledialogs = "3.9.1"
|
||||
lazy_static = "1.4"
|
||||
|
||||
[profile.dev]
|
||||
opt-level = 3
|
||||
|
|
23
src/nn.rs
23
src/nn.rs
|
@ -1,10 +1,14 @@
|
|||
use macroquad::{prelude::*, rand::gen_range};
|
||||
use nalgebra::*;
|
||||
use r::Rng;
|
||||
use rand_distr::StandardNormal;
|
||||
use rand_distr::Normal;
|
||||
use serde::{Deserialize, Serialize};
|
||||
extern crate rand as r;
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref CONNECTION_DISTRIBUTION: Normal<f32> = Normal::new(0.0, 0.75).unwrap();
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
|
||||
pub enum ActivationFunc {
|
||||
|
@ -39,8 +43,12 @@ impl NN {
|
|||
.zip(config.iter().skip(1))
|
||||
.map(|(&curr, &last)| {
|
||||
// DMatrix::from_fn(last, curr + 1, |_, _| gen_range(-1., 1.))
|
||||
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
|
||||
* (2. / last as f32).sqrt()
|
||||
DMatrix::<f32>::from_distribution(
|
||||
last,
|
||||
curr + 1,
|
||||
&*CONNECTION_DISTRIBUTION,
|
||||
&mut rng,
|
||||
) * (2. / last as f32).sqrt()
|
||||
})
|
||||
.collect(),
|
||||
|
||||
|
@ -75,7 +83,9 @@ impl NN {
|
|||
if gen_range(0., 1.) < self.mut_rate {
|
||||
// *ele += gen_range(-1., 1.);
|
||||
// *ele = gen_range(-1., 1.);
|
||||
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
|
||||
*ele +=
|
||||
r::thread_rng().sample::<f32, Normal<f32>>(CONNECTION_DISTRIBUTION.clone());
|
||||
*ele = ele.min(10.0).max(-10.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -85,12 +95,11 @@ impl NN {
|
|||
// println!("inputs: {:?}", inputs);
|
||||
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs.to_vec());
|
||||
for i in 0..self.config.len() - 1 {
|
||||
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
|
||||
match self.activ_func {
|
||||
let row = y.insert_row(self.config[i] - 1, 1.);
|
||||
y = (&self.weights[i] * row).map(|x| match self.activ_func {
|
||||
ActivationFunc::ReLU => x.max(0.),
|
||||
ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()),
|
||||
ActivationFunc::Tanh => x.tanh(),
|
||||
}
|
||||
});
|
||||
}
|
||||
y.column(0).data.into_slice().to_vec()
|
||||
|
|
Loading…
Reference in New Issue