Add shift register memory

This commit is contained in:
Leonora Tindall 2023-04-03 22:50:33 -05:00
parent e63fbf74ba
commit 007eb03050
1 changed files with 20 additions and 2 deletions

View File

@ -11,6 +11,8 @@ const NUM_KEYS: usize = 4;
const INPUTS_PER_ASTEROID: usize = 4; const INPUTS_PER_ASTEROID: usize = 4;
const NUM_ASTEROIDS: usize = 1; const NUM_ASTEROIDS: usize = 1;
const INPUTS_FOR_SHIP: usize = 2; const INPUTS_FOR_SHIP: usize = 2;
const VALUES_PER_MEMORY: usize = 1;
const NUM_MEMORIES: usize = 0;
#[derive(Default)] #[derive(Default)]
pub struct Player { pub struct Player {
pub pos: Vec2, pub pos: Vec2,
@ -29,6 +31,7 @@ pub struct Player {
alive: bool, alive: bool,
pub lifespan: u32, pub lifespan: u32,
pub shots: u32, pub shots: u32,
memory: std::collections::VecDeque<f32>,
} }
impl Player { impl Player {
@ -46,10 +49,16 @@ impl Player {
0, 0,
(INPUTS_PER_ASTEROID * NUM_ASTEROIDS) (INPUTS_PER_ASTEROID * NUM_ASTEROIDS)
+ INPUTS_FOR_SHIP + INPUTS_FOR_SHIP
+ (VALUES_PER_MEMORY * NUM_MEMORIES),
); );
// Number of outputs // Number of outputs
c.push( c.push(
NUM_KEYS NUM_KEYS
+ if NUM_MEMORIES > 0 {
VALUES_PER_MEMORY
} else {
0
},
); );
Some(NN::new(c, mut_rate.unwrap(), activ.unwrap())) Some(NN::new(c, mut_rate.unwrap(), activ.unwrap()))
} }
@ -63,8 +72,9 @@ impl Player {
shot_interval: 18, shot_interval: 18,
alive: true, alive: true,
shots: 4, shots: 4,
// 4 outputs // 4 outputs, 1 for memory
outputs: vec![0.; NUM_KEYS], outputs: vec![0.; NUM_KEYS + VALUES_PER_MEMORY],
memory: vec![0.; VALUES_PER_MEMORY * NUM_MEMORIES].into(),
..Default::default() ..Default::default()
} }
@ -140,10 +150,18 @@ impl Player {
self.inputs.push( self.inputs.push(
(self.shot_interval as f32 - self.last_shot as f32).max(0.) / self.shot_interval as f32, (self.shot_interval as f32 - self.last_shot as f32).max(0.) / self.shot_interval as f32,
); );
// Insert the memories
for memory in &self.memory {
self.inputs.push(memory.min(1.).max(-1.));
}
// Run the brain // Run the brain
if let Some(brain) = &self.brain { if let Some(brain) = &self.brain {
assert_eq!(self.inputs.len(), brain.config[0] - 1); assert_eq!(self.inputs.len(), brain.config[0] - 1);
self.outputs = brain.feed_forward(&self.inputs); self.outputs = brain.feed_forward(&self.inputs);
if NUM_MEMORIES > 0 {
self.memory.push_back(self.outputs[self.outputs.len() - 1]);
self.memory.pop_front();
}
keys = self keys = self
.outputs .outputs
.iter() .iter()