2023-01-06 22:47:57 +00:00
|
|
|
use macroquad::{prelude::*, rand::gen_range};
|
2022-10-09 18:14:22 +00:00
|
|
|
use nalgebra::*;
|
2022-10-09 19:46:27 +00:00
|
|
|
use r::Rng;
|
2022-10-09 18:14:22 +00:00
|
|
|
use rand_distr::StandardNormal;
|
2023-01-06 09:10:29 +00:00
|
|
|
use serde::{Deserialize, Serialize};
|
2022-10-09 18:14:22 +00:00
|
|
|
extern crate rand as r;
|
|
|
|
|
2023-01-12 09:55:26 +00:00
|
|
|
#[derive(PartialEq, Debug, Clone, Copy, Serialize, Deserialize)]
|
2023-01-04 19:57:59 +00:00
|
|
|
|
2023-01-11 21:31:03 +00:00
|
|
|
pub enum ActivationFunc {
|
2023-01-12 09:55:26 +00:00
|
|
|
ReLU,
|
2022-10-09 18:14:22 +00:00
|
|
|
Sigmoid,
|
|
|
|
Tanh,
|
|
|
|
}
|
2022-10-10 18:36:14 +00:00
|
|
|
|
2023-01-12 09:55:26 +00:00
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
2022-10-09 18:14:22 +00:00
|
|
|
pub struct NN {
|
2022-10-10 18:36:14 +00:00
|
|
|
pub config: Vec<usize>,
|
2023-01-06 22:47:57 +00:00
|
|
|
pub weights: Vec<DMatrix<f32>>,
|
2023-01-11 21:31:03 +00:00
|
|
|
pub activ_func: ActivationFunc,
|
|
|
|
pub mut_rate: f32,
|
2022-10-09 18:14:22 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
impl NN {
|
|
|
|
// Vec of number of neurons in input, hidden 1, hidden 2, ..., output layers
|
2023-01-12 09:55:26 +00:00
|
|
|
pub fn new(config: Vec<usize>, mut_rate: f32, activ: ActivationFunc) -> Self {
|
2022-10-09 18:14:22 +00:00
|
|
|
let mut rng = r::thread_rng();
|
|
|
|
|
|
|
|
Self {
|
|
|
|
config: config
|
|
|
|
.iter()
|
|
|
|
.enumerate()
|
|
|
|
.map(|(i, &x)| if i != config.len() - 1 { x + 1 } else { x })
|
|
|
|
.collect(),
|
|
|
|
|
|
|
|
// He-et-al Initialization
|
|
|
|
weights: config
|
|
|
|
.iter()
|
|
|
|
.zip(config.iter().skip(1))
|
|
|
|
.map(|(&curr, &last)| {
|
2023-01-04 19:57:59 +00:00
|
|
|
// DMatrix::from_fn(last, curr + 1, |_, _| gen_range(-1., 1.))
|
2022-10-09 18:14:22 +00:00
|
|
|
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
|
|
|
|
* (2. / last as f32).sqrt()
|
|
|
|
})
|
|
|
|
.collect(),
|
2022-10-09 18:33:00 +00:00
|
|
|
|
2023-01-11 21:31:03 +00:00
|
|
|
mut_rate,
|
2023-01-12 09:55:26 +00:00
|
|
|
activ_func: activ,
|
2022-10-09 19:46:27 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn crossover(a: &NN, b: &NN) -> Self {
|
|
|
|
assert_eq!(a.config, b.config, "NN configs not same.");
|
|
|
|
Self {
|
|
|
|
config: a.config.to_owned(),
|
|
|
|
activ_func: a.activ_func,
|
|
|
|
mut_rate: a.mut_rate,
|
|
|
|
weights: a
|
|
|
|
.weights
|
|
|
|
.iter()
|
|
|
|
.zip(b.weights.iter())
|
2023-01-11 18:56:03 +00:00
|
|
|
.map(|(m1, m2)| {
|
|
|
|
m1.zip_map(
|
|
|
|
m2,
|
|
|
|
|ele1, ele2| if gen_range(0., 1.) < 0.5 { ele1 } else { ele2 },
|
|
|
|
)
|
|
|
|
})
|
2022-10-09 19:46:27 +00:00
|
|
|
.collect(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-10-10 18:36:14 +00:00
|
|
|
pub fn mutate(&mut self) {
|
2022-10-09 19:46:27 +00:00
|
|
|
for weight in &mut self.weights {
|
|
|
|
for ele in weight {
|
2022-10-22 19:40:19 +00:00
|
|
|
if gen_range(0., 1.) < self.mut_rate {
|
|
|
|
// *ele += gen_range(-1., 1.);
|
2023-01-06 22:47:57 +00:00
|
|
|
// *ele = gen_range(-1., 1.);
|
|
|
|
*ele = r::thread_rng().sample::<f32, StandardNormal>(StandardNormal);
|
2022-10-09 19:46:27 +00:00
|
|
|
}
|
|
|
|
}
|
2022-10-09 18:14:22 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-01-07 21:59:59 +00:00
|
|
|
pub fn feed_forward(&self, inputs: &Vec<f32>) -> Vec<f32> {
|
2023-01-04 19:57:59 +00:00
|
|
|
// println!("inputs: {:?}", inputs);
|
2023-01-07 21:59:59 +00:00
|
|
|
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs.to_vec());
|
2023-01-08 19:07:35 +00:00
|
|
|
for i in 0..self.config.len() - 1 {
|
2022-10-09 18:14:22 +00:00
|
|
|
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
|
|
|
|
match self.activ_func {
|
|
|
|
ActivationFunc::ReLU => x.max(0.),
|
|
|
|
ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()),
|
|
|
|
ActivationFunc::Tanh => x.tanh(),
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
2022-10-09 18:33:00 +00:00
|
|
|
y.column(0).data.into_slice().to_vec()
|
2022-10-09 18:14:22 +00:00
|
|
|
}
|
2023-01-06 09:10:29 +00:00
|
|
|
|
2023-01-08 19:07:35 +00:00
|
|
|
pub fn draw(&self, width: f32, height: f32, inputs: &Vec<f32>, outputs: &Vec<f32>, bias: bool) {
|
2023-01-06 22:47:57 +00:00
|
|
|
draw_rectangle_lines(-width * 0.5, -height * 0.5, width, height, 2., WHITE);
|
|
|
|
|
|
|
|
let width = width * 0.8;
|
|
|
|
let height = height * 0.8;
|
|
|
|
let vspace = height / (self.config.iter().max().unwrap() - 1) as f32;
|
2023-01-08 20:40:34 +00:00
|
|
|
let mut p1s: Vec<(f32, f32)>;
|
2023-01-06 22:47:57 +00:00
|
|
|
let mut p2s: Vec<(f32, f32)> = Vec::new();
|
|
|
|
for (i, layer) in self
|
|
|
|
.config
|
|
|
|
.iter()
|
2023-01-08 19:07:35 +00:00
|
|
|
.take(self.config.len() - 1)
|
|
|
|
.map(|x| x - if bias { 0 } else { 1 })
|
|
|
|
.chain(self.config.last().map(|&x| x))
|
2023-01-06 22:47:57 +00:00
|
|
|
.enumerate()
|
|
|
|
{
|
|
|
|
p1s = p2s;
|
|
|
|
p2s = Vec::new();
|
2023-01-08 19:07:35 +00:00
|
|
|
for neuron in 0..layer {
|
2023-01-06 22:47:57 +00:00
|
|
|
p2s.push((
|
|
|
|
i as f32 * width / (self.config.len() - 1) as f32 - width * 0.5,
|
|
|
|
neuron as f32 * vspace - (vspace * (layer - 1) as f32) * 0.5,
|
|
|
|
));
|
|
|
|
}
|
2023-01-07 21:59:59 +00:00
|
|
|
for (k, j, p1, p2) in p1s.iter().enumerate().flat_map(|(k, x)| {
|
|
|
|
p2s.iter()
|
2023-01-08 19:07:35 +00:00
|
|
|
.take(
|
|
|
|
p2s.len()
|
|
|
|
- if i == self.config.len() - 1 || !bias {
|
|
|
|
0
|
|
|
|
} else {
|
|
|
|
1
|
|
|
|
},
|
|
|
|
)
|
2023-01-07 21:59:59 +00:00
|
|
|
.enumerate()
|
|
|
|
.map(move |(j, y)| (k, j, *x, *y))
|
|
|
|
}) {
|
|
|
|
let weight = *self.weights[i - 1].index((j, k));
|
|
|
|
let c = if weight < 0. { 0. } else { 1. };
|
2023-01-06 22:47:57 +00:00
|
|
|
draw_line(
|
|
|
|
p1.0,
|
|
|
|
p1.1,
|
|
|
|
p2.0,
|
|
|
|
p2.1,
|
2023-01-08 19:07:35 +00:00
|
|
|
1.5,
|
2023-01-07 21:59:59 +00:00
|
|
|
Color::new(1., c, c, weight.abs()),
|
2023-01-06 22:47:57 +00:00
|
|
|
);
|
|
|
|
}
|
2023-01-08 19:07:35 +00:00
|
|
|
|
|
|
|
let mut inputs = inputs.to_vec();
|
|
|
|
inputs.push(1.);
|
|
|
|
|
|
|
|
for (j, p) in p1s.iter().enumerate() {
|
2023-01-06 22:47:57 +00:00
|
|
|
draw_circle(p.0, p.1, 10., WHITE);
|
2023-01-08 19:07:35 +00:00
|
|
|
draw_circle(p.0, p.1, 8., BLACK);
|
|
|
|
draw_circle(
|
|
|
|
p.0,
|
|
|
|
p.1,
|
|
|
|
8.,
|
|
|
|
if i == 1 && inputs.len() > 1 {
|
|
|
|
let c = if inputs[j] < 0. { 0. } else { 1. };
|
|
|
|
Color::new(1., c, c, inputs[j].abs())
|
|
|
|
} else {
|
|
|
|
BLACK
|
|
|
|
},
|
|
|
|
);
|
2023-01-08 19:26:00 +00:00
|
|
|
if i == 1 && inputs.len() > 1 {
|
|
|
|
draw_text(
|
|
|
|
&format!("{:.2}", inputs[j]),
|
|
|
|
p.0 - if inputs[j] < 0. { 50. } else { 42. },
|
|
|
|
p.1 + 4.,
|
|
|
|
16.,
|
|
|
|
WHITE,
|
|
|
|
);
|
|
|
|
}
|
2023-01-06 22:47:57 +00:00
|
|
|
}
|
|
|
|
}
|
2023-01-08 19:07:35 +00:00
|
|
|
for (j, p) in p2s.iter().enumerate() {
|
2023-01-06 22:47:57 +00:00
|
|
|
draw_circle(p.0, p.1, 10., WHITE);
|
2023-01-08 19:07:35 +00:00
|
|
|
draw_circle(p.0, p.1, 8., BLACK);
|
|
|
|
if !outputs.is_empty() {
|
|
|
|
draw_circle(p.0, p.1, 8., Color::new(1., 1., 1., outputs[j]));
|
2023-01-08 19:26:00 +00:00
|
|
|
draw_text(
|
|
|
|
&format!("{:.2}", outputs[j]),
|
|
|
|
p.0 + 14.,
|
|
|
|
p.1 + 4.,
|
|
|
|
16.,
|
|
|
|
WHITE,
|
|
|
|
);
|
2023-01-08 19:07:35 +00:00
|
|
|
}
|
2023-01-06 22:47:57 +00:00
|
|
|
}
|
2023-01-09 19:25:18 +00:00
|
|
|
draw_rectangle(width * 0.47, height * 0.47, 10., 10., RED);
|
|
|
|
let params = TextParams {
|
|
|
|
font_size: 40,
|
|
|
|
font_scale: 0.5,
|
|
|
|
..Default::default()
|
|
|
|
};
|
|
|
|
draw_text_ex("-ve", width * 0.47 + 20., height * 0.47 + 10., params);
|
|
|
|
draw_rectangle(width * 0.47, height * 0.47 + 20., 10., 10., WHITE);
|
|
|
|
draw_text_ex("+ve", width * 0.47 + 20., height * 0.47 + 30., params);
|
2023-01-06 22:47:57 +00:00
|
|
|
}
|
|
|
|
|
2023-01-06 09:10:29 +00:00
|
|
|
pub fn export(&self) -> String {
|
|
|
|
serde_json::to_string(self).unwrap()
|
|
|
|
}
|
|
|
|
|
2023-01-08 20:09:10 +00:00
|
|
|
pub fn import(path: &str) -> NN {
|
|
|
|
let json = std::fs::read_to_string(path).expect("Unable to read file");
|
2023-01-06 09:10:29 +00:00
|
|
|
serde_json::from_str(&json).unwrap()
|
|
|
|
}
|
2022-10-09 18:14:22 +00:00
|
|
|
}
|