From 31be57ab379f6734b645c070ccf1240fdf2c7999 Mon Sep 17 00:00:00 2001 From: sparshg <43041139+sparshg@users.noreply.github.com> Date: Mon, 9 Jan 2023 01:39:10 +0530 Subject: [PATCH] load/save models --- Cargo.lock | 17 +++++++++++++++++ Cargo.toml | 1 + models/brain.json | 2 +- src/main.rs | 21 ++++++++++++++++++--- src/nn.rs | 4 ++-- src/population.rs | 3 --- src/world.rs | 12 ++++-------- 7 files changed, 43 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0bfc9df..9f945db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,6 +76,12 @@ version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +[[package]] +name = "cc" +version = "1.0.78" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a20104e2335ce8a659d6dd92a51a767a0c062599c73b343fd152cb401e828c3d" + [[package]] name = "cfg-if" version = "1.0.0" @@ -142,6 +148,7 @@ dependencies = [ "rand_distr", "serde", "serde_json", + "tinyfiledialogs", ] [[package]] @@ -582,6 +589,16 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tinyfiledialogs" +version = "3.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e25fa0bc43a6566e2cc6d7ac96df3fa5a57beba34445bead1b368ba8fe9ca568" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "ttf-parser" version = "0.15.2" diff --git a/Cargo.toml b/Cargo.toml index 1c26d7b..1e6acf9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,4 @@ rand = "0.8.5" rand_distr = "0.4.3" serde = { version = "1.0.152", features = ["derive"] } serde_json = "1.0.91" +tinyfiledialogs = "3.9.1" diff --git a/models/brain.json b/models/brain.json index bf9d2ea..a7089ca 100644 --- a/models/brain.json +++ b/models/brain.json @@ -1 +1 @@ -{"config":[6,7,7,4],"weights":[[[-0.63086957,-0.53742135,-0.9276433,-0.34575018,0.7992718,-1.715718,-0.73926973,-1.0049589,0.8922695,-0.89036644,-1.9465111,1.1345955,0.16628692,0.35796213,-0.19456862,1.2559475,-1.4939085,0.7122434,-0.35949436,0.59764737,1.3104885,-0.23108178,1.1664017,-0.562161,-1.0272486,0.71920836,-0.48468617,1.9857628,-1.5419438,0.7855448,0.7199551,-0.055748053,0.89273596,-1.5976299,0.67290825,1.321863],6,6],[[0.96074045,-1.3427482,-0.34363145,-0.467356,0.40982306,1.6549394,-0.98944664,1.3190286,1.114854,-1.9362743,1.4953789,-0.92963094,2.176349,-0.6917859,0.4418182,-0.578999,0.28320304,0.047711473,0.9644473,-0.27468953,-0.6578115,-1.835753,1.0188148,0.4024565,-0.283786,1.7134846,0.6505742,-0.47571746,-0.2544604,2.5863714,-0.3020923,1.8471942,-0.7066354,0.4471431,0.2741386,-1.7216014,-0.32693514,1.760901,0.20739792,-1.222214,1.7166008,2.3254254],6,7],[[-0.8019232,0.43523917,0.45905063,-1.0787292,0.94543713,1.4523034,1.061355,-0.3872152,-1.3330262,1.1092484,0.40466794,0.8725662,-0.5520403,-0.095518686,-0.03717842,0.44755453,-2.0717077,2.0572546,0.6617041,0.019890757,-0.9302492,-0.38139746,0.22157091,-1.4106625,-0.8257789,0.054379098,0.927862,1.8578752],4,7]],"activ_func":"ReLU","mut_rate":0.05} \ No newline at end of file +{"config":[6,7,7,4],"weights":[[[1.0073581,0.10285417,0.99449056,-0.3444486,0.6855302,0.27510634,-1.0100973,-0.6120519,2.5831788,-0.20424974,0.0436417,-1.3984715,-0.006550881,-0.4939134,0.11111599,0.90777147,0.3497712,0.1878394,1.5924022,0.4055183,0.13070123,-0.20338647,-0.3423283,0.0065261708,-3.129609,0.10328761,-0.04636088,-0.2210973,1.6255523,0.0877632,-1.0272317,-0.25930727,0.5464621,-0.6345097,0.36178154,0.11928975],6,6],[[0.70730317,-1.5186486,0.48550358,-0.07556238,0.30853745,0.28583515,-1.1104876,1.281638,-0.73964846,0.3207366,-0.2936539,-0.1681668,-0.5059147,-1.286904,1.3346463,-0.31040213,-1.2456195,-1.487566,-0.7792251,0.21069501,-0.58090514,-0.7379585,-0.76974285,0.84114456,-0.83502007,0.7163775,1.0413684,-0.0083767455,0.34308165,-0.07222111,1.0065539,-0.26538268,-2.956921,0.14103593,-0.16258997,1.608054,1.0216693,-1.1089768,0.1331676,1.0608563,-0.18297681,-0.42292616],6,7],[[0.06483343,1.0197396,0.3112982,0.9256023,0.92714655,0.0707627,-0.84999484,0.27492166,1.59565,-0.96241134,-0.83382475,-1.3709176,-0.12743561,0.77063257,0.32503816,0.53178865,0.36163366,-0.8444256,-0.54038405,1.0471557,-1.9363129,0.9999369,-0.7392465,0.18372212,-0.97619456,0.32837945,-1.3330086,2.2873125],4,7]],"activ_func":"ReLU","mut_rate":0.05} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index e6c978d..b45179f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,11 +4,15 @@ mod player; mod population; mod world; +use nn::NN; +use tinyfiledialogs::*; + use macroquad::{ prelude::*, ui::{hash, root_ui, widgets, Skin}, }; use population::Population; +use world::World; pub const WIDTH: f32 = 800.; pub const HEIGHT: f32 = 780.; @@ -179,9 +183,20 @@ async fn main() { .ui(ui, |ui| { ui.label(None, &format!("Generation: {}", pop.gen)); ui.same_line(314.); - widgets::Button::new("Load Model").ui(ui); + if widgets::Button::new("Load Model").ui(ui) { + if let Some(path) = open_file_dialog("Load Model", "brain.json", None) { + let brain = NN::import(&path); + size = 1; + pop = Population::new(1); + pop.worlds[0] = World::simulate(Some(brain)); + } + } ui.same_line(0.); - widgets::Button::new("Save Model").ui(ui); + if widgets::Button::new("Save Model").ui(ui) { + if let Some(path) = save_file_dialog("Save Model", "brain.json") { + pop.worlds[0].export_brain(&path); + } + } ui.same_line(0.); if widgets::Button::new(fast).ui(ui) { speedup = !speedup; @@ -198,7 +213,7 @@ async fn main() { widgets::Group::new(hash!(), Vec2::new(100., ui_thick)) .position(Vec2::new(140., 0.)) .ui(ui, |ui| { - ui.drag(hash!(), "", Some((2, 500)), &mut size); + ui.drag(hash!(), "", Some((1, 500)), &mut size); }); ui.same_line(307.); widgets::Button::new("Debug").ui(ui); diff --git a/src/nn.rs b/src/nn.rs index 3913fd8..740436c 100644 --- a/src/nn.rs +++ b/src/nn.rs @@ -194,8 +194,8 @@ impl NN { serde_json::to_string(self).unwrap() } - pub fn import() -> NN { - let json = std::fs::read_to_string("models/brain.json").expect("Unable to read file"); + pub fn import(path: &str) -> NN { + let json = std::fs::read_to_string(path).expect("Unable to read file"); serde_json::from_str(&json).unwrap() } } diff --git a/src/population.rs b/src/population.rs index 2c83786..d5e2cf9 100644 --- a/src/population.rs +++ b/src/population.rs @@ -34,9 +34,6 @@ impl Population { if is_key_pressed(KeyCode::Z) { self.best = !self.best; } - if is_key_pressed(KeyCode::Space) { - self.worlds[0].export_brain(); - } } pub fn draw(&self) { diff --git a/src/world.rs b/src/world.rs index 141cd69..c7944f7 100644 --- a/src/world.rs +++ b/src/world.rs @@ -38,10 +38,9 @@ impl World { self.player.brain.as_ref().unwrap() } - pub fn export_brain(&self) { + pub fn export_brain(&self, path: &str) { let json = self.player.brain.as_ref().unwrap().export(); - std::fs::create_dir_all("models").expect("Unable to create directory"); - std::fs::write("models/brain.json", json).expect("Unable to write file"); + std::fs::write(path, json).expect("Unable to write file"); } pub fn update(&mut self) { @@ -110,11 +109,8 @@ impl World { asteroid.draw(); } draw_text( - &format!( - "{}", - (self.score / self.player.shots as f32).powi(2) * self.player.lifespan as f32 - ), - self.player.pos.x - 20., + &format!("{:.2}", self.fitness), + self.player.pos.x - 22., self.player.pos.y - 20., 12., WHITE,