From 80fb255c360a5dfb468e585d6b13593e58b2a905 Mon Sep 17 00:00:00 2001 From: sparshg <43041139+sparshg@users.noreply.github.com> Date: Sun, 9 Oct 2022 23:44:22 +0530 Subject: [PATCH] neural network, feed forward --- Cargo.lock | 206 +++++++++++++++++++++++++++++++++++++++++++++++--- Cargo.toml | 4 +- src/main.rs | 17 +++-- src/nn.rs | 55 ++++++++++++++ src/player.rs | 4 +- src/world.rs | 2 +- 6 files changed, 266 insertions(+), 22 deletions(-) create mode 100644 src/nn.rs diff --git a/Cargo.lock b/Cargo.lock index 3aea678..556a31d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,6 +19,15 @@ dependencies = [ "version_check", ] +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "audir-sles" version = "0.1.0" @@ -128,7 +137,9 @@ name = "genetic" version = "0.1.0" dependencies = [ "macroquad", - "macroquad-particles", + "nalgebra", + "rand", + "rand_distr", ] [[package]] @@ -194,6 +205,12 @@ version = "0.2.134" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb" +[[package]] +name = "libm" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "292a948cd991e376cf75541fe5b97a1081d713c618b4f1b9500f8844e49eb565" + [[package]] name = "macroquad" version = "0.3.24" @@ -210,15 +227,6 @@ dependencies = [ "quad-snd", ] -[[package]] -name = "macroquad-particles" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8505dc035e4d68799a53a459afef0bc345410039325b5928af2e0d83eb15de40" -dependencies = [ - "macroquad", -] - [[package]] name = "macroquad_macro" version = "0.1.7" @@ -234,6 +242,15 @@ dependencies = [ "libc", ] +[[package]] +name = "matrixmultiply" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84" +dependencies = [ + "rawpointer", +] + [[package]] name = "maybe-uninit" version = "2.0.0" @@ -261,12 +278,49 @@ dependencies = [ "adler", ] +[[package]] +name = "nalgebra" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9e0a04ce089f9401aac565c740ed30c46291260f27d4911fdbaa6ca65fa3044" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "rand", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01fcc0b8149b4632adc89ac3b7b31a12fb6099a0317a4eb2ebff574ef7de7218" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "ndk-sys" version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1bcdd74c20ad5d95aacd60ef9ba40fdf77f767051040541df557b7a9b2a2121" +[[package]] +name = "num-complex" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ae39348c8bc5fbd7f40c727a9925f03517afd2ab27d46702108b6a7e5414c19" +dependencies = [ + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.45" @@ -295,6 +349,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -321,6 +376,12 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1" +[[package]] +name = "paste" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1de2e551fb905ac83f73f7aedf2f0cb4a0da7e35efa24a202a936269f1f18e1" + [[package]] name = "png" version = "0.17.6" @@ -333,6 +394,21 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "ppv-lite86" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" + +[[package]] +name = "proc-macro2" +version = "1.0.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b" +dependencies = [ + "unicode-ident", +] + [[package]] name = "quad-alsa-sys" version = "0.3.2" @@ -361,6 +437,83 @@ dependencies = [ "winapi", ] +[[package]] +name = "quote" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "safe_arch" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "794821e4ccb0d9f979512f9c1973480123f9bd62a90d74ab0f9426fcf8f4a529" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "simba" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c48e45e5961033db030b56ad67aef22e9c908c493a6e8348c0a0f6b93433cd77" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + [[package]] name = "smallvec" version = "0.6.14" @@ -370,12 +523,35 @@ dependencies = [ "maybe-uninit", ] +[[package]] +name = "syn" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "ttf-parser" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b3e06c9b9d80ed6b745c7159c40b311ad2916abb34a49e9be2653b90db0d8dd" +[[package]] +name = "typenum" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" + +[[package]] +name = "unicode-ident" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3" + [[package]] name = "version_check" version = "0.9.4" @@ -388,6 +564,16 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wide" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae41ecad2489a1655c8ef8489444b0b113c0a0c795944a3572a0931cf7d2525c" +dependencies = [ + "bytemuck", + "safe_arch", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index 0e0df32..6632eba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,4 +7,6 @@ edition = "2021" [dependencies] macroquad = "0.3.24" -macroquad-particles = "0.1.1" +nalgebra = { version = "0.31.1", features = ["rand-no-std"] } +rand = "0.8.5" +rand_distr = "0.4.3" diff --git a/src/main.rs b/src/main.rs index f492157..f15b63b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,11 @@ mod asteroids; +mod nn; mod player; mod utils; mod world; -use asteroids::Asteroid; use macroquad::prelude::*; -use player::Player; +use nn::NN; use world::World; #[macroquad::main("Camera")] @@ -17,13 +17,14 @@ async fn main() { }; set_camera(&cam); let mut world = World::new(); - + let nn = NN::new(vec![2, 3, 2]); + nn.feed_forward(vec![2., 3.]); loop { - clear_background(BLACK); - if !world.over { - world.update(); - } - world.draw(); + // clear_background(BLACK); + // if !world.over { + // world.update(); + // } + // world.draw(); next_frame().await } } diff --git a/src/nn.rs b/src/nn.rs new file mode 100644 index 0000000..2be989b --- /dev/null +++ b/src/nn.rs @@ -0,0 +1,55 @@ +use nalgebra::*; +use rand_distr::StandardNormal; +extern crate rand as r; + +enum ActivationFunc { + Sigmoid, + Tanh, + ReLU, +} +pub struct NN { + config: Vec, + weights: Vec>, + activ_func: ActivationFunc, +} + +impl NN { + // Vec of number of neurons in input, hidden 1, hidden 2, ..., output layers + pub fn new(config: Vec) -> Self { + let mut rng = r::thread_rng(); + + Self { + config: config + .iter() + .enumerate() + .map(|(i, &x)| if i != config.len() - 1 { x + 1 } else { x }) + .collect(), + + // He-et-al Initialization + weights: config + .iter() + .zip(config.iter().skip(1)) + .map(|(&curr, &last)| { + DMatrix::::from_distribution(last, curr + 1, &StandardNormal, &mut rng) + * (2. / last as f32).sqrt() + }) + .collect(), + activ_func: ActivationFunc::ReLU, + } + } + + pub fn feed_forward(&self, inputs: Vec) { + let mut y = DMatrix::from_vec(inputs.len(), 1, inputs); + for i in 0..self.config.len() - 1 { + println!("{} {}", y, self.weights[i]); + y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| { + match self.activ_func { + ActivationFunc::ReLU => x.max(0.), + ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()), + ActivationFunc::Tanh => x.tanh(), + } + }); + } + println!("{}", y); + } +} diff --git a/src/player.rs b/src/player.rs index d67c816..3e994c4 100644 --- a/src/player.rs +++ b/src/player.rs @@ -19,8 +19,8 @@ pub struct Player { impl Player { pub fn new() -> Self { Self { - dir: vec2(0., 1.), - rot: PI / 2., + dir: vec2(0., -1.), + rot: -PI / 2., drag: 0.001, shot_interval: 0.3, alive: true, diff --git a/src/world.rs b/src/world.rs index 5f510ef..9c08c13 100644 --- a/src/world.rs +++ b/src/world.rs @@ -45,7 +45,7 @@ impl World { )); } AsteroidSize::Medium => { - let rand = vec2(gen_range(-30., 30.), gen_range(-30., 30.)); + let rand = vec2(gen_range(-40., 40.), gen_range(-40., 40.)); to_add.push(Asteroid::new_from( asteroid.pos, asteroid.vel + rand,