load/save models

This commit is contained in:
sparshg 2023-01-09 01:39:10 +05:30
parent 7255bd617d
commit 31be57ab37
7 changed files with 43 additions and 17 deletions

17
Cargo.lock generated
View File

@ -76,6 +76,12 @@ version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]]
name = "cc"
version = "1.0.78"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a20104e2335ce8a659d6dd92a51a767a0c062599c73b343fd152cb401e828c3d"
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
version = "1.0.0" version = "1.0.0"
@ -142,6 +148,7 @@ dependencies = [
"rand_distr", "rand_distr",
"serde", "serde",
"serde_json", "serde_json",
"tinyfiledialogs",
] ]
[[package]] [[package]]
@ -582,6 +589,16 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "tinyfiledialogs"
version = "3.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e25fa0bc43a6566e2cc6d7ac96df3fa5a57beba34445bead1b368ba8fe9ca568"
dependencies = [
"cc",
"libc",
]
[[package]] [[package]]
name = "ttf-parser" name = "ttf-parser"
version = "0.15.2" version = "0.15.2"

View File

@ -12,3 +12,4 @@ rand = "0.8.5"
rand_distr = "0.4.3" rand_distr = "0.4.3"
serde = { version = "1.0.152", features = ["derive"] } serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.91" serde_json = "1.0.91"
tinyfiledialogs = "3.9.1"

View File

@ -1 +1 @@
{"config":[6,7,7,4],"weights":[[[-0.63086957,-0.53742135,-0.9276433,-0.34575018,0.7992718,-1.715718,-0.73926973,-1.0049589,0.8922695,-0.89036644,-1.9465111,1.1345955,0.16628692,0.35796213,-0.19456862,1.2559475,-1.4939085,0.7122434,-0.35949436,0.59764737,1.3104885,-0.23108178,1.1664017,-0.562161,-1.0272486,0.71920836,-0.48468617,1.9857628,-1.5419438,0.7855448,0.7199551,-0.055748053,0.89273596,-1.5976299,0.67290825,1.321863],6,6],[[0.96074045,-1.3427482,-0.34363145,-0.467356,0.40982306,1.6549394,-0.98944664,1.3190286,1.114854,-1.9362743,1.4953789,-0.92963094,2.176349,-0.6917859,0.4418182,-0.578999,0.28320304,0.047711473,0.9644473,-0.27468953,-0.6578115,-1.835753,1.0188148,0.4024565,-0.283786,1.7134846,0.6505742,-0.47571746,-0.2544604,2.5863714,-0.3020923,1.8471942,-0.7066354,0.4471431,0.2741386,-1.7216014,-0.32693514,1.760901,0.20739792,-1.222214,1.7166008,2.3254254],6,7],[[-0.8019232,0.43523917,0.45905063,-1.0787292,0.94543713,1.4523034,1.061355,-0.3872152,-1.3330262,1.1092484,0.40466794,0.8725662,-0.5520403,-0.095518686,-0.03717842,0.44755453,-2.0717077,2.0572546,0.6617041,0.019890757,-0.9302492,-0.38139746,0.22157091,-1.4106625,-0.8257789,0.054379098,0.927862,1.8578752],4,7]],"activ_func":"ReLU","mut_rate":0.05} {"config":[6,7,7,4],"weights":[[[1.0073581,0.10285417,0.99449056,-0.3444486,0.6855302,0.27510634,-1.0100973,-0.6120519,2.5831788,-0.20424974,0.0436417,-1.3984715,-0.006550881,-0.4939134,0.11111599,0.90777147,0.3497712,0.1878394,1.5924022,0.4055183,0.13070123,-0.20338647,-0.3423283,0.0065261708,-3.129609,0.10328761,-0.04636088,-0.2210973,1.6255523,0.0877632,-1.0272317,-0.25930727,0.5464621,-0.6345097,0.36178154,0.11928975],6,6],[[0.70730317,-1.5186486,0.48550358,-0.07556238,0.30853745,0.28583515,-1.1104876,1.281638,-0.73964846,0.3207366,-0.2936539,-0.1681668,-0.5059147,-1.286904,1.3346463,-0.31040213,-1.2456195,-1.487566,-0.7792251,0.21069501,-0.58090514,-0.7379585,-0.76974285,0.84114456,-0.83502007,0.7163775,1.0413684,-0.0083767455,0.34308165,-0.07222111,1.0065539,-0.26538268,-2.956921,0.14103593,-0.16258997,1.608054,1.0216693,-1.1089768,0.1331676,1.0608563,-0.18297681,-0.42292616],6,7],[[0.06483343,1.0197396,0.3112982,0.9256023,0.92714655,0.0707627,-0.84999484,0.27492166,1.59565,-0.96241134,-0.83382475,-1.3709176,-0.12743561,0.77063257,0.32503816,0.53178865,0.36163366,-0.8444256,-0.54038405,1.0471557,-1.9363129,0.9999369,-0.7392465,0.18372212,-0.97619456,0.32837945,-1.3330086,2.2873125],4,7]],"activ_func":"ReLU","mut_rate":0.05}

View File

@ -4,11 +4,15 @@ mod player;
mod population; mod population;
mod world; mod world;
use nn::NN;
use tinyfiledialogs::*;
use macroquad::{ use macroquad::{
prelude::*, prelude::*,
ui::{hash, root_ui, widgets, Skin}, ui::{hash, root_ui, widgets, Skin},
}; };
use population::Population; use population::Population;
use world::World;
pub const WIDTH: f32 = 800.; pub const WIDTH: f32 = 800.;
pub const HEIGHT: f32 = 780.; pub const HEIGHT: f32 = 780.;
@ -179,9 +183,20 @@ async fn main() {
.ui(ui, |ui| { .ui(ui, |ui| {
ui.label(None, &format!("Generation: {}", pop.gen)); ui.label(None, &format!("Generation: {}", pop.gen));
ui.same_line(314.); ui.same_line(314.);
widgets::Button::new("Load Model").ui(ui); if widgets::Button::new("Load Model").ui(ui) {
if let Some(path) = open_file_dialog("Load Model", "brain.json", None) {
let brain = NN::import(&path);
size = 1;
pop = Population::new(1);
pop.worlds[0] = World::simulate(Some(brain));
}
}
ui.same_line(0.); ui.same_line(0.);
widgets::Button::new("Save Model").ui(ui); if widgets::Button::new("Save Model").ui(ui) {
if let Some(path) = save_file_dialog("Save Model", "brain.json") {
pop.worlds[0].export_brain(&path);
}
}
ui.same_line(0.); ui.same_line(0.);
if widgets::Button::new(fast).ui(ui) { if widgets::Button::new(fast).ui(ui) {
speedup = !speedup; speedup = !speedup;
@ -198,7 +213,7 @@ async fn main() {
widgets::Group::new(hash!(), Vec2::new(100., ui_thick)) widgets::Group::new(hash!(), Vec2::new(100., ui_thick))
.position(Vec2::new(140., 0.)) .position(Vec2::new(140., 0.))
.ui(ui, |ui| { .ui(ui, |ui| {
ui.drag(hash!(), "", Some((2, 500)), &mut size); ui.drag(hash!(), "", Some((1, 500)), &mut size);
}); });
ui.same_line(307.); ui.same_line(307.);
widgets::Button::new("Debug").ui(ui); widgets::Button::new("Debug").ui(ui);

View File

@ -194,8 +194,8 @@ impl NN {
serde_json::to_string(self).unwrap() serde_json::to_string(self).unwrap()
} }
pub fn import() -> NN { pub fn import(path: &str) -> NN {
let json = std::fs::read_to_string("models/brain.json").expect("Unable to read file"); let json = std::fs::read_to_string(path).expect("Unable to read file");
serde_json::from_str(&json).unwrap() serde_json::from_str(&json).unwrap()
} }
} }

View File

@ -34,9 +34,6 @@ impl Population {
if is_key_pressed(KeyCode::Z) { if is_key_pressed(KeyCode::Z) {
self.best = !self.best; self.best = !self.best;
} }
if is_key_pressed(KeyCode::Space) {
self.worlds[0].export_brain();
}
} }
pub fn draw(&self) { pub fn draw(&self) {

View File

@ -38,10 +38,9 @@ impl World {
self.player.brain.as_ref().unwrap() self.player.brain.as_ref().unwrap()
} }
pub fn export_brain(&self) { pub fn export_brain(&self, path: &str) {
let json = self.player.brain.as_ref().unwrap().export(); let json = self.player.brain.as_ref().unwrap().export();
std::fs::create_dir_all("models").expect("Unable to create directory"); std::fs::write(path, json).expect("Unable to write file");
std::fs::write("models/brain.json", json).expect("Unable to write file");
} }
pub fn update(&mut self) { pub fn update(&mut self) {
@ -110,11 +109,8 @@ impl World {
asteroid.draw(); asteroid.draw();
} }
draw_text( draw_text(
&format!( &format!("{:.2}", self.fitness),
"{}", self.player.pos.x - 22.,
(self.score / self.player.shots as f32).powi(2) * self.player.lifespan as f32
),
self.player.pos.x - 20.,
self.player.pos.y - 20., self.player.pos.y - 20.,
12., 12.,
WHITE, WHITE,