crossover, mutation
This commit is contained in:
parent
44e2cb36f4
commit
91bc4e27a4
|
@ -291,6 +291,7 @@ dependencies = [
|
||||||
"num-rational",
|
"num-rational",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"rand",
|
"rand",
|
||||||
|
"rand_distr",
|
||||||
"simba",
|
"simba",
|
||||||
"typenum",
|
"typenum",
|
||||||
]
|
]
|
||||||
|
|
|
@ -7,6 +7,6 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
macroquad = "0.3.24"
|
macroquad = "0.3.24"
|
||||||
nalgebra = { version = "0.31.1", features = ["rand-no-std"] }
|
nalgebra = { version = "0.31.1", features = ["rand-no-std", "rand"] }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
rand_distr = "0.4.3"
|
rand_distr = "0.4.3"
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
mod asteroids;
|
mod asteroids;
|
||||||
mod nn;
|
mod nn;
|
||||||
mod player;
|
mod player;
|
||||||
mod utils;
|
|
||||||
mod world;
|
mod world;
|
||||||
|
|
||||||
use macroquad::prelude::*;
|
use macroquad::prelude::*;
|
||||||
|
@ -17,8 +16,11 @@ async fn main() {
|
||||||
};
|
};
|
||||||
set_camera(&cam);
|
set_camera(&cam);
|
||||||
let mut world = World::new();
|
let mut world = World::new();
|
||||||
let nn = NN::new(vec![2, 3, 3]);
|
let mut nn = NN::new(vec![1, 2, 1]);
|
||||||
println!("{:?}", nn.feed_forward(vec![2., 3.]));
|
println!("{} {}", nn.weights[0], nn.weights[1]);
|
||||||
|
nn.mutation();
|
||||||
|
println!("{} {}", nn.weights[0], nn.weights[1]);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// clear_background(BLACK);
|
// clear_background(BLACK);
|
||||||
// if !world.over {
|
// if !world.over {
|
||||||
|
|
31
src/nn.rs
31
src/nn.rs
|
@ -1,7 +1,9 @@
|
||||||
use nalgebra::*;
|
use nalgebra::*;
|
||||||
|
use r::Rng;
|
||||||
use rand_distr::StandardNormal;
|
use rand_distr::StandardNormal;
|
||||||
extern crate rand as r;
|
extern crate rand as r;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Debug, Clone, Copy)]
|
||||||
enum ActivationFunc {
|
enum ActivationFunc {
|
||||||
Sigmoid,
|
Sigmoid,
|
||||||
Tanh,
|
Tanh,
|
||||||
|
@ -9,8 +11,9 @@ enum ActivationFunc {
|
||||||
}
|
}
|
||||||
pub struct NN {
|
pub struct NN {
|
||||||
config: Vec<usize>,
|
config: Vec<usize>,
|
||||||
weights: Vec<DMatrix<f32>>,
|
pub weights: Vec<DMatrix<f32>>,
|
||||||
activ_func: ActivationFunc,
|
activ_func: ActivationFunc,
|
||||||
|
mut_rate: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl NN {
|
impl NN {
|
||||||
|
@ -36,6 +39,32 @@ impl NN {
|
||||||
.collect(),
|
.collect(),
|
||||||
|
|
||||||
activ_func: ActivationFunc::ReLU,
|
activ_func: ActivationFunc::ReLU,
|
||||||
|
mut_rate: 0.05,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn crossover(a: &NN, b: &NN) -> Self {
|
||||||
|
assert_eq!(a.config, b.config, "NN configs not same.");
|
||||||
|
Self {
|
||||||
|
config: a.config.to_owned(),
|
||||||
|
activ_func: a.activ_func,
|
||||||
|
mut_rate: a.mut_rate,
|
||||||
|
weights: a
|
||||||
|
.weights
|
||||||
|
.iter()
|
||||||
|
.zip(b.weights.iter())
|
||||||
|
.map(|(m1, m2)| m1.zip_map(m2, |ele1, ele2| if r::random() { ele1 } else { ele2 }))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mutation(&mut self) {
|
||||||
|
for weight in &mut self.weights {
|
||||||
|
for ele in weight {
|
||||||
|
if r::random() {
|
||||||
|
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal) * 0.05;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
30
src/utils.rs
30
src/utils.rs
|
@ -1,30 +0,0 @@
|
||||||
use macroquad::prelude::*;
|
|
||||||
|
|
||||||
pub fn rotate_vec(vec: Vec2, angle: f32) -> Vec2 {
|
|
||||||
vec2(angle.cos(), angle.sin()).rotate(vec)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn draw_polygon(x: f32, y: f32, points: Vec<Vec2>, color: Color) {
|
|
||||||
let points_length = points.len();
|
|
||||||
let mut vertices = Vec::with_capacity(points_length as usize + 2);
|
|
||||||
let mut indices = Vec::<u16>::with_capacity(points_length as usize * 3);
|
|
||||||
|
|
||||||
for (i, point) in points.iter().enumerate() {
|
|
||||||
let vertex = macroquad::models::Vertex {
|
|
||||||
position: Vec3::new(x + point.x, y + point.y, 0.0),
|
|
||||||
uv: Vec2::default(),
|
|
||||||
color,
|
|
||||||
};
|
|
||||||
|
|
||||||
vertices.push(vertex);
|
|
||||||
indices.extend_from_slice(&[0, i as u16 + 1, i as u16 + 2]);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mesh = Mesh {
|
|
||||||
vertices,
|
|
||||||
indices,
|
|
||||||
texture: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
draw_mesh(&mesh);
|
|
||||||
}
|
|
Loading…
Reference in New Issue