From e0f6dfa3ef9d7495015e13eefc19021b68559a4b Mon Sep 17 00:00:00 2001 From: sakex Date: Sun, 2 Jan 2022 21:57:09 +0100 Subject: [PATCH 1/4] Spiking neural networks --- Cargo.lock | 130 +++++++++++++++ Cargo.toml | 5 + benches/benchmark.rs | 2 +- src/neural_network/mod.rs | 3 + src/neural_network/nn.rs | 21 +-- src/neural_network/nn_trait.rs | 24 +++ src/neural_network/spiking/mod.rs | 4 + src/neural_network/spiking/spiking_neuron.rs | 73 +++++++++ src/neural_network/spiking/spiking_nn.rs | 157 +++++++++++++++++++ src/tests.rs | 12 +- src/topology/bias_and_genes.rs | 7 +- src/topology/topology_struct.rs | 4 +- src/train/mod.rs | 8 +- src/train/training.rs | 24 +-- 14 files changed, 440 insertions(+), 34 deletions(-) create mode 100644 src/neural_network/nn_trait.rs create mode 100644 src/neural_network/spiking/mod.rs create mode 100644 src/neural_network/spiking/spiking_neuron.rs create mode 100644 src/neural_network/spiking/spiking_nn.rs diff --git a/Cargo.lock b/Cargo.lock index 5e6cbe2..9fa538e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -188,6 +188,95 @@ version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" +[[package]] +name = "futures" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28560757fe2bb34e79f907794bb6b22ae8b0e5c669b638a1132f2592b19035b4" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3dda0b6588335f360afc675d0564c17a77a2bda81ca178a4b6081bd86c7f0b" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0c8ff0461b82559810cdccfde3215c3f373807f5e5232b71479bff7bb2583d7" + +[[package]] +name = "futures-executor" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29d6d2ff5bb10fb95c85b8ce46538a2e5f5e7fdc755623a7d4529ab8a4ed9d2a" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f9d34af5a1aac6fb380f735fe510746c38067c5bf16c7fd250280503c971b2" + +[[package]] +name = "futures-macro" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbd947adfffb0efc70599b3ddcf7b5597bb5fa9e245eb99f62b3a5f7bb8bd3c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3055baccb68d74ff6480350f8d6eb8fcfa3aa11bdc1a1ae3afdd0514617d508" + +[[package]] +name = "futures-task" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ee7c6485c30167ce4dfb83ac568a849fe53274c831081476ee13e0dce1aad72" + +[[package]] +name = "futures-util" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b5cf40b47a271f77a8b1bec03ca09044d99d2372c0de244e66430761127164" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "getrandom" version = "0.2.3" @@ -294,6 +383,7 @@ version = "1.1.0" dependencies = [ "async-trait", "criterion", + "futures", "getrandom", "itertools", "num", @@ -303,6 +393,7 @@ dependencies = [ "rayon", "serde", "serde_json", + "tokio", ] [[package]] @@ -408,6 +499,18 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "pin-project-lite" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e280fbe77cc62c91527259e9442153f4688736748d24660126286329742b4c6c" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "plotters" version = "0.3.1" @@ -633,6 +736,12 @@ dependencies = [ "serde", ] +[[package]] +name = "slab" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" + [[package]] name = "syn" version = "1.0.83" @@ -663,6 +772,27 @@ dependencies = [ "serde_json", ] +[[package]] +name = "tokio" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbbf1c778ec206785635ce8ad57fe52b3009ae9e0c9f574a728f3049d3e55838" +dependencies = [ + "pin-project-lite", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b557f72f448c511a979e2564e55d74e6c4432fc96ff4f6241bc6bded342643b7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-width" version = "0.1.9" diff --git a/Cargo.toml b/Cargo.toml index bf22157..626a9b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,9 @@ repository = "https://github.com/sakex/neat-gru-rust" categories = ["science", "wasm"] keywords = ["neat", "ai", "machine-learning", "genetic", "algorithm"] +[features] +default = [] +snn = ["tokio", "futures"] [lib] crate-type = ["cdylib", "rlib"] @@ -23,6 +26,8 @@ numeric_literals = "0.2.0" rayon = "1.5.1" itertools = "0.10.1" async-trait = "0.1.51" +futures = { version = "0.3", optional = true } +tokio = { version = "1.0", optional = true, features = ["sync", "rt", "macros", "time"] } [dev-dependencies] criterion = "0.3.5" diff --git a/benches/benchmark.rs b/benches/benchmark.rs index b2ea05e..9b30e59 100644 --- a/benches/benchmark.rs +++ b/benches/benchmark.rs @@ -10,7 +10,7 @@ fn benchmark(c: &mut Criterion) { let file_string = &mut "".to_string(); file.read_to_string(file_string).unwrap(); let topology = Topology::from_string(file_string); - let mut network = unsafe { NeuralNetwork::new(&topology) }; + let mut network = unsafe { NeuralNetwork::from_topology(&topology) }; c.bench_function("nn::compute", |b| { b.iter(|| network.compute(black_box(&[0.0, 0.0]))) }); diff --git a/src/neural_network/mod.rs b/src/neural_network/mod.rs index e51d00f..fb7ec70 100644 --- a/src/neural_network/mod.rs +++ b/src/neural_network/mod.rs @@ -4,5 +4,8 @@ mod connection_sigmoid; mod functions; mod neuron; mod nn; +pub mod nn_trait; pub use nn::*; +#[cfg(feature = "snn")] +pub mod spiking; diff --git a/src/neural_network/nn.rs b/src/neural_network/nn.rs index 3c16b22..98cfe1f 100644 --- a/src/neural_network/nn.rs +++ b/src/neural_network/nn.rs @@ -8,6 +8,7 @@ use num::Float; use std::fmt::Display; use super::connection_relu::ConnectionRelu; +use super::nn_trait::NN; #[derive(Debug)] pub struct NeuralNetwork @@ -22,17 +23,11 @@ where unsafe impl Send for NeuralNetwork where T: Float + std::ops::AddAssign + Display + Send {} unsafe impl Sync for NeuralNetwork where T: Float + std::ops::AddAssign + Display + Send {} -impl NeuralNetwork +impl NN for NeuralNetwork where T: Float + std::ops::AddAssign + Display + Send, { - /// Instantiates a new Neural Network from a `Topology` - /// - /// # Safety - /// - /// If the Topology is ill-formed, it will result in pointer overflow. - /// Topologies generated by this crate are guaranteed to be safe. - pub unsafe fn new(topology: &Topology) -> NeuralNetwork { + unsafe fn from_topology(topology: &Topology) -> NeuralNetwork { let layer_count = topology.layers_sizes.len(); let sizes = &topology.layers_sizes; let mut layer_addresses = vec![0; layer_count]; @@ -113,7 +108,12 @@ where net.reset_neurons_value(); net } +} +impl NeuralNetwork +where + T: Float + std::ops::AddAssign + Display + Send, +{ #[inline] fn reset_neurons_value(&mut self) { for (neuron, bias) in self.neurons.iter_mut().zip(self.biases.iter()) { @@ -151,11 +151,6 @@ where neuron.reset_state(); } } - - pub fn from_string(serialized: &str) -> NeuralNetwork { - let top = Topology::from_string(serialized); - unsafe { NeuralNetwork::new(&top) } - } } impl PartialEq for NeuralNetwork diff --git a/src/neural_network/nn_trait.rs b/src/neural_network/nn_trait.rs new file mode 100644 index 0000000..d9705ac --- /dev/null +++ b/src/neural_network/nn_trait.rs @@ -0,0 +1,24 @@ +use std::fmt::Display; + +use num::Float; + +use crate::topology::Topology; + +pub trait NN: Sized +where + T: Float + std::ops::AddAssign + Display + Send, +{ + /// Instantiates a new Neural Network from a `Topology` + /// + /// # Safety + /// + /// If the Topology is ill-formed, it will result in pointer overflow. + /// Topologies generated by this crate are guaranteed to be safe. + unsafe fn from_topology(topology: &Topology) -> Self; + + /// Deserializes a serde serialized Topolgy into a neural network + fn from_string(serialized: &str) -> Self { + let top = Topology::from_string(serialized); + unsafe { Self::from_topology(&top) } + } +} diff --git a/src/neural_network/spiking/mod.rs b/src/neural_network/spiking/mod.rs new file mode 100644 index 0000000..c4af91a --- /dev/null +++ b/src/neural_network/spiking/mod.rs @@ -0,0 +1,4 @@ +mod spiking_neuron; +mod spiking_nn; + +pub use spiking_nn::*; diff --git a/src/neural_network/spiking/spiking_neuron.rs b/src/neural_network/spiking/spiking_neuron.rs new file mode 100644 index 0000000..3a68ef4 --- /dev/null +++ b/src/neural_network/spiking/spiking_neuron.rs @@ -0,0 +1,73 @@ +use futures::future::join_all; +use num::Float; +use std::fmt::Display; +use std::time::Duration; +use tokio::sync::mpsc::{channel, Receiver, Sender}; + +pub struct SpikingNeuron +where + T: Float + std::ops::AddAssign + Display + Send + Sync + 'static, +{ + /// Genetic can have threshold one > threshold 2 or the opposite + pub threshold_one: T, + pub threshold_two: T, + pub decay: T, + pub input: Receiver, + pub outputs: Vec<(T, Sender)>, +} + +impl SpikingNeuron +where + T: Float + std::ops::AddAssign + Display + Send + Sync + 'static, +{ + pub fn new() -> (Self, Sender) { + let (sdr, rcv) = channel(1); + ( + Self { + threshold_one: T::zero(), + threshold_two: T::zero(), + decay: T::zero(), + input: rcv, + outputs: Vec::new(), + }, + sdr, + ) + } + + pub fn spawn_task(self) { + let Self { + threshold_one, + threshold_two, + decay, + mut input, + mut outputs, + } = self; + tokio::spawn(async move { + let threshold_up = threshold_one.max(threshold_two); + let threshold_down = threshold_one.min(threshold_two); + let mid_point = (threshold_up - threshold_down) / T::from(2.0).unwrap(); + let mut activation: T = mid_point; + loop { + tokio::select! { + input = input.recv() => { + if let Some(input) = input { + activation += input; + if activation >= threshold_up || activation <= threshold_down { + let futures: Vec<_> = outputs.iter_mut().map( |(weight, sdr)| { + sdr.send(*weight * activation) + }).collect(); + join_all(futures).await; + activation = mid_point; + } + } else { + break; + } + } + _ = tokio::time::sleep(Duration::from_secs(1)) => { + activation = activation - (activation - mid_point) * decay; + } + } + } + }); + } +} diff --git a/src/neural_network/spiking/spiking_nn.rs b/src/neural_network/spiking/spiking_nn.rs new file mode 100644 index 0000000..0931a01 --- /dev/null +++ b/src/neural_network/spiking/spiking_nn.rs @@ -0,0 +1,157 @@ +use futures::future::select_all; +use num::Float; +use std::fmt::Display; +use tokio::sync::mpsc::{channel, Receiver, Sender}; + +use crate::{ + neural_network::nn_trait::NN, + topology::{bias::Bias, Topology}, +}; + +use super::spiking_neuron::SpikingNeuron; + +/// A spiking Neural Network +/// +/// Initialize with `from_topology` +/// +/// You can send input witht the `send` function and subscribe asynchronously to output using `recv` +pub struct SpikingNeuralNetwork +where + T: Float + std::ops::AddAssign + Display + Send + Sync + 'static, +{ + input_channels: Vec>, + output_channels: Vec>, +} + +impl NN for SpikingNeuralNetwork +where + T: Float + std::ops::AddAssign + Display + Send + Sync + 'static, +{ + /// Instantiates a new Neural Network from a `Topology` + /// + /// # Safety + /// + /// If the Topology is ill-formed, it will result in pointer overflow. + /// Topologies generated by this crate are guaranteed to be safe. + unsafe fn from_topology(topology: &Topology) -> SpikingNeuralNetwork { + let layer_count = topology.layers_sizes.len(); + let sizes = &topology.layers_sizes; + let mut layer_addresses = vec![0; layer_count]; + let mut neurons_count: usize = 0; + for i in 0..layer_count { + layer_addresses[i] = neurons_count; + neurons_count += sizes[i] as usize; + } + let output_size = *sizes.last().unwrap() as usize; + let mut neurons: Vec> = Vec::with_capacity(neurons_count); + let mut senders: Vec> = Vec::with_capacity(neurons_count); + let mut biases: Vec> = Vec::with_capacity(neurons_count); + for _ in 0..neurons_count { + let (neuron, sender) = SpikingNeuron::new(); + neurons.push(neuron); + senders.push(sender); + } + for _ in 0..neurons_count { + biases.push(Bias::new_zero()); + } + + for (point, gene_and_bias) in topology.genes_point.iter() { + if gene_and_bias.genes.is_empty() + || gene_and_bias + .genes + .iter() + .all(|gene| gene.borrow().disabled) + { + continue; + } + let neuron_index = layer_addresses[point.layer as usize] + point.index as usize; + let input_neuron = &mut neurons[neuron_index]; + biases[neuron_index] = gene_and_bias.bias.clone(); + for gene_rc in &gene_and_bias.genes { + let gene = gene_rc.borrow(); + if gene.disabled { + continue; + } + let output = &gene.output; + let index = layer_addresses[output.layer as usize] + output.index as usize; + let output_channel = &senders[index]; + input_neuron + .outputs + .push((gene.input_weight, output_channel.clone())); + } + } + + let base = output_size as isize - neurons_count as isize; + + for it in (neurons_count - output_size) as isize..neurons_count as isize { + biases[it as usize] = topology.output_bias[(it + base) as usize].clone(); + } + + neurons + .iter_mut() + .zip(biases.iter()) + .for_each(|(neuron, bias)| { + neuron.threshold_one = bias.bias_input; + neuron.threshold_two = bias.bias_reset; + neuron.decay = bias.bias_update; + }); + + let input_size = *sizes.first().unwrap() as usize; + let input_channels = senders.into_iter().take(input_size).collect(); + + let output_channels: Vec> = neurons + .iter_mut() + .skip(neurons_count - output_size) + .map(|neuron| { + let (sdr, rcv) = channel(1); + neuron.outputs.push((T::one(), sdr)); + rcv + }) + .collect(); + + let net = SpikingNeuralNetwork { + input_channels, + output_channels, + }; + + for neuron in neurons { + neuron.spawn_task(); + } + + net + } +} + +impl SpikingNeuralNetwork +where + T: Float + std::ops::AddAssign + Display + Send + Sync + 'static, +{ + /// Sends input to the neural network + /// + /// Activates one neuron at a time: + /// + /// # Params + /// + /// `index`: Index of the neuron to activate + /// + /// `value`: Activation value for the neuron + pub async fn send(&mut self, index: usize, value: T) -> Option<()> { + self.input_channels[index].send(value).await.ok() + } + + /// Asynchronously subscribe to input from SNN + /// + /// Spiking Neural Networks produce output asynchronously, therefore we have to listen to their output to get actions to execute + /// + /// # Returns an option containing a tuple of `(index, value)`, if the option is None it means the Neural Network is shut down + pub async fn recv(&mut self) -> Option<(usize, T)> { + let futures: Vec<_> = self + .output_channels + .iter_mut() + .map(|rcv| rcv.recv()) + .map(Box::pin) + .collect(); + let (value, index, _) = select_all(futures).await; + value.map(|v| (index, v)) + } +} diff --git a/src/tests.rs b/src/tests.rs index 9eb0b41..e886994 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -21,7 +21,7 @@ pub fn test_import_network() { .expect("Something went wrong reading the topology_test.json"); let top = Topology::from_string(&serialized); - let cloned: NeuralNetwork = unsafe { NeuralNetwork::new(&top) }; + let cloned: NeuralNetwork = unsafe { NeuralNetwork::from_topology(&top) }; let mut net = NeuralNetwork::from_string(&serialized); assert_eq!(net, cloned); @@ -109,9 +109,10 @@ impl Game for TestGame { assert_eq!(*top, top_cp); let as_str = top.to_string(); - let network = unsafe { NeuralNetwork::new(top) }; + let network = unsafe { NeuralNetwork::from_topology(top) }; let top2 = Topology::from_string(&*as_str); - let network_from_string: NeuralNetwork = unsafe { NeuralNetwork::new(&top2) }; + let network_from_string: NeuralNetwork = + unsafe { NeuralNetwork::from_topology(&top2) }; if network != network_from_string { println!("{:?}, {:?}", top.layers_sizes, top2.layers_sizes); println!("{}", as_str); @@ -197,9 +198,10 @@ impl Game for MemoryCount { let top_cp = top.clone(); assert_eq!(*top, top_cp); let as_str = top.to_string(); - let network = unsafe { NeuralNetwork::new(top) }; + let network = unsafe { NeuralNetwork::from_topology(top) }; let top2 = Topology::from_string(&*as_str); - let network_from_string: NeuralNetwork = unsafe { NeuralNetwork::new(&top2) }; + let network_from_string: NeuralNetwork = + unsafe { NeuralNetwork::from_topology(&top2) }; if network != network_from_string { println!("{}", as_str); section!(); diff --git a/src/topology/bias_and_genes.rs b/src/topology/bias_and_genes.rs index 508bffc..8664baa 100644 --- a/src/topology/bias_and_genes.rs +++ b/src/topology/bias_and_genes.rs @@ -1,8 +1,13 @@ +use std::{cell::RefCell, rc::Rc}; + use crate::topology::bias::Bias; -use crate::topology::GeneSmrtPtr; use num::Float; use serde::{Deserialize, Serialize}; +use super::gene::Gene; + +pub type GeneSmrtPtr = Rc>>; + #[derive(Clone, Deserialize, Serialize, Debug)] pub struct BiasAndGenes where diff --git a/src/topology/topology_struct.rs b/src/topology/topology_struct.rs index 84b105d..65a08d5 100644 --- a/src/topology/topology_struct.rs +++ b/src/topology/topology_struct.rs @@ -16,9 +16,9 @@ use std::fmt::{Display, Formatter}; use std::rc::Rc; use std::sync::{Arc, Mutex}; -const NORMAL_STDDEV: f64 = 0.04; +use super::bias_and_genes::GeneSmrtPtr; -pub type GeneSmrtPtr = Rc>>; +const NORMAL_STDDEV: f64 = 0.04; #[derive(Deserialize, Serialize, Debug)] pub struct Topology diff --git a/src/train/mod.rs b/src/train/mod.rs index d597df2..ec502e1 100644 --- a/src/train/mod.rs +++ b/src/train/mod.rs @@ -1,6 +1,12 @@ +#[cfg(feature = "snn")] +use crate::neural_network::spiking::SpikingNeuralNetwork; +use crate::neural_network::NeuralNetwork; + pub mod error; pub mod evolution_number; mod species; mod training; -pub use training::*; +pub type Train<'a, T, F> = training::Train<'a, T, F, NeuralNetwork>; +#[cfg(feature = "snn")] +pub type TrainSnn<'a, T, F> = training::Train<'a, T, F, SpikingNeuralNetwork>; diff --git a/src/train/training.rs b/src/train/training.rs index ff410fd..5d84f3e 100644 --- a/src/train/training.rs +++ b/src/train/training.rs @@ -1,6 +1,7 @@ use crate::game::{Game, GameAsync}; #[cfg(target_arch = "wasm32")] use crate::instant_wasm_replacement::Instant; +use crate::neural_network::nn_trait::NN; use crate::neural_network::NeuralNetwork; use crate::section; use crate::topology::mutation_probabilities::MutationProbabilities; @@ -44,14 +45,15 @@ macro_rules! cond_iter_mut { }}; } -pub type TrainAccessCallback<'a, T, F> = Box)>; +pub type TrainAccessCallback<'a, T, F, N> = Box)>; /// The train struct is used to train a Neural Network on a simulation with the NEAT algorithm -pub struct Train<'a, T, F> +pub struct Train<'a, T, F, N> where F: 'a + Float + Sum + Display + std::ops::AddAssign + std::ops::SubAssign + Send + Sync, T: Game, &'a [F]: rayon::iter::IntoParallelIterator, + N: NN, { pub simulation: &'a mut T, iterations_: usize, @@ -72,14 +74,15 @@ where best_historical_score: F, no_progress_counter: usize, proba: MutationProbabilities, - access_train_object_fn: Option>, + access_train_object_fn: Option>, } -impl<'a, T, F> Train<'a, T, F> +impl<'a, T, F, N> Train<'a, T, F, N> where T: Game, F: 'a + Float + Sum + Display + std::ops::AddAssign + std::ops::SubAssign + Send + Sync, &'a [F]: rayon::iter::IntoParallelIterator, + N: NN, { /// Creates a Train instance /// @@ -134,7 +137,7 @@ where /// } /// ``` #[inline] - pub fn new(simulation: &'a mut T) -> Train<'a, T, F> { + pub fn new(simulation: &'a mut T) -> Train<'a, T, F, N> { let iterations_: usize = 1000; let max_individuals_: usize = 100; let inputs_ = None; @@ -314,7 +317,7 @@ where #[inline] pub fn access_train_object( &mut self, - callback: Box)>, + callback: Box)>, ) -> &mut Self { self.access_train_object_fn = Some(callback); self @@ -392,7 +395,7 @@ where .map(|top_rc| { let lock = top_rc.lock().unwrap(); let top = &*lock; - unsafe { NeuralNetwork::new(top) } + unsafe { NeuralNetwork::from_topology(top) } }) .collect(); println!( @@ -407,9 +410,7 @@ where cond_iter_mut!(self.topologies_) .zip(cond_iter!(results)) .for_each(|(topology, result)| { - if result.is_nan() { - panic!("NaN result"); - } + assert!(!result.is_nan(), "NaN result"); topology.lock().unwrap().set_last_result(*result); }) } @@ -647,11 +648,12 @@ where } } -impl<'a, T, F> Train<'a, T, F> +impl<'a, T, F, N> Train<'a, T, F, N> where T: GameAsync, F: 'a + Float + Sum + Display + std::ops::AddAssign + std::ops::SubAssign + Send + Sync, &'a [F]: rayon::iter::IntoParallelIterator, + N: NN, { pub async fn start_async(&mut self) -> Result<(), TrainingError> { let inputs = self.inputs_.ok_or(TrainingError::NoInput)?; From 4a7383658a72f1af5e69b446565a5e4db5270b9e Mon Sep 17 00:00:00 2001 From: Bene <37740907+Nereuxofficial@users.noreply.github.com> Date: Sun, 9 Jan 2022 15:03:44 +0100 Subject: [PATCH 2/4] Added imports in order for the tests to work --- benches/benchmark.rs | 1 + src/neural_network/mod.rs | 2 +- src/tests.rs | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/benches/benchmark.rs b/benches/benchmark.rs index 9b30e59..4cda290 100644 --- a/benches/benchmark.rs +++ b/benches/benchmark.rs @@ -4,6 +4,7 @@ use neat_gru::neural_network::nn::NeuralNetwork; use neat_gru::topology::Topology; use std::fs::File; use std::io::Read; +use neat_gru::neural_network::nn_trait::NN; fn benchmark(c: &mut Criterion) { let mut file = File::open("snakes_benchmark.json").expect("Can't open snakes_benchmark.json"); diff --git a/src/neural_network/mod.rs b/src/neural_network/mod.rs index fb7ec70..264144b 100644 --- a/src/neural_network/mod.rs +++ b/src/neural_network/mod.rs @@ -3,7 +3,7 @@ mod connection_relu; mod connection_sigmoid; mod functions; mod neuron; -mod nn; +pub mod nn; pub mod nn_trait; pub use nn::*; diff --git a/src/tests.rs b/src/tests.rs index e886994..f01ce76 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -5,6 +5,7 @@ use crate::train::Train; use crate::{game::Game, section}; use rand::{thread_rng, Rng}; use std::fs; +use crate::neural_network::nn_trait::NN; macro_rules! check_output { ($output: expr, $as_str: expr, $index: expr) => { From 13cd23a20c3411b7e1fc10adbe02ff39bd155204 Mon Sep 17 00:00:00 2001 From: Bene <37740907+Nereuxofficial@users.noreply.github.com> Date: Sun, 9 Jan 2022 15:26:50 +0100 Subject: [PATCH 3/4] Fixed a typo --- src/neural_network/spiking/spiking_nn.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neural_network/spiking/spiking_nn.rs b/src/neural_network/spiking/spiking_nn.rs index 0931a01..99a4617 100644 --- a/src/neural_network/spiking/spiking_nn.rs +++ b/src/neural_network/spiking/spiking_nn.rs @@ -14,7 +14,7 @@ use super::spiking_neuron::SpikingNeuron; /// /// Initialize with `from_topology` /// -/// You can send input witht the `send` function and subscribe asynchronously to output using `recv` +/// You can send input with the `send` function and subscribe asynchronously to output using `recv` pub struct SpikingNeuralNetwork where T: Float + std::ops::AddAssign + Display + Send + Sync + 'static, From 9e6151c997b7be63abc51e38b183d09861d7ef77 Mon Sep 17 00:00:00 2001 From: Bene <37740907+Nereuxofficial@users.noreply.github.com> Date: Sun, 9 Jan 2022 16:47:02 +0100 Subject: [PATCH 4/4] Updated the README, version bump --- Cargo.lock | 2 +- Cargo.toml | 2 +- README.md | 13 +++++++------ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9fa538e..7b59eb7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -379,7 +379,7 @@ dependencies = [ [[package]] name = "neat-gru" -version = "1.1.0" +version = "1.1.1" dependencies = [ "async-trait", "criterion", diff --git a/Cargo.toml b/Cargo.toml index 626a9b4..a83e83c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "neat-gru" -version = "1.1.0" +version = "1.1.1" authors = ["sakex "] edition = "2018" description = "NEAT algorithm with GRU gates" diff --git a/README.md b/README.md index 8b77acc..0b15403 100644 --- a/README.md +++ b/README.md @@ -10,20 +10,21 @@ ## Examples [XOR](examples/example.rs) -[Snake](examples/snake-cli) - - -Right now this is the only working example. You can run it via: -``` +```bash cargo run --example example ``` +[Snake](examples/snake-cli) + +```bash +cargo run --example snake-cli +``` ## How to use In `Cargo.toml`: ``` [dependencies] -neat-gru = 1.0.0" +neat-gru = 1.1.0" ``` Create a struct that implements the `Game` trait ```rust