diff --git a/Cargo.toml b/Cargo.toml index 85b4c93..86cb2a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,10 @@ criterion = "0.3.5" name = "benchmark" harness = false +[[bench]] +name="math_functions" +harness=false + [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2.1", features = ["js"] } diff --git a/benches/benchmark.rs b/benches/benchmark.rs index ebe028f..deac3f3 100644 --- a/benches/benchmark.rs +++ b/benches/benchmark.rs @@ -15,5 +15,6 @@ fn benchmark(c: &mut Criterion) { b.iter(|| network.compute(black_box(&[0.0, 0.0]))) }); } + criterion_group!(benches, benchmark); criterion_main!(benches); diff --git a/benches/math_functions.rs b/benches/math_functions.rs new file mode 100644 index 0000000..64e730b --- /dev/null +++ b/benches/math_functions.rs @@ -0,0 +1,105 @@ +//! Contains benchmarks of the functions stored in neural_network::functions. +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use neat_gru::neural_network::functions::*; + +extern crate neat_gru; + +fn bench_sigmoid(c: &mut Criterion) { + let size: f32 = 0.3518392; + let mut group = c.benchmark_group("Sigmoid Function"); + for size in [ + size * 0.0, + size, + size * 2.0, + size * 4.0, + size * 6.0, + size * 8.0, + size * 10.0, + size * 12.0, + size * 14.0, + ] + .iter() + { + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, size| { + b.iter(|| fast_sigmoid(*size)) + }); + } + group.finish(); +} + +fn bench_tanh(c: &mut Criterion) { + let size: f32 = 0.3518392; + let mut group = c.benchmark_group("tanh Function"); + for size in [ + size * 0.0, + size, + size * 2.0, + size * 4.0, + size * 6.0, + size * 8.0, + size * 10.0, + size * 12.0, + size * 14.0, + ] + .iter() + { + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, size| { + b.iter(|| fast_tanh(*size)) + }); + } + group.finish(); +} + +fn bench_relu(c: &mut Criterion) { + let size: f32 = 0.3518392; + let mut group = c.benchmark_group("relu Function"); + for size in [ + size * 0.0, + size, + size * 2.0, + size * 4.0, + size * 6.0, + size * 8.0, + size * 10.0, + size * 12.0, + size * 14.0, + ] + .iter() + { + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, size| { + b.iter(|| re_lu(*size)) + }); + } + group.finish(); +} + +fn comparison(c: &mut Criterion) { + let size: f32 = 0.3518392; + let mut group = c.benchmark_group("relu vs sigmoid"); + for size in [ + size * 0.0, + size, + size * 2.0, + size * 4.0, + size * 6.0, + size * 8.0, + size * 10.0, + size * 12.0, + size * 14.0, + ] + .iter() + { + group.bench_with_input(BenchmarkId::new("Sigmoid", size), size, + |b, size| b.iter(|| fast_sigmoid(*size))); + group.bench_with_input(BenchmarkId::new("Relu", size), size, + |b, size| b.iter(|| fast_sigmoid(*size))); + } + group.finish(); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = bench_tanh, bench_sigmoid, bench_relu, comparison +} +criterion_main!(benches); diff --git a/src/neural_network/functions.rs b/src/neural_network/functions.rs index 00aa9ba..77660ec 100644 --- a/src/neural_network/functions.rs +++ b/src/neural_network/functions.rs @@ -19,3 +19,8 @@ pub fn fast_tanh(x: T) -> T { let b = 135135 + x2 * (62370 + x2 * (3150 + x2 * 28)); a / b } + +#[inline] +pub fn re_lu(x: T) -> T{ + x.max(T::zero()) +} \ No newline at end of file diff --git a/src/neural_network/mod.rs b/src/neural_network/mod.rs index e51d00f..8e37ba2 100644 --- a/src/neural_network/mod.rs +++ b/src/neural_network/mod.rs @@ -1,7 +1,7 @@ mod connection_gru; mod connection_relu; mod connection_sigmoid; -mod functions; +pub mod functions; mod neuron; mod nn; diff --git a/src/neural_network/neuron.rs b/src/neural_network/neuron.rs index 3b37ebf..64480f7 100644 --- a/src/neural_network/neuron.rs +++ b/src/neural_network/neuron.rs @@ -1,7 +1,7 @@ use crate::neural_network::connection_gru::ConnectionGru; use crate::neural_network::connection_relu::ConnectionRelu; use crate::neural_network::connection_sigmoid::ConnectionSigmoid; -use crate::neural_network::functions::{fast_sigmoid, fast_tanh}; +use crate::neural_network::functions::{fast_sigmoid, fast_tanh, re_lu}; use crate::topology::bias::Bias; use crate::utils::floats_almost_equal; use num::Float; @@ -95,8 +95,8 @@ where #[replace_numeric_literals(T::from(literal).unwrap())] #[inline] pub fn get_value(&mut self) -> T { - let update_gate = fast_sigmoid(self.update); - let reset_gate = fast_sigmoid(self.reset); + let update_gate = re_lu(self.update); + let reset_gate = re_lu(self.reset); let current_memory = fast_tanh(self.input + self.memory * reset_gate); let value = update_gate * self.memory + (1 - update_gate) * current_memory; @@ -107,8 +107,8 @@ where #[replace_numeric_literals(T::from(literal).unwrap())] #[inline] pub fn feed_forward(&mut self) { - let update_gate = fast_sigmoid(self.update); - let reset_gate = fast_sigmoid(self.reset); + let update_gate = re_lu(self.update); + let reset_gate = re_lu(self.reset); let current_memory = self.input + self.memory * reset_gate; let value = update_gate * self.memory + (1 - update_gate) * current_memory; for connection in self.connections_gru.iter_mut() {