crossover, mutation

This commit is contained in:
sparshg 2022-10-10 01:16:27 +05:30
parent 44e2cb36f4
commit 91bc4e27a4
5 changed files with 37 additions and 35 deletions

1
Cargo.lock generated
View File

@ -291,6 +291,7 @@ dependencies = [
"num-rational", "num-rational",
"num-traits", "num-traits",
"rand", "rand",
"rand_distr",
"simba", "simba",
"typenum", "typenum",
] ]

View File

@ -7,6 +7,6 @@ edition = "2021"
[dependencies] [dependencies]
macroquad = "0.3.24" macroquad = "0.3.24"
nalgebra = { version = "0.31.1", features = ["rand-no-std"] } nalgebra = { version = "0.31.1", features = ["rand-no-std", "rand"] }
rand = "0.8.5" rand = "0.8.5"
rand_distr = "0.4.3" rand_distr = "0.4.3"

View File

@ -1,7 +1,6 @@
mod asteroids; mod asteroids;
mod nn; mod nn;
mod player; mod player;
mod utils;
mod world; mod world;
use macroquad::prelude::*; use macroquad::prelude::*;
@ -17,8 +16,11 @@ async fn main() {
}; };
set_camera(&cam); set_camera(&cam);
let mut world = World::new(); let mut world = World::new();
let nn = NN::new(vec![2, 3, 3]); let mut nn = NN::new(vec![1, 2, 1]);
println!("{:?}", nn.feed_forward(vec![2., 3.])); println!("{} {}", nn.weights[0], nn.weights[1]);
nn.mutation();
println!("{} {}", nn.weights[0], nn.weights[1]);
loop { loop {
// clear_background(BLACK); // clear_background(BLACK);
// if !world.over { // if !world.over {

View File

@ -1,7 +1,9 @@
use nalgebra::*; use nalgebra::*;
use r::Rng;
use rand_distr::StandardNormal; use rand_distr::StandardNormal;
extern crate rand as r; extern crate rand as r;
#[derive(PartialEq, Debug, Clone, Copy)]
enum ActivationFunc { enum ActivationFunc {
Sigmoid, Sigmoid,
Tanh, Tanh,
@ -9,8 +11,9 @@ enum ActivationFunc {
} }
pub struct NN { pub struct NN {
config: Vec<usize>, config: Vec<usize>,
weights: Vec<DMatrix<f32>>, pub weights: Vec<DMatrix<f32>>,
activ_func: ActivationFunc, activ_func: ActivationFunc,
mut_rate: f32,
} }
impl NN { impl NN {
@ -36,6 +39,32 @@ impl NN {
.collect(), .collect(),
activ_func: ActivationFunc::ReLU, activ_func: ActivationFunc::ReLU,
mut_rate: 0.05,
}
}
pub fn crossover(a: &NN, b: &NN) -> Self {
assert_eq!(a.config, b.config, "NN configs not same.");
Self {
config: a.config.to_owned(),
activ_func: a.activ_func,
mut_rate: a.mut_rate,
weights: a
.weights
.iter()
.zip(b.weights.iter())
.map(|(m1, m2)| m1.zip_map(m2, |ele1, ele2| if r::random() { ele1 } else { ele2 }))
.collect(),
}
}
pub fn mutation(&mut self) {
for weight in &mut self.weights {
for ele in weight {
if r::random() {
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal) * 0.05;
}
}
} }
} }

View File

@ -1,30 +0,0 @@
use macroquad::prelude::*;
pub fn rotate_vec(vec: Vec2, angle: f32) -> Vec2 {
vec2(angle.cos(), angle.sin()).rotate(vec)
}
pub fn draw_polygon(x: f32, y: f32, points: Vec<Vec2>, color: Color) {
let points_length = points.len();
let mut vertices = Vec::with_capacity(points_length as usize + 2);
let mut indices = Vec::<u16>::with_capacity(points_length as usize * 3);
for (i, point) in points.iter().enumerate() {
let vertex = macroquad::models::Vertex {
position: Vec3::new(x + point.x, y + point.y, 0.0),
uv: Vec2::default(),
color,
};
vertices.push(vertex);
indices.extend_from_slice(&[0, i as u16 + 1, i as u16 + 2]);
}
let mesh = Mesh {
vertices,
indices,
texture: None,
};
draw_mesh(&mesh);
}