From 02df5694f9aa3827436cb0ac42a123a34383c668 Mon Sep 17 00:00:00 2001 From: sparshg <43041139+sparshg@users.noreply.github.com> Date: Fri, 6 Jan 2023 14:40:29 +0530 Subject: [PATCH] import/export models --- Cargo.lock | 137 +++++++++++++++++++++++++++++++--------------- Cargo.toml | 4 +- models/brain.json | 1 + src/nn.rs | 16 +++++- src/population.rs | 1 + src/world.rs | 6 ++ 6 files changed, 116 insertions(+), 49 deletions(-) create mode 100644 models/brain.json diff --git a/Cargo.lock b/Cargo.lock index 8b3499f..0bfc9df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -60,15 +60,15 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bumpalo" -version = "3.11.0" +version = "3.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d" +checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba" [[package]] name = "bytemuck" -version = "1.12.1" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f5715e491b5a1598fc2bef5a606847b5dc1d48ea625bd3c02c00de8285591da" +checksum = "aaa3a8d9a1ca92e282c96a32d6511b695d7d994d1d102ba85d279f9b2756947f" [[package]] name = "byteorder" @@ -114,9 +114,9 @@ checksum = "0c87e182de0887fd5361989c677c4e8f5000cd9491d6d563161a8f3a5519fc7f" [[package]] name = "flate2" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f82b0f4c27ad9f8bfd1f3208d882da2b09c301bc1c828fd3a00d0216d2fbbff6" +checksum = "a8a2db397cb1c8772f31494cb8917e48cd1e64f0fa7efac59fbd741a0a8ce841" dependencies = [ "crc32fast", "miniz_oxide", @@ -140,13 +140,15 @@ dependencies = [ "nalgebra", "rand", "rand_distr", + "serde", + "serde_json", ] [[package]] name = "getrandom" -version = "0.2.7" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6" +checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" dependencies = [ "cfg-if", "libc", @@ -176,9 +178,9 @@ checksum = "4d13cdbd5dbb29f9c88095bbdc2590c9cba0d0a1269b983fef6b2cdd7e9f4db1" [[package]] name = "image" -version = "0.24.4" +version = "0.24.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd8e4fb07cf672b1642304e731ef8a6a4c7891d67bb4fd4f5ce58cd6ed86803c" +checksum = "69b7ea949b537b0fd0af141fff8c77690f2ce96f4f41f042ccb6c69c6c965945" dependencies = [ "bytemuck", "byteorder", @@ -188,6 +190,12 @@ dependencies = [ "png", ] +[[package]] +name = "itoa" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" + [[package]] name = "lewton" version = "0.9.4" @@ -201,21 +209,21 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.134" +version = "0.2.139" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb" +checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" [[package]] name = "libm" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "292a948cd991e376cf75541fe5b97a1081d713c618b4f1b9500f8844e49eb565" +checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb" [[package]] name = "macroquad" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19c4f0576d6468cb31de5ba0b3c8eb56cfd95ed3edfee380ac339309b4830074" +checksum = "f3790f7fd2e4c480108cbfc86488f023b72e1e0bb6ffd5c6cba38049c7e2fbfc" dependencies = [ "bumpalo", "fontdue", @@ -259,9 +267,9 @@ checksum = "60302e4db3a61da70c0cb7991976248362f30319e88850c487b9b95bbf059e00" [[package]] name = "miniquad" -version = "0.3.13" +version = "0.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a07975b18f290b99365f042dd80db3e03908539ca6bc47e749c5eef4ee262fd4" +checksum = "46381fe09fbf91bfa402a3e4fc26a104c9130562d51f89964c46adbc00591496" dependencies = [ "libc", "ndk-sys", @@ -271,18 +279,18 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.5.4" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96590ba8f175222643a85693f33d26e9c8a015f599c216509b1a6894af675d34" +checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" dependencies = [ "adler", ] [[package]] name = "nalgebra" -version = "0.31.1" +version = "0.31.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9e0a04ce089f9401aac565c740ed30c46291260f27d4911fdbaa6ca65fa3044" +checksum = "20bd243ab3dbb395b39ee730402d2e5405e448c75133ec49cc977762c4cba3d1" dependencies = [ "approx", "matrixmultiply", @@ -292,6 +300,7 @@ dependencies = [ "num-traits", "rand", "rand_distr", + "serde", "simba", "typenum", ] @@ -320,6 +329,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ae39348c8bc5fbd7f40c727a9925f03517afd2ab27d46702108b6a7e5414c19" dependencies = [ "num-traits", + "serde", ] [[package]] @@ -373,21 +383,21 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.15.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1" +checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" [[package]] name = "paste" -version = "1.0.9" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1de2e551fb905ac83f73f7aedf2f0cb4a0da7e35efa24a202a936269f1f18e1" +checksum = "d01a5bd0424d00070b0098dd17ebca6f961a959dead1dbcbbbc1d1cd8d3deeba" [[package]] name = "png" -version = "0.17.6" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f0e7f4c94ec26ff209cee506314212639d6c91b80afb82984819fafce9df01c" +checksum = "5d708eaf860a19b19ce538740d2b4bdeeb8337fa53f7738455e706623ad5c638" dependencies = [ "bitflags", "crc32fast", @@ -397,15 +407,15 @@ dependencies = [ [[package]] name = "ppv-lite86" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.46" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b" +checksum = "57a8eca9f9c4ffde41714334dee777596264c7825420f521abc92b5b5deb63a5" dependencies = [ "unicode-ident", ] @@ -427,9 +437,9 @@ checksum = "658fa1faf7a4cc5f057c9ee5ef560f717ad9d8dc66d975267f709624d6e1ab88" [[package]] name = "quad-snd" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82e2e4a55292a75d8569ef0b3f7c24964074efe5767b359dbf028a0b3c53464" +checksum = "b3e6bcab480b99b9afec58bcba02151eb9afa0a445ca9ce421b9c11c8dfafe1a" dependencies = [ "audir-sles", "audrey", @@ -440,9 +450,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.21" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" +checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" dependencies = [ "proc-macro2", ] @@ -493,6 +503,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" +[[package]] +name = "ryu" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" + [[package]] name = "safe_arch" version = "0.6.0" @@ -503,10 +519,41 @@ dependencies = [ ] [[package]] -name = "simba" -version = "0.7.2" +name = "serde" +version = "1.0.152" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c48e45e5961033db030b56ad67aef22e9c908c493a6e8348c0a0f6b93433cd77" +checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.152" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877c235533714907a8c2464236f5c4b2a17262ef1bd71f38f35ea592c8da6883" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "simba" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f3fd720c48c53cace224ae62bef1bbff363a70c68c4802a78b5cc6159618176" dependencies = [ "approx", "num-complex", @@ -526,9 +573,9 @@ dependencies = [ [[package]] name = "syn" -version = "1.0.102" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1" +checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" dependencies = [ "proc-macro2", "quote", @@ -543,15 +590,15 @@ checksum = "7b3e06c9b9d80ed6b745c7159c40b311ad2916abb34a49e9be2653b90db0d8dd" [[package]] name = "typenum" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" [[package]] name = "unicode-ident" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3" +checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" [[package]] name = "version_check" diff --git a/Cargo.toml b/Cargo.toml index a09a5b9..1c26d7b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,8 @@ edition = "2021" [dependencies] macroquad = "0.3.24" -nalgebra = { version = "0.31.1", features = ["rand-no-std", "rand"] } +nalgebra = { version = "0.31.1", features = ["rand", "serde-serialize"] } rand = "0.8.5" rand_distr = "0.4.3" +serde = { version = "1.0.152", features = ["derive"] } +serde_json = "1.0.91" diff --git a/models/brain.json b/models/brain.json new file mode 100644 index 0000000..7bd34da --- /dev/null +++ b/models/brain.json @@ -0,0 +1 @@ +{"config":[5,9,9,4],"weights":[[[-0.14426763,0.30192375,-0.33760154,0.56687576,0.5898287,-0.074536264,-0.6543682,-0.8446751,0.044597864,0.4060799,0.09555614,0.13340633,0.22922766,0.67137265,0.9608935,0.6129314,-0.24052018,-0.38112265,-0.5038841,-0.12603956,-0.4028914,0.18523669,0.07864928,0.06674278,0.2717425,0.6340071,0.61368537,0.8703921,-0.38861758,0.5550103,-0.26274413,0.5074116,-0.7821254,0.049322072,-0.2819178,-0.54028326,-0.4513356,-0.5652218,0.50929916,-0.42176604],8,5],[[-0.8103107,0.8446312,-0.40203577,0.54887617,-0.44233263,0.31166792,-0.33348525,0.40876722,0.26672268,-0.712453,-0.21693909,-0.35327882,-0.7854557,0.43809295,0.59911966,-0.071552694,0.3788359,-0.4907685,0.7868428,0.63808143,-0.50622714,0.08628023,-0.8824939,0.24896917,0.63522625,-0.50140214,0.9587381,0.5064759,0.040097475,0.24041378,-0.12401252,0.10650039,-0.8819831,0.29062068,-0.26787168,0.45351043,-0.2870677,-0.24404618,0.8434694,0.30426964,0.31458378,-0.8161984,-0.12195361,0.8177855,-0.57765794,0.89029014,0.75471187,0.5454526,-0.2778288,0.5250077,-0.71220773,0.58331454,0.368407,-1.0235562,-0.057651043,0.6541481,0.49730933,0.23280966,0.9288963,0.42811918,-0.45282865,0.22473057,0.9463216,0.5010747,-0.30762756,-0.5731837,0.8379146,0.06879502,-0.23158616,-0.6971844,0.4713143,0.1294421],8,9],[[0.45169508,-0.31192523,-0.74151075,-0.40177822,0.9622052,0.47277915,-0.76076436,0.00037801266,-0.36617398,0.89722514,0.89538515,0.99763775,1.3664851,-0.09313524,-0.5705385,0.059523106,-0.4970879,0.06191349,0.61592185,0.63810134,-0.97980994,-0.673075,0.025918722,0.5754194,1.2249401,-0.1614248,0.8438246,-0.6717344,0.83650386,-2.2065213,-0.051995337,0.19302869,-0.9978042,0.5303961,-1.4841839,0.118329406],4,9]],"activ_func":"ReLU","mut_rate":0.04} \ No newline at end of file diff --git a/src/nn.rs b/src/nn.rs index 37f0e00..a901bc9 100644 --- a/src/nn.rs +++ b/src/nn.rs @@ -2,9 +2,10 @@ use macroquad::rand::gen_range; use nalgebra::*; use r::Rng; use rand_distr::StandardNormal; +use serde::{Deserialize, Serialize}; extern crate rand as r; -#[derive(PartialEq, Debug, Clone, Copy, Default)] +#[derive(PartialEq, Debug, Clone, Copy, Default, Serialize, Deserialize)] enum ActivationFunc { Sigmoid, @@ -13,10 +14,10 @@ enum ActivationFunc { ReLU, } -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct NN { pub config: Vec, - pub weights: Vec>, + weights: Vec>, activ_func: ActivationFunc, mut_rate: f32, } @@ -93,4 +94,13 @@ impl NN { } y.column(0).data.into_slice().to_vec() } + + pub fn export(&self) -> String { + serde_json::to_string(self).unwrap() + } + + pub fn import() -> NN { + let json = std::fs::read_to_string("models/brain.json").expect("Unable to read file"); + serde_json::from_str(&json).unwrap() + } } diff --git a/src/population.rs b/src/population.rs index f5480a2..2296a31 100644 --- a/src/population.rs +++ b/src/population.rs @@ -67,6 +67,7 @@ impl Population { .map(|i| World::simulate(Some(self.worlds[i].see_brain().to_owned()))) .collect::>(); new_worlds[0].set_best(); + // new_worlds[0].export_brian(); while new_worlds.len() < self.size { let rands = (gen_range(0., total), gen_range(0., total)); let mut sum = 0.; diff --git a/src/world.rs b/src/world.rs index 1c25858..09ac4f6 100644 --- a/src/world.rs +++ b/src/world.rs @@ -38,6 +38,12 @@ impl World { self.player.brain.as_ref().unwrap() } + pub fn export_brian(&self) { + let json = self.player.brain.as_ref().unwrap().export(); + std::fs::create_dir_all("models").expect("Unable to create directory"); + std::fs::write("models/brain.json", json).expect("Unable to write file"); + } + pub fn update(&mut self) { let mut to_add: Vec = Vec::new(); for asteroid in &mut self.asteroids {