import/export models

This commit is contained in:
sparshg 2023-01-06 14:40:29 +05:30
parent b9cfedc00d
commit 02df5694f9
6 changed files with 116 additions and 49 deletions

137
Cargo.lock generated
View File

@ -60,15 +60,15 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]] [[package]]
name = "bumpalo" name = "bumpalo"
version = "3.11.0" version = "3.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d" checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba"
[[package]] [[package]]
name = "bytemuck" name = "bytemuck"
version = "1.12.1" version = "1.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f5715e491b5a1598fc2bef5a606847b5dc1d48ea625bd3c02c00de8285591da" checksum = "aaa3a8d9a1ca92e282c96a32d6511b695d7d994d1d102ba85d279f9b2756947f"
[[package]] [[package]]
name = "byteorder" name = "byteorder"
@ -114,9 +114,9 @@ checksum = "0c87e182de0887fd5361989c677c4e8f5000cd9491d6d563161a8f3a5519fc7f"
[[package]] [[package]]
name = "flate2" name = "flate2"
version = "1.0.24" version = "1.0.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f82b0f4c27ad9f8bfd1f3208d882da2b09c301bc1c828fd3a00d0216d2fbbff6" checksum = "a8a2db397cb1c8772f31494cb8917e48cd1e64f0fa7efac59fbd741a0a8ce841"
dependencies = [ dependencies = [
"crc32fast", "crc32fast",
"miniz_oxide", "miniz_oxide",
@ -140,13 +140,15 @@ dependencies = [
"nalgebra", "nalgebra",
"rand", "rand",
"rand_distr", "rand_distr",
"serde",
"serde_json",
] ]
[[package]] [[package]]
name = "getrandom" name = "getrandom"
version = "0.2.7" version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6" checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"libc", "libc",
@ -176,9 +178,9 @@ checksum = "4d13cdbd5dbb29f9c88095bbdc2590c9cba0d0a1269b983fef6b2cdd7e9f4db1"
[[package]] [[package]]
name = "image" name = "image"
version = "0.24.4" version = "0.24.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd8e4fb07cf672b1642304e731ef8a6a4c7891d67bb4fd4f5ce58cd6ed86803c" checksum = "69b7ea949b537b0fd0af141fff8c77690f2ce96f4f41f042ccb6c69c6c965945"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"byteorder", "byteorder",
@ -188,6 +190,12 @@ dependencies = [
"png", "png",
] ]
[[package]]
name = "itoa"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440"
[[package]] [[package]]
name = "lewton" name = "lewton"
version = "0.9.4" version = "0.9.4"
@ -201,21 +209,21 @@ dependencies = [
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.134" version = "0.2.139"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb" checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79"
[[package]] [[package]]
name = "libm" name = "libm"
version = "0.2.5" version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "292a948cd991e376cf75541fe5b97a1081d713c618b4f1b9500f8844e49eb565" checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb"
[[package]] [[package]]
name = "macroquad" name = "macroquad"
version = "0.3.24" version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19c4f0576d6468cb31de5ba0b3c8eb56cfd95ed3edfee380ac339309b4830074" checksum = "f3790f7fd2e4c480108cbfc86488f023b72e1e0bb6ffd5c6cba38049c7e2fbfc"
dependencies = [ dependencies = [
"bumpalo", "bumpalo",
"fontdue", "fontdue",
@ -259,9 +267,9 @@ checksum = "60302e4db3a61da70c0cb7991976248362f30319e88850c487b9b95bbf059e00"
[[package]] [[package]]
name = "miniquad" name = "miniquad"
version = "0.3.13" version = "0.3.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a07975b18f290b99365f042dd80db3e03908539ca6bc47e749c5eef4ee262fd4" checksum = "46381fe09fbf91bfa402a3e4fc26a104c9130562d51f89964c46adbc00591496"
dependencies = [ dependencies = [
"libc", "libc",
"ndk-sys", "ndk-sys",
@ -271,18 +279,18 @@ dependencies = [
[[package]] [[package]]
name = "miniz_oxide" name = "miniz_oxide"
version = "0.5.4" version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96590ba8f175222643a85693f33d26e9c8a015f599c216509b1a6894af675d34" checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa"
dependencies = [ dependencies = [
"adler", "adler",
] ]
[[package]] [[package]]
name = "nalgebra" name = "nalgebra"
version = "0.31.1" version = "0.31.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9e0a04ce089f9401aac565c740ed30c46291260f27d4911fdbaa6ca65fa3044" checksum = "20bd243ab3dbb395b39ee730402d2e5405e448c75133ec49cc977762c4cba3d1"
dependencies = [ dependencies = [
"approx", "approx",
"matrixmultiply", "matrixmultiply",
@ -292,6 +300,7 @@ dependencies = [
"num-traits", "num-traits",
"rand", "rand",
"rand_distr", "rand_distr",
"serde",
"simba", "simba",
"typenum", "typenum",
] ]
@ -320,6 +329,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ae39348c8bc5fbd7f40c727a9925f03517afd2ab27d46702108b6a7e5414c19" checksum = "7ae39348c8bc5fbd7f40c727a9925f03517afd2ab27d46702108b6a7e5414c19"
dependencies = [ dependencies = [
"num-traits", "num-traits",
"serde",
] ]
[[package]] [[package]]
@ -373,21 +383,21 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.15.0" version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1" checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66"
[[package]] [[package]]
name = "paste" name = "paste"
version = "1.0.9" version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1de2e551fb905ac83f73f7aedf2f0cb4a0da7e35efa24a202a936269f1f18e1" checksum = "d01a5bd0424d00070b0098dd17ebca6f961a959dead1dbcbbbc1d1cd8d3deeba"
[[package]] [[package]]
name = "png" name = "png"
version = "0.17.6" version = "0.17.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f0e7f4c94ec26ff209cee506314212639d6c91b80afb82984819fafce9df01c" checksum = "5d708eaf860a19b19ce538740d2b4bdeeb8337fa53f7738455e706623ad5c638"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"crc32fast", "crc32fast",
@ -397,15 +407,15 @@ dependencies = [
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
version = "0.2.16" version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.46" version = "1.0.49"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b" checksum = "57a8eca9f9c4ffde41714334dee777596264c7825420f521abc92b5b5deb63a5"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]
@ -427,9 +437,9 @@ checksum = "658fa1faf7a4cc5f057c9ee5ef560f717ad9d8dc66d975267f709624d6e1ab88"
[[package]] [[package]]
name = "quad-snd" name = "quad-snd"
version = "0.2.5" version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e82e2e4a55292a75d8569ef0b3f7c24964074efe5767b359dbf028a0b3c53464" checksum = "b3e6bcab480b99b9afec58bcba02151eb9afa0a445ca9ce421b9c11c8dfafe1a"
dependencies = [ dependencies = [
"audir-sles", "audir-sles",
"audrey", "audrey",
@ -440,9 +450,9 @@ dependencies = [
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.21" version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
] ]
@ -493,6 +503,12 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "ryu"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde"
[[package]] [[package]]
name = "safe_arch" name = "safe_arch"
version = "0.6.0" version = "0.6.0"
@ -503,10 +519,41 @@ dependencies = [
] ]
[[package]] [[package]]
name = "simba" name = "serde"
version = "0.7.2" version = "1.0.152"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c48e45e5961033db030b56ad67aef22e9c908c493a6e8348c0a0f6b93433cd77" checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.152"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.91"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877c235533714907a8c2464236f5c4b2a17262ef1bd71f38f35ea592c8da6883"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "simba"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f3fd720c48c53cace224ae62bef1bbff363a70c68c4802a78b5cc6159618176"
dependencies = [ dependencies = [
"approx", "approx",
"num-complex", "num-complex",
@ -526,9 +573,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.102" version = "1.0.107"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1" checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -543,15 +590,15 @@ checksum = "7b3e06c9b9d80ed6b745c7159c40b311ad2916abb34a49e9be2653b90db0d8dd"
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.15.0" version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.5" version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3" checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc"
[[package]] [[package]]
name = "version_check" name = "version_check"

View File

@ -7,6 +7,8 @@ edition = "2021"
[dependencies] [dependencies]
macroquad = "0.3.24" macroquad = "0.3.24"
nalgebra = { version = "0.31.1", features = ["rand-no-std", "rand"] } nalgebra = { version = "0.31.1", features = ["rand", "serde-serialize"] }
rand = "0.8.5" rand = "0.8.5"
rand_distr = "0.4.3" rand_distr = "0.4.3"
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.91"

1
models/brain.json Normal file
View File

@ -0,0 +1 @@
{"config":[5,9,9,4],"weights":[[[-0.14426763,0.30192375,-0.33760154,0.56687576,0.5898287,-0.074536264,-0.6543682,-0.8446751,0.044597864,0.4060799,0.09555614,0.13340633,0.22922766,0.67137265,0.9608935,0.6129314,-0.24052018,-0.38112265,-0.5038841,-0.12603956,-0.4028914,0.18523669,0.07864928,0.06674278,0.2717425,0.6340071,0.61368537,0.8703921,-0.38861758,0.5550103,-0.26274413,0.5074116,-0.7821254,0.049322072,-0.2819178,-0.54028326,-0.4513356,-0.5652218,0.50929916,-0.42176604],8,5],[[-0.8103107,0.8446312,-0.40203577,0.54887617,-0.44233263,0.31166792,-0.33348525,0.40876722,0.26672268,-0.712453,-0.21693909,-0.35327882,-0.7854557,0.43809295,0.59911966,-0.071552694,0.3788359,-0.4907685,0.7868428,0.63808143,-0.50622714,0.08628023,-0.8824939,0.24896917,0.63522625,-0.50140214,0.9587381,0.5064759,0.040097475,0.24041378,-0.12401252,0.10650039,-0.8819831,0.29062068,-0.26787168,0.45351043,-0.2870677,-0.24404618,0.8434694,0.30426964,0.31458378,-0.8161984,-0.12195361,0.8177855,-0.57765794,0.89029014,0.75471187,0.5454526,-0.2778288,0.5250077,-0.71220773,0.58331454,0.368407,-1.0235562,-0.057651043,0.6541481,0.49730933,0.23280966,0.9288963,0.42811918,-0.45282865,0.22473057,0.9463216,0.5010747,-0.30762756,-0.5731837,0.8379146,0.06879502,-0.23158616,-0.6971844,0.4713143,0.1294421],8,9],[[0.45169508,-0.31192523,-0.74151075,-0.40177822,0.9622052,0.47277915,-0.76076436,0.00037801266,-0.36617398,0.89722514,0.89538515,0.99763775,1.3664851,-0.09313524,-0.5705385,0.059523106,-0.4970879,0.06191349,0.61592185,0.63810134,-0.97980994,-0.673075,0.025918722,0.5754194,1.2249401,-0.1614248,0.8438246,-0.6717344,0.83650386,-2.2065213,-0.051995337,0.19302869,-0.9978042,0.5303961,-1.4841839,0.118329406],4,9]],"activ_func":"ReLU","mut_rate":0.04}

View File

@ -2,9 +2,10 @@ use macroquad::rand::gen_range;
use nalgebra::*; use nalgebra::*;
use r::Rng; use r::Rng;
use rand_distr::StandardNormal; use rand_distr::StandardNormal;
use serde::{Deserialize, Serialize};
extern crate rand as r; extern crate rand as r;
#[derive(PartialEq, Debug, Clone, Copy, Default)] #[derive(PartialEq, Debug, Clone, Copy, Default, Serialize, Deserialize)]
enum ActivationFunc { enum ActivationFunc {
Sigmoid, Sigmoid,
@ -13,10 +14,10 @@ enum ActivationFunc {
ReLU, ReLU,
} }
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct NN { pub struct NN {
pub config: Vec<usize>, pub config: Vec<usize>,
pub weights: Vec<DMatrix<f32>>, weights: Vec<DMatrix<f32>>,
activ_func: ActivationFunc, activ_func: ActivationFunc,
mut_rate: f32, mut_rate: f32,
} }
@ -93,4 +94,13 @@ impl NN {
} }
y.column(0).data.into_slice().to_vec() y.column(0).data.into_slice().to_vec()
} }
pub fn export(&self) -> String {
serde_json::to_string(self).unwrap()
}
pub fn import() -> NN {
let json = std::fs::read_to_string("models/brain.json").expect("Unable to read file");
serde_json::from_str(&json).unwrap()
}
} }

View File

@ -67,6 +67,7 @@ impl Population {
.map(|i| World::simulate(Some(self.worlds[i].see_brain().to_owned()))) .map(|i| World::simulate(Some(self.worlds[i].see_brain().to_owned())))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
new_worlds[0].set_best(); new_worlds[0].set_best();
// new_worlds[0].export_brian();
while new_worlds.len() < self.size { while new_worlds.len() < self.size {
let rands = (gen_range(0., total), gen_range(0., total)); let rands = (gen_range(0., total), gen_range(0., total));
let mut sum = 0.; let mut sum = 0.;

View File

@ -38,6 +38,12 @@ impl World {
self.player.brain.as_ref().unwrap() self.player.brain.as_ref().unwrap()
} }
pub fn export_brian(&self) {
let json = self.player.brain.as_ref().unwrap().export();
std::fs::create_dir_all("models").expect("Unable to create directory");
std::fs::write("models/brain.json", json).expect("Unable to write file");
}
pub fn update(&mut self) { pub fn update(&mut self) {
let mut to_add: Vec<Asteroid> = Vec::new(); let mut to_add: Vec<Asteroid> = Vec::new();
for asteroid in &mut self.asteroids { for asteroid in &mut self.asteroids {