diff --git a/crates/circuits/src/acc_blake3.rs b/crates/circuits/src/acc_blake3.rs new file mode 100644 index 000000000..f833a2861 --- /dev/null +++ b/crates/circuits/src/acc_blake3.rs @@ -0,0 +1,221 @@ +use binius_core::oracle::{OracleId, ShiftVariant}; +use binius_field::{ + as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, TowerField, +}; +use bytemuck::Pod; + +use crate::{arithmetic, arithmetic::Flags, builder::ConstraintSystemBuilder}; + +type F1 = BinaryField1b; + +// Gadget that performs two u32 variables XOR and then rotates the result +fn xor_rotate_right( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + log_size: usize, + a: OracleId, + b: OracleId, + rotate_right_offset: u32, +) -> OracleId +where + U: PackScalar + PackScalar + Pod, + F: TowerField, +{ + assert!(rotate_right_offset <= 32); + + builder.push_namespace(name); + + let xor = builder + .add_linear_combination("xor", log_size, [(a, F::ONE), (b, F::ONE)]) + .unwrap(); + + let rotate = builder + .add_shifted( + "rotate", + xor, + 32 - rotate_right_offset as usize, + crate::sha256::LOG_U32_BITS, + ShiftVariant::CircularLeft, + ) + .unwrap(); + + if let Some(witness) = builder.witness() { + let a_value = witness.get::(a).unwrap().as_slice::(); + let b_value = witness.get::(b).unwrap().as_slice::(); + + let mut xor_witness = witness.new_column::(xor); + let xor_value = xor_witness.as_mut_slice::(); + + for (idx, v) in xor_value.iter_mut().enumerate() { + *v = a_value[idx] ^ b_value[idx]; + } + + let mut rotate_witness = witness.new_column::(rotate); + let rotate_value = rotate_witness.as_mut_slice::(); + for (idx, v) in rotate_value.iter_mut().enumerate() { + *v = xor_value[idx].rotate_right(rotate_right_offset); + } + } + + builder.pop_namespace(); + + rotate +} + +#[allow(clippy::too_many_arguments)] +pub fn blake3_g( + builder: &mut ConstraintSystemBuilder, + a_in: OracleId, + b_in: OracleId, + c_in: OracleId, + d_in: OracleId, + mx: OracleId, + my: OracleId, + log_size: usize, +) -> Result<[OracleId; 4], anyhow::Error> +where + U: UnderlierType + Pod + PackScalar + PackScalar, + F: TowerField, +{ + builder.push_namespace("blake3_g"); + + let a1 = arithmetic::u32::add3(builder, "a_in + b_in + mx", a_in, b_in, mx, Flags::Unchecked)?; + + let d1 = xor_rotate_right(builder, "(d_in ^ a1).rotate_right(16)", log_size, d_in, a1, 16u32); + + let c1 = arithmetic::u32::add(builder, "c_in + d1", c_in, d1, Flags::Unchecked)?; + + let b1 = xor_rotate_right(builder, "(b_in ^ c1).rotate_right(12)", log_size, b_in, c1, 12u32); + + let a2 = + arithmetic::u32::add3(builder, "a1 + b1 + my_in", a1, b1, my, Flags::Unchecked).unwrap(); + + let d2 = xor_rotate_right(builder, "(d1 ^ a2).rotate_right(8)", log_size, d1, a2, 8u32); + + let c2 = arithmetic::u32::add(builder, "c1 + d2", c1, d2, Flags::Unchecked)?; + + let b2 = xor_rotate_right(builder, "(b1 ^ c2).rotate_right(7)", log_size, b1, c2, 7u32); + + builder.pop_namespace(); + + Ok([a2, b2, c2, d2]) +} + +#[cfg(test)] +mod tests { + use binius_core::constraint_system::validate::validate_witness; + use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b}; + use binius_maybe_rayon::prelude::*; + + use crate::{ + acc_blake3::blake3_g, + builder::ConstraintSystemBuilder, + unconstrained::{unconstrained, variables_u32}, + }; + + type U = OptimalUnderlier; + type F128 = BinaryField128b; + type F1 = BinaryField1b; + + // The Blake3 mixing function, G, which mixes either a column or a diagonal. + // https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs + const fn g( + a_in: u32, + b_in: u32, + c_in: u32, + d_in: u32, + mx: u32, + my: u32, + ) -> (u32, u32, u32, u32) { + let a1 = a_in.wrapping_add(b_in).wrapping_add(mx); + let d1 = (d_in ^ a1).rotate_right(16); + let c1 = c_in.wrapping_add(d1); + let b1 = (b_in ^ c1).rotate_right(12); + + let a2 = a1.wrapping_add(b1).wrapping_add(my); + let d2 = (d1 ^ a2).rotate_right(8); + let c2 = c1.wrapping_add(d2); + let b2 = (b1 ^ c2).rotate_right(7); + + (a2, b2, c2, d2) + } + + #[test] + fn test_vector() { + // Let's use some fixed data input to check that our in-circuit computation + // produces same output as out-of-circuit one + let a = 0xaaaaaaaau32; + let b = 0xbbbbbbbbu32; + let c = 0xccccccccu32; + let d = 0xddddddddu32; + let mx = 0xffff00ffu32; + let my = 0xff00ffffu32; + + let (expected_0, expected_1, expected_2, expected_3) = g(a, b, c, d, mx, my); + + let log_size = 8usize; + let size = 1 << log_size; + + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let a_in = + variables_u32::(&mut builder, "a", log_size, vec![a; size]).unwrap(); + let b_in = + variables_u32::(&mut builder, "b", log_size, vec![b; size]).unwrap(); + let c_in = + variables_u32::(&mut builder, "c", log_size, vec![c; size]).unwrap(); + let d_in = + variables_u32::(&mut builder, "d", log_size, vec![d; size]).unwrap(); + let mx_in = + variables_u32::(&mut builder, "mx", log_size, vec![mx; size]).unwrap(); + let my_in = + variables_u32::(&mut builder, "my", log_size, vec![my; size]).unwrap(); + + let output = + blake3_g(&mut builder, a_in, b_in, c_in, d_in, mx_in, my_in, log_size).unwrap(); + + if let Some(witness) = builder.witness() { + ( + witness.get::(output[0]).unwrap().as_slice::(), + witness.get::(output[1]).unwrap().as_slice::(), + witness.get::(output[2]).unwrap().as_slice::(), + witness.get::(output[3]).unwrap().as_slice::(), + ) + .into_par_iter() + .for_each(|(actual_0, actual_1, actual_2, actual_3)| { + assert_eq!(*actual_0, expected_0); + assert_eq!(*actual_1, expected_1); + assert_eq!(*actual_2, expected_2); + assert_eq!(*actual_3, expected_3); + }); + } + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); + } + + #[test] + fn test_random_input() { + let log_size = 8usize; + + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let a_in = unconstrained::(&mut builder, "a", log_size).unwrap(); + let b_in = unconstrained::(&mut builder, "b", log_size).unwrap(); + let c_in = unconstrained::(&mut builder, "c", log_size).unwrap(); + let d_in = unconstrained::(&mut builder, "d", log_size).unwrap(); + let mx_in = unconstrained::(&mut builder, "mx", log_size).unwrap(); + let my_in = unconstrained::(&mut builder, "my", log_size).unwrap(); + + blake3_g(&mut builder, a_in, b_in, c_in, d_in, mx_in, my_in, log_size).unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); + } +} diff --git a/crates/circuits/src/arithmetic/u32.rs b/crates/circuits/src/arithmetic/u32.rs index 88a0c9a70..d428aa2d3 100644 --- a/crates/circuits/src/arithmetic/u32.rs +++ b/crates/circuits/src/arithmetic/u32.rs @@ -151,6 +151,69 @@ where Ok(zout) } +// Gadget that adds three u32 at once +pub fn add3( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + xin: OracleId, + yin: OracleId, + zin: OracleId, + flags: super::Flags, +) -> Result +where + U: PackScalar + PackScalar + Pod, + F: TowerField, +{ + builder.push_namespace(name); + let log_rows = builder.log_rows([xin, yin, zin])?; + let left = builder.add_linear_combination( + "left", + log_rows, + [(xin, F::ONE), (yin, F::ONE), (zin, F::ONE)], + )?; + let right = builder.add_committed("right", log_rows, BinaryField1b::TOWER_LEVEL); + + if let Some(witness) = builder.witness() { + let x_vals = witness.get::(xin)?.as_slice::(); + let y_vals = witness.get::(yin)?.as_slice::(); + let z_vals = witness.get::(zin)?.as_slice::(); + + let mut left_values = witness.new_column::(left); + let mut right_values = witness.new_column::(right); + + // In order to reduce our task to a simpler two integers addition (that we have gadget for) we use a trick from + // https://stackoverflow.com/questions/26228262/how-does-this-function-sum-3-integers-using-only-bit-wise-operators + (x_vals, y_vals, z_vals, left_values.as_mut_slice(), right_values.as_mut_slice()) + .into_par_iter() + .for_each(|(x, y, z, left, right)| { + *left = (*x ^ *y) ^ *z; + *right = (*x) & (*y) | (*x) & (*z) | (*y & *z); + }); + } + + // right << 1 + let right_shifted = shl(builder, "right_shifted", right, 1)?; + + builder.assert_zero( + "left", + [xin, yin, zin, left], + arith_expr!([x, y, z, left] = x + y + z - left).convert_field(), + ); + + // We apply following rule: a OR b = a XOR b XOR (a AND B) to the expression of 'right' column defined above. + builder.assert_zero( + "right", + [xin, yin, zin, right], + arith_expr!( + [x, y, z, right] = x * (y + z) + y * z * (1 + x * (1 + (y + z + x * y * z))) - right + ) + .convert_field(), + ); + + builder.pop_namespace(); + add(builder, "add3 -> add2", left, right_shifted, flags) +} + pub fn sub( builder: &mut ConstraintSystemBuilder, name: impl ToString, diff --git a/crates/circuits/src/lib.rs b/crates/circuits/src/lib.rs index c76857c67..197af0fc4 100644 --- a/crates/circuits/src/lib.rs +++ b/crates/circuits/src/lib.rs @@ -9,6 +9,7 @@ #![feature(array_try_map, array_try_from_fn)] #![allow(clippy::module_inception)] +pub mod acc_blake3; pub mod arithmetic; pub mod bitwise; pub mod builder; diff --git a/crates/circuits/src/sha256.rs b/crates/circuits/src/sha256.rs index d2e8fe6e9..8e47c3247 100644 --- a/crates/circuits/src/sha256.rs +++ b/crates/circuits/src/sha256.rs @@ -16,7 +16,7 @@ use itertools::izip; use crate::{arithmetic, builder::ConstraintSystemBuilder}; -const LOG_U32_BITS: usize = checked_log_2(32); +pub const LOG_U32_BITS: usize = checked_log_2(32); type B1 = BinaryField1b; diff --git a/crates/circuits/src/unconstrained.rs b/crates/circuits/src/unconstrained.rs index a798ec5f1..811ee9404 100644 --- a/crates/circuits/src/unconstrained.rs +++ b/crates/circuits/src/unconstrained.rs @@ -33,3 +33,30 @@ where Ok(rng) } + +pub fn variables_u32( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + log_size: usize, + value: Vec, +) -> Result +where + U: UnderlierType + Pod + PackScalar + PackScalar, + F: TowerField + ExtensionField, + FS: TowerField, +{ + let rng = builder.add_committed(name, log_size, FS::TOWER_LEVEL); + + if let Some(witness) = builder.witness() { + witness + .new_column::(rng) + .as_mut_slice::() + .into_par_iter() + .zip(value.into_par_iter()) + .for_each(|(data, value)| { + *data = value; + }); + } + + Ok(rng) +}