Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
2b43c07
Implemented GUI for snake
Nereuxofficial Aug 8, 2021
d115927
Merge remote-tracking branch 'origin/main' into snake
Nereuxofficial Aug 15, 2021
7de0a29
Removed GUI & Increased Generations & the number of snakes
Nereuxofficial Aug 18, 2021
73a0039
Snakes can no longer rotate forever
Nereuxofficial Aug 18, 2021
1e652d4
Removed unnecessary dependency
Nereuxofficial Aug 18, 2021
0d4e38e
Simplified Snake example
Nereuxofficial Aug 19, 2021
8332522
Added Benchmark library
Nereuxofficial Aug 20, 2021
86f38b3
Remove unused imports
Nereuxofficial Aug 20, 2021
00f0316
Cleanups
Nereuxofficial Aug 20, 2021
ee9861b
Derive Hash instead of implementing it
Nereuxofficial Aug 20, 2021
04a9a22
Cleanups
Nereuxofficial Aug 20, 2021
e559d99
Further Cleanups & Refactoring
Nereuxofficial Aug 21, 2021
5ceaf25
Update README.md
Nereuxofficial Aug 21, 2021
70d905c
Refactoring & Docs
Nereuxofficial Aug 21, 2021
81ef41c
Better function name
Nereuxofficial Aug 21, 2021
39c9bcd
Cleanups
Nereuxofficial Aug 22, 2021
a3f2428
Added comments
Nereuxofficial Aug 22, 2021
0e2c83d
Reworded comments
Nereuxofficial Aug 22, 2021
459c4a3
Removed unnecessary .gitignore
Nereuxofficial Aug 22, 2021
9f5555a
Fancier Badges
Nereuxofficial Aug 23, 2021
2b68ac6
Fixed accidental Crtl+V
Nereuxofficial Aug 23, 2021
144bb77
Fixed accidental Paste
Nereuxofficial Aug 23, 2021
7208831
Split Github Workflows
Aug 25, 2021
b30052e
Removed unnecessary code
Nereuxofficial Aug 25, 2021
b7541d4
WIP: Fixing serialization
Nereuxofficial Aug 28, 2021
4946362
Revert changes to Topology::to_serde_string
Nereuxofficial Aug 28, 2021
f9bf105
Fixed Serialization and Benchmarks
Nereuxofficial Aug 28, 2021
d756682
Fixed Benchmark and included snakes.json
Nereuxofficial Aug 28, 2021
47f6733
Topology::to_string(&self) now uses to_string_pretty
Nereuxofficial Aug 28, 2021
b0f54b0
Removed duplicate functions & Created separate json for benchmarking
Nereuxofficial Aug 28, 2021
f4a9a2e
Restructured some Snake code
Nereuxofficial Aug 28, 2021
157395e
Snake Example simplified and with par_iter
Nereuxofficial Aug 29, 2021
b041c17
Updated dependencies
Nereuxofficial Aug 29, 2021
7b8e69a
Refactoring
Nereuxofficial Aug 29, 2021
72252be
Refactoring
Nereuxofficial Aug 29, 2021
64d8cf8
Added wasm32 build to Tests
Nereuxofficial Aug 29, 2021
0ef3e2c
Added wasm32 build to Tests
Nereuxofficial Aug 29, 2021
89114e9
Fix Github Actions wasm32 build
Nereuxofficial Aug 29, 2021
f834f57
Added benchmarks for math functions.
Sep 11, 2021
9d44b1a
Merge remote-tracking branch 'sakex/main'
Nereuxofficial Sep 12, 2021
cf0ab0c
WIP: Trying out the relu function
Nereuxofficial Oct 3, 2021
db104fd
Extended benchmarks
Nereuxofficial Oct 3, 2021
8c4f130
Merge remote-tracking branch 'upstream/main'
Nereuxofficial Jan 9, 2022
65d9744
Added relu as a feature to the crate
Nereuxofficial Nov 22, 2022
9431476
Merge branch 'main' into feature/relu
Nereuxofficial Nov 22, 2022
04c52e8
Fixed merge conflicts
Nereuxofficial Nov 22, 2022
6625c8e
PR Changes
Nereuxofficial Dec 1, 2022
a2fe01e
Fixed build errors
Nereuxofficial Dec 11, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
405 changes: 187 additions & 218 deletions Cargo.lock

Large diffs are not rendered by default.

23 changes: 16 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,41 @@ repository = "https://github.com/sakex/neat-gru-rust"
categories = ["science", "wasm"]
keywords = ["neat", "ai", "machine-learning", "genetic", "algorithm"]

[features]
default = ["tanh"]
relu = []
sigmoid = []
tanh = []

[lib]
crate-type = ["cdylib", "rlib"]

[dependencies]
serde = { version = "1.0.130", features = ["derive", "rc"] }
serde_json = "1.0.67"
serde = { version = "1.0.143", features = ["derive", "rc"] }
serde_json = "1.0.83"
num = "0.4.0"
rand = "0.8.4"
rand_distr = "0.4.1"
rand = "0.8.5"
rand_distr = "0.4.3"
numeric_literals = "0.2.0"
rayon = "1.5.1"
rayon = "1.5.3"
itertools = "0.10.1"
async-trait = "0.1.51"
log = "0.4.17"
tempdir = "0.3.7"

[dev-dependencies]
criterion = "0.3.5"
criterion = "0.4.0"

[[bench]]
name = "benchmark"
harness = false

[[bench]]
name="math_functions"
harness=false

[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2.1", features = ["js"] }
getrandom = { version = "0.2.7", features = ["js"] }

[profile.release]
# Link time optimisation, possibly even with C++, equivalent to G++'s -flto
Expand Down
3 changes: 2 additions & 1 deletion benches/benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
extern crate neat_gru;
use neat_gru::neural_network::NeuralNetwork;
use neat_gru::neural_network::nn::NeuralNetwork;
use neat_gru::topology::Topology;
use std::fs::File;
use std::io::Read;
Expand All @@ -15,5 +15,6 @@ fn benchmark(c: &mut Criterion) {
b.iter(|| network.compute(black_box(&[0.0, 0.0])))
});
}

criterion_group!(benches, benchmark);
criterion_main!(benches);
59 changes: 59 additions & 0 deletions benches/math_functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//! Contains benchmarks of the functions stored in neural_network::functions.
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use neat_gru::neural_network::functions::*;

extern crate neat_gru;
fn input_data() -> impl IntoIterator<Item = f32>{
let size: f32 = 0.3518392;
(0..20).map(move |i| size * i as f32)
}
fn bench_sigmoid(c: &mut Criterion) {
let mut group = c.benchmark_group("Sigmoid Function");
input_data().into_iter().for_each(|s|{
group.bench_with_input(BenchmarkId::from_parameter(s), &s, |b, size| {
b.iter(|| fast_sigmoid(*size))
});
});
group.finish();
}

fn bench_tanh(c: &mut Criterion) {
let mut group = c.benchmark_group("Tanh Function");
input_data().into_iter().for_each(|s|{
group.bench_with_input(BenchmarkId::from_parameter(s), &s, |b, s| {
b.iter(|| fast_tanh(*s))
});
});
group.finish();
}

fn bench_relu(c: &mut Criterion) {
let mut group = c.benchmark_group("Relu Function");
input_data().into_iter().for_each(|s|{
group.bench_with_input(BenchmarkId::from_parameter(s), &s, |b, s| {
b.iter(|| re_lu(*s))
});
});
group.finish();
}

fn comparison(c: &mut Criterion) {
let mut group = c.benchmark_group("Relu vs Sigmoid");
input_data().into_iter().for_each(|s|
{
group.bench_with_input(BenchmarkId::new("Sigmoid", s), &s,
|b, size| b.iter(|| fast_sigmoid(*size)));
group.bench_with_input(BenchmarkId::new("Relu", s), &s,
|b, s| b.iter(|| re_lu(*s)));
group.bench_with_input(BenchmarkId::new("Tanh", s), &s,
|b, s| b.iter(|| fast_tanh(*s)));
});
group.finish();
}

criterion_group! {
name = benches;
config = Criterion::default();
targets = bench_tanh, bench_sigmoid, bench_relu, comparison
}
criterion_main!(benches);
12 changes: 6 additions & 6 deletions examples/example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ impl Player {
let inputs = Xor::get_inputs();
// Calculate a score for every input
let outputs: Vec<f64> = inputs.iter().map(|i| self.net.compute(i)[0]).collect();
let mut scores: Vec<f64> = vec![];
for (input, output) in inputs.iter().zip(outputs.iter()) {
scores.push(compute_score(input, *output));
}
// And return the sum of the scores
scores.iter().sum()
// Return the sum of the scores
inputs
.iter()
.zip(outputs.iter())
.map(|(input, output)| compute_score(input, *output))
.sum()
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/neural_network/connection_gru.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::neural_network::functions::fast_tanh;
use crate::neural_network::functions::activate;
use crate::neural_network::neuron::Neuron;
use crate::utils::floats_almost_equal;
use num::Float;
Expand Down Expand Up @@ -70,7 +70,7 @@ where
#[inline]
pub(crate) fn activate(&mut self, value: T) {
let prev_reset = unsafe { (*self.output).prev_reset };
self.memory = fast_tanh(
self.memory = activate(
self.prev_input * self.input_weight + self.memory_weight * prev_reset * self.memory,
);
self.prev_input = value;
Expand Down
25 changes: 23 additions & 2 deletions src/neural_network/functions.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
use num::Float;
use numeric_literals::replace_numeric_literals;

#[inline(always)]
#[cfg(feature = "relu")]
pub fn activate<T: Float>(value: T) -> T {
re_lu(value)
}
#[cfg(feature = "tanh")]
#[inline(always)]
pub fn activate<T: Float>(value: T) -> T {
fast_tanh(value)
}
#[cfg(feature = "sigmoid")]
#[inline(always)]
pub fn activate<T: Float>(value: T) -> T {
fast_sigmoid(value)
}

#[inline(always)]
#[replace_numeric_literals(T::from(literal).unwrap())]
#[inline]
pub fn fast_sigmoid<T: Float>(value: T) -> T {
value / (1 + value.abs())
}

#[inline(always)]
#[replace_numeric_literals(T::from(literal).unwrap())]
#[inline]
pub fn fast_tanh<T: Float>(x: T) -> T {
if x.abs() >= 4.97 {
let values = [-1, 1];
Expand All @@ -19,3 +35,8 @@ pub fn fast_tanh<T: Float>(x: T) -> T {
let b = 135135 + x2 * (62370 + x2 * (3150 + x2 * 28));
a / b
}

#[inline(always)]
pub fn re_lu<T: Float>(x: T) -> T {
x.max(T::zero())
}
5 changes: 3 additions & 2 deletions src/neural_network/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
mod connection_gru;
mod connection_relu;
mod connection_sigmoid;
mod functions;
pub(crate) mod functions;
mod neuron;
mod nn;
pub mod nn;
pub mod nn_trait;

pub use nn::*;
14 changes: 7 additions & 7 deletions src/neural_network/neuron.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::neural_network::connection_gru::ConnectionGru;
use crate::neural_network::connection_relu::ConnectionRelu;
use crate::neural_network::connection_sigmoid::ConnectionSigmoid;
use crate::neural_network::functions::{fast_sigmoid, fast_tanh};
use crate::neural_network::functions::activate;
use crate::topology::bias::Bias;
use crate::utils::floats_almost_equal;
use num::Float;
Expand Down Expand Up @@ -95,20 +95,20 @@ where
#[replace_numeric_literals(T::from(literal).unwrap())]
#[inline]
pub fn get_value(&mut self) -> T {
let update_gate = fast_sigmoid(self.update);
let reset_gate = fast_sigmoid(self.reset);
let current_memory = fast_tanh(self.input + self.memory * reset_gate);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have options for tanh,sigmoid and relu for all possible activations as it wil bias the output

let update_gate = activate(self.update);
let reset_gate = activate(self.reset);
let current_memory = activate(self.input + self.memory * reset_gate);
let value = update_gate * self.memory + (1 - update_gate) * current_memory;

self.prev_reset = reset_gate;
fast_tanh(value)
activate(value)
}

#[replace_numeric_literals(T::from(literal).unwrap())]
#[inline]
pub fn feed_forward(&mut self) {
let update_gate = fast_sigmoid(self.update);
let reset_gate = fast_sigmoid(self.reset);
let update_gate = activate(self.update);
let reset_gate = activate(self.reset);
let current_memory = self.input + self.memory * reset_gate;
let value = update_gate * self.memory + (1 - update_gate) * current_memory;
for connection in self.connections_gru.iter_mut() {
Expand Down
24 changes: 24 additions & 0 deletions src/neural_network/nn_trait.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use std::fmt::Display;

use num::Float;

use crate::topology::Topology;

pub trait NN<T>: Sized
where
T: Float + std::ops::AddAssign + Display + Send,
{
/// Instantiates a new Neural Network from a `Topology`
///
/// # Safety
///
/// If the Topology is ill-formed, it will result in pointer overflow.
/// Topologies generated by this crate are guaranteed to be safe.
unsafe fn from_topology(topology: &Topology<T>) -> Self;

/// Deserializes a serde serialized Topology into a neural network
fn from_string(serialized: &str) -> Self {
let top = Topology::from_string(serialized);
unsafe { Self::from_topology(&top) }
}
}
6 changes: 2 additions & 4 deletions src/train/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,10 +781,8 @@ where
spec1
.adjusted_fitness
.partial_cmp(&spec2.adjusted_fitness)
.expect(&*format!(
"First: {}, second: {}, variance {}",
spec1.adjusted_fitness, spec2.adjusted_fitness, variance
))
.unwrap_or_else(|| panic!("First: {}, second: {}, variance {}",
spec1.adjusted_fitness, spec2.adjusted_fitness, variance))
}
});
}
Expand Down