Use a normal distribution for mutations
Rather than StandardNormal, weights are generated with a Normal distribution centered on 0 with a standard deviation of 0.75, and mutations are performed by adding a number from that distribution to the existing weight, rather than replacing the weight entirely.
This commit is contained in:
		
							parent
							
								
									f04309f678
								
							
						
					
					
						commit
						49b720baec
					
				|  | @ -142,6 +142,7 @@ dependencies = [ | |||
| name = "genetic" | ||||
| version = "0.1.0" | ||||
| dependencies = [ | ||||
|  "lazy_static", | ||||
|  "macroquad", | ||||
|  "nalgebra", | ||||
|  "rand", | ||||
|  | @ -203,6 +204,12 @@ version = "1.0.5" | |||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" | ||||
| 
 | ||||
| [[package]] | ||||
| name = "lazy_static" | ||||
| version = "1.4.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" | ||||
| 
 | ||||
| [[package]] | ||||
| name = "lewton" | ||||
| version = "0.9.4" | ||||
|  |  | |||
|  | @ -13,6 +13,7 @@ rand_distr = "0.4.3" | |||
| serde = { version = "1.0.152", features = ["derive"] } | ||||
| serde_json = "1.0.91" | ||||
| tinyfiledialogs = "3.9.1" | ||||
| lazy_static = "1.4" | ||||
| 
 | ||||
| [profile.dev] | ||||
| opt-level = 3 | ||||
|  |  | |||
							
								
								
									
										29
									
								
								src/nn.rs
								
								
								
								
							
							
						
						
									
										29
									
								
								src/nn.rs
								
								
								
								
							|  | @ -1,10 +1,14 @@ | |||
| use macroquad::{prelude::*, rand::gen_range}; | ||||
| use nalgebra::*; | ||||
| use r::Rng; | ||||
| use rand_distr::StandardNormal; | ||||
| use rand_distr::Normal; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| extern crate rand as r; | ||||
| 
 | ||||
| lazy_static::lazy_static! { | ||||
|   static ref CONNECTION_DISTRIBUTION: Normal<f32> = Normal::new(0.0, 0.75).unwrap(); | ||||
| } | ||||
| 
 | ||||
| #[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)] | ||||
| 
 | ||||
| pub enum ActivationFunc { | ||||
|  | @ -39,8 +43,12 @@ impl NN { | |||
|                 .zip(config.iter().skip(1)) | ||||
|                 .map(|(&curr, &last)| { | ||||
|                     // DMatrix::from_fn(last, curr + 1, |_, _| gen_range(-1., 1.))
 | ||||
|                     DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng) | ||||
|                         * (2. / last as f32).sqrt() | ||||
|                     DMatrix::<f32>::from_distribution( | ||||
|                         last, | ||||
|                         curr + 1, | ||||
|                         &*CONNECTION_DISTRIBUTION, | ||||
|                         &mut rng, | ||||
|                     ) * (2. / last as f32).sqrt() | ||||
|                 }) | ||||
|                 .collect(), | ||||
| 
 | ||||
|  | @ -75,7 +83,9 @@ impl NN { | |||
|                 if gen_range(0., 1.) < self.mut_rate { | ||||
|                     // *ele += gen_range(-1., 1.);
 | ||||
|                     // *ele = gen_range(-1., 1.);
 | ||||
|                     *ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal); | ||||
|                     *ele += | ||||
|                         r::thread_rng().sample::<f32, Normal<f32>>(CONNECTION_DISTRIBUTION.clone()); | ||||
|                     *ele = ele.min(10.0).max(-10.0); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | @ -85,12 +95,11 @@ impl NN { | |||
|         // println!("inputs: {:?}", inputs);
 | ||||
|         let mut y = DMatrix::from_vec(inputs.len(), 1, inputs.to_vec()); | ||||
|         for i in 0..self.config.len() - 1 { | ||||
|             y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| { | ||||
|                 match self.activ_func { | ||||
|                     ActivationFunc::ReLU => x.max(0.), | ||||
|                     ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()), | ||||
|                     ActivationFunc::Tanh => x.tanh(), | ||||
|                 } | ||||
|             let row = y.insert_row(self.config[i] - 1, 1.); | ||||
|             y = (&self.weights[i] * row).map(|x| match self.activ_func { | ||||
|                 ActivationFunc::ReLU => x.max(0.), | ||||
|                 ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()), | ||||
|                 ActivationFunc::Tanh => x.tanh(), | ||||
|             }); | ||||
|         } | ||||
|         y.column(0).data.into_slice().to_vec() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue