From 7e3acb4d57e4287bf1f4a98d1a86f7497a2ad1cf Mon Sep 17 00:00:00 2001 From: sparshg <43041139+sparshg@users.noreply.github.com> Date: Mon, 9 Jan 2023 00:37:35 +0530 Subject: [PATCH] player visual --- models/brain.json | 2 +- src/main.rs | 163 +++++++++++++++++++++++----------------------- src/nn.rs | 52 ++++++++++----- src/player.rs | 31 ++++----- src/population.rs | 4 +- src/world.rs | 66 ++++++++++++++++++- 6 files changed, 199 insertions(+), 119 deletions(-) diff --git a/models/brain.json b/models/brain.json index e6d5137..bf9d2ea 100644 --- a/models/brain.json +++ b/models/brain.json @@ -1 +1 @@ -{"config":[6,9,9,4],"weights":[[[0.3488082,-0.09199099,-0.11929926,-1.0631673,-0.23543529,0.12705088,-1.0842022,-1.1801648,1.3452269,-0.105297334,0.7070266,-0.49821422,-0.9919794,-0.4586555,0.38327622,1.2620807,-0.8927275,0.72946376,0.36548716,0.3453985,0.24704376,1.1178607,-0.73745847,-0.36780706,0.5647091,-0.29108286,1.710524,1.0728852,-0.8066526,-0.28913066,0.14346941,-1.0912626,0.36901304,0.7923526,-0.51800287,-0.72875357,0.8539478,-1.473583,0.68293977,0.18473642,0.0003245327,-0.58371824,0.48150238,0.3494165,-0.23288698,-1.0439657,-0.26875693,0.5452296],8,6],[[-2.3292994,-0.26192483,-0.90176463,2.324304,-1.9353858,-0.14891693,-0.52935755,0.76884156,0.1082592,-0.9176799,0.6898619,0.7002196,0.19165382,-0.00026388586,-2.0727108,-0.43361717,0.4825783,0.5469626,1.5679779,1.6802235,0.5569048,0.2322176,-1.2066526,0.7200245,-0.54737276,-0.15166411,0.19801892,0.040810376,1.3895321,-0.08859847,-1.2233515,-0.063391574,0.10386248,-1.1793425,0.47050527,-1.7380185,-1.7678633,0.42901033,-0.017297065,0.4843002,-1.3651237,-0.24331652,0.2636839,0.7167474,-1.2047871,2.1309357,-1.3384317,2.6571567,0.044456493,-1.8444118,-0.52083886,0.32806322,0.088446766,-0.009452653,0.20716749,-0.911177,-0.74860054,0.16590308,0.46789008,0.9035914,0.64725244,-0.30468005,0.27064824,0.39474502,2.02378,-1.4614658,-0.84020156,0.69931465,-0.51446456,0.24423209,1.2651527,-0.45960972],8,9],[[-0.585348,0.5383882,-0.97360545,0.032165576,-0.049561385,0.24148509,-0.8100511,-0.6426556,1.059696,0.85932785,0.6353762,-0.9978614,1.9525132,-0.10112733,-0.5321224,0.443344,-0.37738746,-0.99622214,-1.3957877,0.6381588,-0.47502336,1.0980062,0.047556207,-0.008916257,-0.8662841,-0.45772696,1.6571641,0.9795985,0.098027825,-1.6333683,0.58037895,1.2487193,-1.0500598,-0.44345784,-2.4462621,0.5654464],4,9]],"activ_func":"ReLU","mut_rate":0.05} \ No newline at end of file +{"config":[6,7,7,4],"weights":[[[-0.63086957,-0.53742135,-0.9276433,-0.34575018,0.7992718,-1.715718,-0.73926973,-1.0049589,0.8922695,-0.89036644,-1.9465111,1.1345955,0.16628692,0.35796213,-0.19456862,1.2559475,-1.4939085,0.7122434,-0.35949436,0.59764737,1.3104885,-0.23108178,1.1664017,-0.562161,-1.0272486,0.71920836,-0.48468617,1.9857628,-1.5419438,0.7855448,0.7199551,-0.055748053,0.89273596,-1.5976299,0.67290825,1.321863],6,6],[[0.96074045,-1.3427482,-0.34363145,-0.467356,0.40982306,1.6549394,-0.98944664,1.3190286,1.114854,-1.9362743,1.4953789,-0.92963094,2.176349,-0.6917859,0.4418182,-0.578999,0.28320304,0.047711473,0.9644473,-0.27468953,-0.6578115,-1.835753,1.0188148,0.4024565,-0.283786,1.7134846,0.6505742,-0.47571746,-0.2544604,2.5863714,-0.3020923,1.8471942,-0.7066354,0.4471431,0.2741386,-1.7216014,-0.32693514,1.760901,0.20739792,-1.222214,1.7166008,2.3254254],6,7],[[-0.8019232,0.43523917,0.45905063,-1.0787292,0.94543713,1.4523034,1.061355,-0.3872152,-1.3330262,1.1092484,0.40466794,0.8725662,-0.5520403,-0.095518686,-0.03717842,0.44755453,-2.0717077,2.0572546,0.6617041,0.019890757,-0.9302492,-0.38139746,0.22157091,-1.4106625,-0.8257789,0.054379098,0.927862,1.8578752],4,7]],"activ_func":"ReLU","mut_rate":0.05} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index e57fd66..8e80eb2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,8 +4,6 @@ mod player; mod population; mod world; -use std::borrow::BorrowMut; - use macroquad::{ prelude::*, ui::{hash, root_ui, widgets, Skin}, @@ -38,7 +36,7 @@ async fn main() { offset: vec2((2. * th + WIDTH) / screen_width() - 1., 0.), ..Default::default() }; - let maincam = Camera2D { + let netcam = Camera2D { zoom: vec2(2. / screen_width(), -2. / screen_height()), offset: vec2( (th + WIDTH) / screen_width(), @@ -46,25 +44,23 @@ async fn main() { ), ..Default::default() }; - // let mut cam = Camera2D::from_display_rect(Rect { - // x: 0., - // y: 0., - // w: 1600., - // h: 1200., - // }); - // cam.offset = vec2(1., -1.); - // // { - // zoom: vec2(2. / 800., -2. / 600.), - // // offset: vec2(-19. / 60., 0.), - // ..Default::default() - // }; - let mut pop = Population::new(100); + let statcam = Camera2D { + zoom: vec2(2. / screen_width(), -2. / screen_height()), + offset: vec2( + (th + WIDTH) / screen_width(), + ((th + HEIGHT) * 0.5) / screen_height(), + ), + ..Default::default() + }; + let mut speedup = false; let mut paused = false; - let mut checkbox = false; - let mut combobox = 0; - let mut text = String::new(); - let mut number = 0.0; + let mut bias = false; + let mut showall = false; + let mut size = 100; + let mut pop = Population::new(size as usize); + + let ui_thick = 34.; let skin = { let boxed = root_ui() @@ -88,6 +84,7 @@ async fn main() { bytes: vec![0; 4], }) .background_margin(RectOffset::new(0., 0., 0., 0.)) + .color_inactive(WHITE) .build(); let button_style = boxed .color_hovered(RED) @@ -96,18 +93,26 @@ async fn main() { .text_color_hovered(WHITE) .text_color_clicked(WHITE) .margin(RectOffset::new(10., 10., 8., 8.)) + .color_inactive(WHITE) .build(); let label_style = root_ui() .style_builder() .text_color(WHITE) .font_size(24) .margin(RectOffset::new(5., 5., 4., 4.)) + .color_inactive(WHITE) + .build(); + let group_style = root_ui() + .style_builder() + .color(Color::new(0., 0., 0., 0.)) .build(); Skin { window_style, button_style, label_style, + group_style, + margin: 0., ..root_ui().default_skin() } @@ -116,6 +121,7 @@ async fn main() { root_ui().push_skin(&skin); loop { clear_background(BLACK); + set_camera(&gamecam); if is_key_pressed(KeyCode::S) { speedup = !speedup; } @@ -136,89 +142,82 @@ async fn main() { WIDTH * 0.5 + th, -HEIGHT * 0.5, screen_width() - WIDTH - 3. * th, - 34., + ui_thick, 2., WHITE, ); draw_rectangle_lines( WIDTH * 0.5 + th, - -HEIGHT * 0.5 + (screen_height() - 3. * th) * 0.5 - 34., + -HEIGHT * 0.5 + (screen_height() - 3. * th) * 0.5 - ui_thick, screen_width() - WIDTH - 3. * th, - 34., + ui_thick, 2., WHITE, ); - set_camera(&maincam); - // draw_circle(0., 0., 20., RED); + set_camera(&netcam); pop.worlds[0].player.draw_brain( screen_width() - WIDTH - 3. * th, (screen_height() - 3. * th) * 0.5, + bias, + ); + set_camera(&statcam); + pop.worlds[0].draw_stats( + screen_width() - WIDTH - 3. * th, + (screen_height() - 7. * th) * 0.5 - 2. * ui_thick, ); - set_camera(&gamecam); + let ui_width = screen_width() - WIDTH - 3. * th + 1.; + let ui_height = (screen_height() - 3. * th) * 0.5; root_ui().window( hash!(), vec2(WIDTH + 2. * th, th), - vec2(screen_width() - WIDTH - 3. * th + 1., 34.), + vec2(ui_width, ui_height), |ui| { - ui.label(None, &format!("Generation: {}", pop.gen)); - ui.same_line(278.); - widgets::Button::new("Load Model").ui(ui); - ui.same_line(0.); - widgets::Button::new("Save Model").ui(ui); - ui.same_line(0.); - if widgets::Button::new(fast).ui(ui) { - speedup = !speedup; - }; - ui.same_line(0.); - if widgets::Button::new(restart).ui(ui) { - pop = Population::new(100); - }; - ui.same_line(0.); - if widgets::Button::new(if paused { play } else { pause }).ui(ui) { - paused = !paused; - }; + widgets::Group::new(hash!(), Vec2::new(ui_width, ui_thick)) + .position(Vec2::new(0., 0.)) + .ui(ui, |ui| { + ui.label(None, &format!("Generation: {}", pop.gen)); + ui.same_line(314.); + widgets::Button::new("Load Model").ui(ui); + ui.same_line(0.); + widgets::Button::new("Save Model").ui(ui); + ui.same_line(0.); + if widgets::Button::new(fast).ui(ui) { + speedup = !speedup; + }; + ui.same_line(0.); + if widgets::Button::new(if paused { play } else { pause }).ui(ui) { + paused = !paused; + }; + }); + widgets::Group::new(hash!(), Vec2::new(ui_width, ui_thick)) + .position(Vec2::new(0., ui_height - ui_thick)) + .ui(ui, |ui| { + ui.label(Some(vec2(0., 2.)), "«Population»"); + widgets::Group::new(hash!(), Vec2::new(100., ui_thick)) + .position(Vec2::new(140., 0.)) + .ui(ui, |ui| { + ui.drag(hash!(), "", Some((2, 500)), &mut size); + }); + ui.same_line(364.); + if widgets::Button::new(if bias { "Hide Bias" } else { "Show Bias" }).ui(ui) + { + bias = !bias; + }; + ui.same_line(0.); + if widgets::Button::new(if !pop.best { "Show Best" } else { "Show All " }) + .ui(ui) + { + pop.best = !pop.best; + }; + ui.same_line(0.); + if widgets::Button::new(restart).ui(ui) { + pop = Population::new(size as usize); + }; + }); }, ); - root_ui().window( - hash!(), - vec2(WIDTH + 2. * th, (screen_height() - th) * 0.5 - 34.), - vec2(screen_width() - WIDTH - 3. * th + 1., 34.), - |ui| { - ui.label(None, &format!("Generation: {}", pop.gen)); - ui.same_line(278.); - widgets::Button::new("Load Model").ui(ui); - ui.same_line(0.); - widgets::Button::new("Save Model").ui(ui); - ui.same_line(0.); - if widgets::Button::new(fast).ui(ui) { - speedup = !speedup; - }; - ui.same_line(0.); - if widgets::Button::new(restart).ui(ui) { - pop = Population::new(100); - }; - ui.same_line(0.); - if widgets::Button::new(if paused { play } else { pause }).ui(ui) { - paused = !paused; - }; - }, - ); - - // set_camera(&maincam); - // draw_texture_ex( - // target.texture, - // 0., - // 0., - // Color::new(1., 1., 1., 0.3), - // DrawTextureParams { - // flip_y: true, - // ..Default::default() - // }, - // ); - // set_camera(&cam); - next_frame().await; } } diff --git a/src/nn.rs b/src/nn.rs index d9a6c5e..9a34b03 100644 --- a/src/nn.rs +++ b/src/nn.rs @@ -81,7 +81,7 @@ impl NN { pub fn feed_forward(&self, inputs: &Vec) -> Vec { // println!("inputs: {:?}", inputs); let mut y = DMatrix::from_vec(inputs.len(), 1, inputs.to_vec()); - for i in 0..self.config.len() - 2 { + for i in 0..self.config.len() - 1 { y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)).map(|x| { match self.activ_func { ActivationFunc::ReLU => x.max(0.), @@ -90,13 +90,10 @@ impl NN { } }); } - let i = self.config.len() - 2; - y = (&self.weights[i] * y.insert_row(self.config[i] - 1, 1.)) - .map(|x| 1. / (1. + (-x).exp())); y.column(0).data.into_slice().to_vec() } - pub fn draw(&self, width: f32, height: f32, inputs: &Vec) { + pub fn draw(&self, width: f32, height: f32, inputs: &Vec, outputs: &Vec, bias: bool) { draw_rectangle_lines(-width * 0.5, -height * 0.5, width, height, 2., WHITE); let width = width * 0.8; @@ -107,14 +104,14 @@ impl NN { for (i, layer) in self .config .iter() - // .take(self.config.len() - 1) - // .map(|x| x - 1) - // .chain(self.config.last().map(|&x| x)) + .take(self.config.len() - 1) + .map(|x| x - if bias { 0 } else { 1 }) + .chain(self.config.last().map(|&x| x)) .enumerate() { p1s = p2s; p2s = Vec::new(); - for neuron in 0..*layer { + for neuron in 0..layer { p2s.push(( i as f32 * width / (self.config.len() - 1) as f32 - width * 0.5, neuron as f32 * vspace - (vspace * (layer - 1) as f32) * 0.5, @@ -122,7 +119,14 @@ impl NN { } for (k, j, p1, p2) in p1s.iter().enumerate().flat_map(|(k, x)| { p2s.iter() - .take(p2s.len() - if i == self.config.len() - 1 { 0 } else { 1 }) + .take( + p2s.len() + - if i == self.config.len() - 1 || !bias { + 0 + } else { + 1 + }, + ) .enumerate() .map(move |(j, y)| (k, j, *x, *y)) }) { @@ -133,18 +137,36 @@ impl NN { p1.1, p2.0, p2.1, - 1., + 1.5, Color::new(1., c, c, weight.abs()), ); } - for p in &p1s { + + let mut inputs = inputs.to_vec(); + inputs.push(1.); + + for (j, p) in p1s.iter().enumerate() { draw_circle(p.0, p.1, 10., WHITE); - draw_circle(p.0, p.1, 9., BLACK); + draw_circle(p.0, p.1, 8., BLACK); + draw_circle( + p.0, + p.1, + 8., + if i == 1 && inputs.len() > 1 { + let c = if inputs[j] < 0. { 0. } else { 1. }; + Color::new(1., c, c, inputs[j].abs()) + } else { + BLACK + }, + ); } } - for p in &p2s { + for (j, p) in p2s.iter().enumerate() { draw_circle(p.0, p.1, 10., WHITE); - draw_circle(p.0, p.1, 9., BLACK); + draw_circle(p.0, p.1, 8., BLACK); + if !outputs.is_empty() { + draw_circle(p.0, p.1, 8., Color::new(1., 1., 1., outputs[j])); + } } draw_rectangle(width * 0.45, height * 0.45, 10., 10., RED); draw_text("-ve", width * 0.45 + 20., height * 0.45 + 10., 20., WHITE); diff --git a/src/player.rs b/src/player.rs index abf357a..23094e8 100644 --- a/src/player.rs +++ b/src/player.rs @@ -14,7 +14,7 @@ pub struct Player { bullets: Vec, asteroid: Option, inputs: Vec, - outputs: Vec, + pub outputs: Vec, // asteroid_data: Vec<(f32, f32, f32)>, last_shot: u8, shot_interval: u8, @@ -38,7 +38,7 @@ impl Player { alive: true, debug: false, shots: 4, - outputs: vec![false; 4], + outputs: vec![0.; 4], ..Default::default() } @@ -120,13 +120,13 @@ impl Player { self.lifespan += 1; self.last_shot += 1; self.acc = 0.; - self.outputs = vec![false; 4]; + self.outputs = vec![0.; 4]; if let Some(ast) = self.asteroid.as_ref() { self.inputs = vec![ - (ast.pos - self.pos).length() / WIDTH, + (ast.pos - self.pos).length() / HEIGHT, self.dir.angle_between(ast.pos - self.pos), - (ast.vel - self.vel).x / 11., - (ast.vel - self.vel).y / 11., + (ast.vel - self.vel).x * 0.6, + (ast.vel - self.vel).y * 0.6, self.rot / TAU as f32, ]; @@ -143,25 +143,22 @@ impl Player { // ); if let Some(brain) = &self.brain { - self.outputs = brain - .feed_forward(&self.inputs) - .iter() - .map(|&x| x > 0.85) - .collect(); + self.outputs = brain.feed_forward(&self.inputs); } } - if is_key_down(KeyCode::Right) && self.debug || self.outputs[0] { + let keys: Vec = self.outputs.iter().map(|&x| x > 0.).collect(); + if is_key_down(KeyCode::Right) && self.debug || keys[0] { self.rot = (self.rot + 0.1 + TAU as f32) % TAU as f32; self.dir = vec2(self.rot.cos(), self.rot.sin()); } - if is_key_down(KeyCode::Left) && self.debug || self.outputs[1] { + if is_key_down(KeyCode::Left) && self.debug || keys[1] { self.rot = (self.rot - 0.1 + TAU as f32) % TAU as f32; self.dir = vec2(self.rot.cos(), self.rot.sin()); } - if is_key_down(KeyCode::Up) && self.debug || self.outputs[2] { + if is_key_down(KeyCode::Up) && self.debug || keys[2] { self.acc = 0.14; } - if is_key_down(KeyCode::Space) && self.debug || self.outputs[3] { + if is_key_down(KeyCode::Space) && self.debug || keys[3] { if self.last_shot > self.shot_interval { self.last_shot = 0; self.shots += 1; @@ -251,9 +248,9 @@ impl Player { } } - pub fn draw_brain(&self, width: f32, height: f32) { + pub fn draw_brain(&self, width: f32, height: f32, bias: bool) { if let Some(brain) = &self.brain { - brain.draw(width, height, &self.inputs); + brain.draw(width, height, &self.inputs, &self.outputs, bias); } } } diff --git a/src/population.rs b/src/population.rs index 66f6fab..2c83786 100644 --- a/src/population.rs +++ b/src/population.rs @@ -6,7 +6,7 @@ use crate::{nn::NN, world::World, HEIGHT, WIDTH}; pub struct Population { size: usize, pub gen: i32, - best: bool, + pub best: bool, pub worlds: Vec, } @@ -85,7 +85,7 @@ impl Population { println!("Fitness: {}", i.fitness); } println!("Gen: {}, Fitness: {}", self.gen, self.worlds[0].fitness); - let mut new_worlds = (0..self.size / 20) + let mut new_worlds = (0..std::cmp::max(1, self.size / 20)) .map(|i| World::simulate(Some(self.worlds[i].see_brain().to_owned()))) .collect::>(); new_worlds[0].set_best(); diff --git a/src/world.rs b/src/world.rs index 844fd45..bbab198 100644 --- a/src/world.rs +++ b/src/world.rs @@ -82,10 +82,10 @@ impl World { } if self.player.check_player_collision(asteroid) { self.over = true; - self.fitness = - (self.score / self.player.shots as f32).powi(2) * self.player.lifespan as f32; } } + self.fitness = + (self.score / self.player.shots as f32).powi(2) * self.player.lifespan as f32; self.player.update(); self.asteroids.append(&mut to_add); self.asteroids.retain(|asteroid| asteroid.alive); @@ -120,4 +120,66 @@ impl World { WHITE, ); } + + pub fn draw_stats(&self, width: f32, height: f32) { + draw_rectangle_lines(-width * 0.5, -height * 0.5, width, height, 2., WHITE); + + let scale = 2.5; + let offset = vec2(-width * 0.3, -height * 0.1); + let p1 = scale * vec2(0., -20.) + offset; + let p2 = scale * vec2(-12.667, 18.) + offset; + let p3 = scale * vec2(12.667, 18.) + offset; + let p4 = scale * vec2(-10., 10.) + offset; + let p5 = scale * vec2(10., 10.) + offset; + let p6 = scale * vec2(0., 25.) + offset; + let p7 = scale * vec2(-6., 10.) + offset; + let p8 = scale * vec2(6., 10.) + offset; + + draw_line(p1.x, p1.y, p2.x, p2.y, 2., WHITE); + draw_line(p1.x, p1.y, p3.x, p3.y, 2., WHITE); + draw_line(p4.x, p4.y, p5.x, p5.y, 2., WHITE); + if self.player.outputs[2] > 0. && (gen_range(0., 1.) < 0.4 || self.over) { + draw_triangle_lines(p6, p7, p8, 2., WHITE); + } + let l1 = scale * vec2(30., 0.) + offset; + let l2 = scale * vec2(25., -5.) + offset; + let l3 = scale * vec2(25., 5.) + offset; + if self.player.outputs[0] > 0. { + draw_line(l1.x, l1.y, l2.x, l2.y, 2., WHITE); + draw_line(l1.x, l1.y, l3.x, l3.y, 2., WHITE); + } + let l1 = -scale * vec2(30., 0.) + offset; + let l2 = -scale * vec2(25., -5.) + offset; + let l3 = -scale * vec2(25., 5.) + offset; + if self.player.outputs[1] > 0. { + draw_line(l1.x, l1.y, l2.x, l2.y, 2., WHITE); + draw_line(l1.x, l1.y, l3.x, l3.y, 2., WHITE); + } + let l1 = -scale * vec2(0., 35.) + offset; + if self.player.outputs[3] > 0. { + draw_circle(l1.x, l1.y, 5., WHITE); + draw_circle(l1.x, l1.y, 3.5, BLACK); + } + draw_text( + if self.over { "DEAD" } else { "ALIVE" }, + -width * 0.5 + 20., + 75., + 24., + if self.over { RED } else { GREEN }, + ); + draw_text( + &format!("Score: {}", self.score), + -width * 0.5 + 20., + 100., + 24., + WHITE, + ); + draw_text( + &format!("Fitness: {:.2}", self.fitness), + -width * 0.5 + 20., + 125., + 24., + WHITE, + ); + } }