neural network, feed forward
This commit is contained in:
parent
3d633c435d
commit
80fb255c36
|
@ -19,6 +19,15 @@ dependencies = [
|
|||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "approx"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "audir-sles"
|
||||
version = "0.1.0"
|
||||
|
@ -128,7 +137,9 @@ name = "genetic"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"macroquad",
|
||||
"macroquad-particles",
|
||||
"nalgebra",
|
||||
"rand",
|
||||
"rand_distr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -194,6 +205,12 @@ version = "0.2.134"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb"
|
||||
|
||||
[[package]]
|
||||
name = "libm"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "292a948cd991e376cf75541fe5b97a1081d713c618b4f1b9500f8844e49eb565"
|
||||
|
||||
[[package]]
|
||||
name = "macroquad"
|
||||
version = "0.3.24"
|
||||
|
@ -210,15 +227,6 @@ dependencies = [
|
|||
"quad-snd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "macroquad-particles"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8505dc035e4d68799a53a459afef0bc345410039325b5928af2e0d83eb15de40"
|
||||
dependencies = [
|
||||
"macroquad",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "macroquad_macro"
|
||||
version = "0.1.7"
|
||||
|
@ -234,6 +242,15 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84"
|
||||
dependencies = [
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "maybe-uninit"
|
||||
version = "2.0.0"
|
||||
|
@ -261,12 +278,49 @@ dependencies = [
|
|||
"adler",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra"
|
||||
version = "0.31.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e9e0a04ce089f9401aac565c740ed30c46291260f27d4911fdbaa6ca65fa3044"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"matrixmultiply",
|
||||
"nalgebra-macros",
|
||||
"num-complex",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
"rand",
|
||||
"simba",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nalgebra-macros"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "01fcc0b8149b4632adc89ac3b7b31a12fb6099a0317a4eb2ebff574ef7de7218"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndk-sys"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1bcdd74c20ad5d95aacd60ef9ba40fdf77f767051040541df557b7a9b2a2121"
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ae39348c8bc5fbd7f40c727a9925f03517afd2ab27d46702108b6a7e5414c19"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.45"
|
||||
|
@ -295,6 +349,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -321,6 +376,12 @@ version = "1.15.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1"
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1de2e551fb905ac83f73f7aedf2f0cb4a0da7e35efa24a202a936269f1f18e1"
|
||||
|
||||
[[package]]
|
||||
name = "png"
|
||||
version = "0.17.6"
|
||||
|
@ -333,6 +394,21 @@ dependencies = [
|
|||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.46"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quad-alsa-sys"
|
||||
version = "0.3.2"
|
||||
|
@ -361,6 +437,83 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_distr"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rawpointer"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
|
||||
[[package]]
|
||||
name = "safe_arch"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "794821e4ccb0d9f979512f9c1973480123f9bd62a90d74ab0f9426fcf8f4a529"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simba"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c48e45e5961033db030b56ad67aef22e9c908c493a6e8348c0a0f6b93433cd77"
|
||||
dependencies = [
|
||||
"approx",
|
||||
"num-complex",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"wide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "0.6.14"
|
||||
|
@ -370,12 +523,35 @@ dependencies = [
|
|||
"maybe-uninit",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.102"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ttf-parser"
|
||||
version = "0.15.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b3e06c9b9d80ed6b745c7159c40b311ad2916abb34a49e9be2653b90db0d8dd"
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
|
@ -388,6 +564,16 @@ version = "0.11.0+wasi-snapshot-preview1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||
|
||||
[[package]]
|
||||
name = "wide"
|
||||
version = "0.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ae41ecad2489a1655c8ef8489444b0b113c0a0c795944a3572a0931cf7d2525c"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"safe_arch",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
|
|
@ -7,4 +7,6 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
macroquad = "0.3.24"
|
||||
macroquad-particles = "0.1.1"
|
||||
nalgebra = { version = "0.31.1", features = ["rand-no-std"] }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
|
|
17
src/main.rs
17
src/main.rs
|
@ -1,11 +1,11 @@
|
|||
mod asteroids;
|
||||
mod nn;
|
||||
mod player;
|
||||
mod utils;
|
||||
mod world;
|
||||
|
||||
use asteroids::Asteroid;
|
||||
use macroquad::prelude::*;
|
||||
use player::Player;
|
||||
use nn::NN;
|
||||
use world::World;
|
||||
|
||||
#[macroquad::main("Camera")]
|
||||
|
@ -17,13 +17,14 @@ async fn main() {
|
|||
};
|
||||
set_camera(&cam);
|
||||
let mut world = World::new();
|
||||
|
||||
let nn = NN::new(vec![2, 3, 2]);
|
||||
nn.feed_forward(vec![2., 3.]);
|
||||
loop {
|
||||
clear_background(BLACK);
|
||||
if !world.over {
|
||||
world.update();
|
||||
}
|
||||
world.draw();
|
||||
// clear_background(BLACK);
|
||||
// if !world.over {
|
||||
// world.update();
|
||||
// }
|
||||
// world.draw();
|
||||
next_frame().await
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
use nalgebra::*;
|
||||
use rand_distr::StandardNormal;
|
||||
extern crate rand as r;
|
||||
|
||||
enum ActivationFunc {
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
ReLU,
|
||||
}
|
||||
pub struct NN {
|
||||
config: Vec<usize>,
|
||||
weights: Vec<DMatrix<f32>>,
|
||||
activ_func: ActivationFunc,
|
||||
}
|
||||
|
||||
impl NN {
|
||||
// Vec of number of neurons in input, hidden 1, hidden 2, ..., output layers
|
||||
pub fn new(config: Vec<usize>) -> Self {
|
||||
let mut rng = r::thread_rng();
|
||||
|
||||
Self {
|
||||
config: config
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &x)| if i != config.len() - 1 { x + 1 } else { x })
|
||||
.collect(),
|
||||
|
||||
// He-et-al Initialization
|
||||
weights: config
|
||||
.iter()
|
||||
.zip(config.iter().skip(1))
|
||||
.map(|(&curr, &last)| {
|
||||
DMatrix::<f32>::from_distribution(last, curr + 1, &StandardNormal, &mut rng)
|
||||
* (2. / last as f32).sqrt()
|
||||
})
|
||||
.collect(),
|
||||
activ_func: ActivationFunc::ReLU,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn feed_forward(&self, inputs: Vec<f32>) {
|
||||
let mut y = DMatrix::from_vec(inputs.len(), 1, inputs);
|
||||
for i in 0..self.config.len() - 1 {
|
||||
println!("{} {}", y, self.weights[i]);
|
||||
y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| {
|
||||
match self.activ_func {
|
||||
ActivationFunc::ReLU => x.max(0.),
|
||||
ActivationFunc::Sigmoid => 1. / (1. + (-x).exp()),
|
||||
ActivationFunc::Tanh => x.tanh(),
|
||||
}
|
||||
});
|
||||
}
|
||||
println!("{}", y);
|
||||
}
|
||||
}
|
|
@ -19,8 +19,8 @@ pub struct Player {
|
|||
impl Player {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
dir: vec2(0., 1.),
|
||||
rot: PI / 2.,
|
||||
dir: vec2(0., -1.),
|
||||
rot: -PI / 2.,
|
||||
drag: 0.001,
|
||||
shot_interval: 0.3,
|
||||
alive: true,
|
||||
|
|
|
@ -45,7 +45,7 @@ impl World {
|
|||
));
|
||||
}
|
||||
AsteroidSize::Medium => {
|
||||
let rand = vec2(gen_range(-30., 30.), gen_range(-30., 30.));
|
||||
let rand = vec2(gen_range(-40., 40.), gen_range(-40., 40.));
|
||||
to_add.push(Asteroid::new_from(
|
||||
asteroid.pos,
|
||||
asteroid.vel + rand,
|
||||
|
|
Loading…
Reference in New Issue