neural network, feed forward

This commit is contained in:
sparshg 2022-10-09 23:44:22 +05:30
parent 3d633c435d
commit 80fb255c36
6 changed files with 266 additions and 22 deletions

206
Cargo.lock generated
View File

@ -19,6 +19,15 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "audir-sles" name = "audir-sles"
version = "0.1.0" version = "0.1.0"
@ -128,7 +137,9 @@ name = "genetic"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"macroquad", "macroquad",
"macroquad-particles", "nalgebra",
"rand",
"rand_distr",
] ]
[[package]] [[package]]
@ -194,6 +205,12 @@ version = "0.2.134"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb" checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb"
[[package]]
name = "libm"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "292a948cd991e376cf75541fe5b97a1081d713c618b4f1b9500f8844e49eb565"
[[package]] [[package]]
name = "macroquad" name = "macroquad"
version = "0.3.24" version = "0.3.24"
@ -210,15 +227,6 @@ dependencies = [
"quad-snd", "quad-snd",
] ]
[[package]]
name = "macroquad-particles"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8505dc035e4d68799a53a459afef0bc345410039325b5928af2e0d83eb15de40"
dependencies = [
"macroquad",
]
[[package]] [[package]]
name = "macroquad_macro" name = "macroquad_macro"
version = "0.1.7" version = "0.1.7"
@ -234,6 +242,15 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "matrixmultiply"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84"
dependencies = [
"rawpointer",
]
[[package]] [[package]]
name = "maybe-uninit" name = "maybe-uninit"
version = "2.0.0" version = "2.0.0"
@ -261,12 +278,49 @@ dependencies = [
"adler", "adler",
] ]
[[package]]
name = "nalgebra"
version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9e0a04ce089f9401aac565c740ed30c46291260f27d4911fdbaa6ca65fa3044"
dependencies = [
"approx",
"matrixmultiply",
"nalgebra-macros",
"num-complex",
"num-rational",
"num-traits",
"rand",
"simba",
"typenum",
]
[[package]]
name = "nalgebra-macros"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01fcc0b8149b4632adc89ac3b7b31a12fb6099a0317a4eb2ebff574ef7de7218"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "ndk-sys" name = "ndk-sys"
version = "0.2.2" version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1bcdd74c20ad5d95aacd60ef9ba40fdf77f767051040541df557b7a9b2a2121" checksum = "e1bcdd74c20ad5d95aacd60ef9ba40fdf77f767051040541df557b7a9b2a2121"
[[package]]
name = "num-complex"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ae39348c8bc5fbd7f40c727a9925f03517afd2ab27d46702108b6a7e5414c19"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "num-integer" name = "num-integer"
version = "0.1.45" version = "0.1.45"
@ -295,6 +349,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"libm",
] ]
[[package]] [[package]]
@ -321,6 +376,12 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1" checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1"
[[package]]
name = "paste"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1de2e551fb905ac83f73f7aedf2f0cb4a0da7e35efa24a202a936269f1f18e1"
[[package]] [[package]]
name = "png" name = "png"
version = "0.17.6" version = "0.17.6"
@ -333,6 +394,21 @@ dependencies = [
"miniz_oxide", "miniz_oxide",
] ]
[[package]]
name = "ppv-lite86"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
[[package]]
name = "proc-macro2"
version = "1.0.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b"
dependencies = [
"unicode-ident",
]
[[package]] [[package]]
name = "quad-alsa-sys" name = "quad-alsa-sys"
version = "0.3.2" version = "0.3.2"
@ -361,6 +437,83 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "quote"
version = "1.0.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
dependencies = [
"num-traits",
"rand",
]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "safe_arch"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "794821e4ccb0d9f979512f9c1973480123f9bd62a90d74ab0f9426fcf8f4a529"
dependencies = [
"bytemuck",
]
[[package]]
name = "simba"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c48e45e5961033db030b56ad67aef22e9c908c493a6e8348c0a0f6b93433cd77"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
"wide",
]
[[package]] [[package]]
name = "smallvec" name = "smallvec"
version = "0.6.14" version = "0.6.14"
@ -370,12 +523,35 @@ dependencies = [
"maybe-uninit", "maybe-uninit",
] ]
[[package]]
name = "syn"
version = "1.0.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]] [[package]]
name = "ttf-parser" name = "ttf-parser"
version = "0.15.2" version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b3e06c9b9d80ed6b745c7159c40b311ad2916abb34a49e9be2653b90db0d8dd" checksum = "7b3e06c9b9d80ed6b745c7159c40b311ad2916abb34a49e9be2653b90db0d8dd"
[[package]]
name = "typenum"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
[[package]]
name = "unicode-ident"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3"
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.4" version = "0.9.4"
@ -388,6 +564,16 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wide"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae41ecad2489a1655c8ef8489444b0b113c0a0c795944a3572a0931cf7d2525c"
dependencies = [
"bytemuck",
"safe_arch",
]
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.3.9" version = "0.3.9"

View File

@ -7,4 +7,6 @@ edition = "2021"
[dependencies] [dependencies]
macroquad = "0.3.24" macroquad = "0.3.24"
macroquad-particles = "0.1.1" nalgebra = { version = "0.31.1", features = ["rand-no-std"] }
rand = "0.8.5"
rand_distr = "0.4.3"

View File

@ -1,11 +1,11 @@
mod asteroids; mod asteroids;
mod nn;
mod player; mod player;
mod utils; mod utils;
mod world; mod world;
use asteroids::Asteroid;
use macroquad::prelude::*; use macroquad::prelude::*;
use player::Player; use nn::NN;
use world::World; use world::World;
#[macroquad::main("Camera")] #[macroquad::main("Camera")]
@ -17,13 +17,14 @@ async fn main() {
}; };
set_camera(&cam); set_camera(&cam);
let mut world = World::new(); let mut world = World::new();
let nn = NN::new(vec![2, 3, 2]);
nn.feed_forward(vec![2., 3.]);
loop { loop {
clear_background(BLACK); // clear_background(BLACK);
if !world.over { // if !world.over {
world.update(); // world.update();
} // }
world.draw(); // world.draw();
next_frame().await next_frame().await
} }
} }

55
src/nn.rs Normal file
View File

@ -0,0 +1,55 @@
use nalgebra::*;
use rand_distr::StandardNormal;
extern crate rand as r;
enum ActivationFunc {
Sigmoid,
Tanh,
ReLU,
}
pub struct NN {
config: Vec<usize>,
weights: Vec<DMatrix<f32>>,
activ_func: ActivationFunc,
}
impl NN {
// Vec of number of neurons in input, hidden 1, hidden 2, ..., output layers
pub fn new(config: Vec<usize>) -> Self {
let mut rng = r::thread_rng();
Self {
config: config
.iter()
.enumerate()
.map(|(i, &x)| if i != config.len() - 1 { x + 1 } else { x })
.collect(),
// He-et-al Initialization
weights: config
.iter()
.zip(config.iter().skip(1))
.map(|(&curr, &last)| {
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
* (2. / last as f32).sqrt()
})
.collect(),
activ_func: ActivationFunc::ReLU,
}
}
pub fn feed_forward(&self, inputs: Vec<f32>) {
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs);
for i in 0..self.config.len() - 1 {
println!("{} {}", y, self.weights[i]);
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
match self.activ_func {
ActivationFunc::ReLU => x.max(0.),
ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()),
ActivationFunc::Tanh => x.tanh(),
}
});
}
println!("{}", y);
}
}

View File

@ -19,8 +19,8 @@ pub struct Player {
impl Player { impl Player {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
dir: vec2(0., 1.), dir: vec2(0., -1.),
rot: PI / 2., rot: -PI / 2.,
drag: 0.001, drag: 0.001,
shot_interval: 0.3, shot_interval: 0.3,
alive: true, alive: true,

View File

@ -45,7 +45,7 @@ impl World {
)); ));
} }
AsteroidSize::Medium => { AsteroidSize::Medium => {
let rand = vec2(gen_range(-30., 30.), gen_range(-30., 30.)); let rand = vec2(gen_range(-40., 40.), gen_range(-40., 40.));
to_add.push(Asteroid::new_from( to_add.push(Asteroid::new_from(
asteroid.pos, asteroid.pos,
asteroid.vel + rand, asteroid.vel + rand,