@ -1,10 +1,14 @@
use macroquad ::{ prelude ::* , rand ::gen_range } ;
use nalgebra ::* ;
use r ::Rng ;
use rand_distr ::Standard Normal;
use rand_distr ::Normal ;
use serde ::{ Deserialize , Serialize } ;
extern crate rand as r ;
lazy_static ::lazy_static ! {
static ref CONNECTION_DISTRIBUTION : Normal < f32 > = Normal ::new ( 0.0 , 0.75 ) . unwrap ( ) ;
}
#[ derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize) ]
pub enum ActivationFunc {
@ -39,8 +43,12 @@ impl NN {
. zip ( config . iter ( ) . skip ( 1 ) )
. map ( | ( & curr , & last ) | {
// DMatrix::from_fn(last, curr + 1, |_, _| gen_range(-1., 1.))
DMatrix ::< f32 > ::from_distribution ( last , curr + 1 , & StandardNormal , & mut rng )
* ( 2. / last as f32 ) . sqrt ( )
DMatrix ::< f32 > ::from_distribution (
last ,
curr + 1 ,
& * CONNECTION_DISTRIBUTION ,
& mut rng ,
) * ( 2. / last as f32 ) . sqrt ( )
} )
. collect ( ) ,
@ -75,7 +83,9 @@ impl NN {
if gen_range ( 0. , 1. ) < self . mut_rate {
// *ele += gen_range(-1., 1.);
// *ele = gen_range(-1., 1.);
* ele = r ::thread_rng ( ) . sample ::< f32 , StandardNormal > ( StandardNormal ) ;
* ele + =
r ::thread_rng ( ) . sample ::< f32 , Normal < f32 > > ( CONNECTION_DISTRIBUTION . clone ( ) ) ;
* ele = ele . min ( 10.0 ) . max ( - 10.0 ) ;
}
}
}
@ -85,12 +95,11 @@ impl NN {
// println!("inputs: {:?}", inputs);
let mut y = DMatrix ::from_vec ( inputs . len ( ) , 1 , inputs . to_vec ( ) ) ;
for i in 0 . . self . config . len ( ) - 1 {
y = ( & self . weights [ i ] * y . insert_row ( self . config [ i ] - 1 , 1. ) ) . map ( | x | {
match self . activ_func {
ActivationFunc ::ReLU = > x . max ( 0. ) ,
ActivationFunc ::Sigmoid = > 1. / ( 1. + ( - x ) . exp ( ) ) ,
ActivationFunc ::Tanh = > x . tanh ( ) ,
}
let row = y . insert_row ( self . config [ i ] - 1 , 1. ) ;
y = ( & self . weights [ i ] * row ) . map ( | x | match self . activ_func {
ActivationFunc ::ReLU = > x . max ( 0. ) ,
ActivationFunc ::Sigmoid = > 1. / ( 1. + ( - x ) . exp ( ) ) ,
ActivationFunc ::Tanh = > x . tanh ( ) ,
} ) ;
}
y . column ( 0 ) . data . into_slice ( ) . to_vec ( )