diff --git a/.cargo/config.toml b/.cargo/config.toml index 7acd84b05..6287bb569 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,5 @@ [build] rustdocflags = ["-Dwarnings", "--html-in-header", "doc/katex-header.html"] + +[alias] +fast_test = "test --tests" diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index d2a03abf4..000000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,203 +0,0 @@ -workflow: - rules: - - if: $CI_PIPELINE_SOURCE == "merge_request_event" - - if: $CI_COMMIT_BRANCH == 'main' - -variables: - CARGO_HOME: "$CI_PROJECT_DIR/toolchains/cargo" - RUSTUP_HOME: "$CI_PROJECT_DIR/toolchains" - GIT_CLEAN_FLAGS: "-ffdx --exclude toolchains" - FF_TIMESTAMPS: true - - -stages: - - lint - - build - - test - - deploy - -# AMD job configuration template -.job_template_amd: - image: rustlang/rust:nightly - variables: - KUBERNETES_NODE_SELECTOR_INSTANCE_TYPE: "ulvt-node-pool=ulvt-c7i-2xlarge" - KUBERNETES_CPU_REQUEST: "6" - KUBERNETES_MEMORY_REQUEST: "14Gi" - GIT_CLONE_PATH: "$CI_BUILDS_DIR/binius_amd" - tags: - - k8s - -# AMD job configuration template with stable Rust -.job_template_amd_stable: - extends: .test_job_template_amd - variables: - RUST_VERSION: "1.83.0" - before_script: - # workaround for https://github.com/rust-lang/rustup/issues/2886 - - rustup set auto-self-update disable - - rustup toolchain install $RUST_VERSION - -# ARM job configuration template -.job_template_arm: - image: rustlang/rust:nightly - variables: - KUBERNETES_NODE_SELECTOR_INSTANCE_TYPE: "ulvt-node-pool=ulvt-c8g-2xlarge" - KUBERNETES_NODE_SELECTOR_ARCH: 'kubernetes.io/arch=arm64' - KUBERNETES_CPU_REQUEST: "6" - KUBERNETES_MEMORY_REQUEST: "14Gi" - GIT_CLONE_PATH: "$CI_BUILDS_DIR/binius_arm" - before_script: - - if [ "$(uname -m)" != "aarch64" ]; then echo "This job is intended to run on ARM architecture only."; exit 1; fi - tags: - - k8s - -# Linting jobs -copyright-check: - extends: .job_template_amd - stage: lint - script: - - ./scripts/check_copyright_notice.sh - -cargofmt: - extends: .job_template_amd - stage: lint - script: - - cargo fmt --check - -clippy: - extends: .job_template_amd - stage: lint - script: - - cargo clippy --all --all-features --tests --benches --examples -- -D warnings - -# Building jobs - -# TODO: use a docker image with `wasm32-unknown-unknown` target preinstalled -build-debug-wasm: - extends: .job_template_amd - stage: build - script: - - rustup target add wasm32-unknown-unknown - - cargo build --package binius_field --target wasm32-unknown-unknown - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -build-debug-amd: - extends: .job_template_amd - stage: build - script: - - cargo build --tests --benches --examples - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -# Build without default features -# This checks if build without `rayon` feature works. -build-debug-amd-no-default-features: - extends: .job_template_amd - stage: build - script: - - cargo build --tests --benches --examples --no-default-features - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -build-debug-amd-stable: - extends: .job_template_amd_stable - stage: build - script: - - cargo +$RUST_VERSION build --tests --benches --examples -p binius_core --features stable_only - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -build-debug-arm: - extends: .job_template_arm - stage: build - script: - - cargo build --tests --benches --examples - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -.test_job_template_amd: - extends: .job_template_amd - dependencies: - - build-debug-amd - -.test_job_template_amd_stable: - extends: .job_template_amd_stable - dependencies: - - build-debug-amd-stable - -.test_job_template_arm: - extends: .job_template_arm - dependencies: - - build-debug-arm - -unit-test-amd-portable: - extends: .test_job_template_amd - script: - - RUSTFLAGS="-C target-cpu=generic" ./scripts/run_tests_and_examples.sh - -unit-test-arm-portable: - extends: .test_job_template_arm - script: - - RUSTFLAGS="-C target-cpu=generic" ./scripts/run_tests_and_examples.sh - -unit-test-single-threaded: - extends: .test_job_template_amd - script: - - RAYON_NUM_THREADS=1 RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh - -unit-test-no-default-features: - extends: .test_job_template_amd - script: - - CARGO_EXTRA_FLAGS="--no-default-features" RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh - -unit-test-amd: - extends: .test_job_template_amd - script: - - RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh - -unit-test-amd-stable: - extends: .test_job_template_amd_stable - script: - - RUSTFLAGS="-C target-cpu=native" CARGO_STABLE=true ./scripts/run_tests_and_examples.sh - -unit-test-arm: - extends: .test_job_template_arm - script: - - RUSTFLAGS="-C target-cpu=native -C target-feature=+aes" ./scripts/run_tests_and_examples.sh - -# Documentation and pages jobs -build-docs: - extends: .job_template_amd - stage: build - script: - - cargo doc --no-deps - artifacts: - paths: - - target/doc - expire_in: 1 week - -pages: - extends: .job_template_amd - stage: deploy - dependencies: - - build-docs - script: - - mv target/doc public - - echo "/ /binius_core 302" > public/_redirects - artifacts: - paths: - - public - only: - refs: - - main # Deploy for every push to the main branch, for now diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 000000000..61d772514 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,5 @@ +# Code owners for the entire repository +* @jimpo-ulvt @onesk + +# Code owners for the .github path +/.github/ @IrreducibleOSS/Infrastructure diff --git a/Cargo.toml b/Cargo.toml index 9cf5f0969..2ee58dfc2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -100,6 +100,7 @@ generic-array = "0.14.7" getset = "0.1.2" groestl_crypto = { package = "groestl", version = "0.10.1" } hex-literal = "0.4.1" +inventory = "0.3.19" itertools = "0.13.0" lazy_static = "1.5.0" paste = "1.0.15" @@ -112,13 +113,13 @@ seq-macro = "0.3.5" sha2 = "0.10.8" stackalloc = "1.2.1" subtle = "2.5.0" -syn = { version = "2.0.60", features = ["full"] } +syn = { version = "2.0.98", features = ["extra-traits"] } thiserror = "2.0.3" thread_local = "1.1.7" tiny-keccak = { version = "2.0.2", features = ["keccak"] } trait-set = "0.3.0" tracing = "0.1.38" -tracing-profile = "0.9.0" +tracing-profile = "0.10.1" transpose = "0.2.2" [profile.release] @@ -137,4 +138,4 @@ opt-level = 1 debug = true debug-assertions = true overflow-checks = true -lto = false +lto = "off" diff --git a/crates/circuits/src/arithmetic/u32.rs b/crates/circuits/src/arithmetic/u32.rs index 88a0c9a70..0dcdd1ae7 100644 --- a/crates/circuits/src/arithmetic/u32.rs +++ b/crates/circuits/src/arithmetic/u32.rs @@ -1,25 +1,17 @@ // Copyright 2024-2025 Irreducible Inc. use binius_core::oracle::{OracleId, ProjectionVariant, ShiftVariant}; -use binius_field::{ - as_packed_field::PackScalar, packed::set_packed_slice, BinaryField1b, BinaryField32b, - ExtensionField, Field, TowerField, -}; +use binius_field::{packed::set_packed_slice, BinaryField1b, BinaryField32b, Field, TowerField}; use binius_macros::arith_expr; use binius_maybe_rayon::prelude::*; -use bytemuck::Pod; use crate::{builder::ConstraintSystemBuilder, transparent}; -pub fn packed( - builder: &mut ConstraintSystemBuilder, +pub fn packed( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, -) -> Result -where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, -{ +) -> Result { let packed = builder.add_packed(name, input, 5)?; if let Some(witness) = builder.witness() { witness.set( @@ -32,17 +24,13 @@ where Ok(packed) } -pub fn mul_const( - builder: &mut ConstraintSystemBuilder, +pub fn mul_const( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, value: u32, flags: super::Flags, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { if value == 0 { let log_rows = builder.log_rows([input])?; return transparent::constant(builder, name, log_rows, BinaryField1b::ZERO); @@ -85,17 +73,13 @@ where Ok(result) } -pub fn add( - builder: &mut ConstraintSystemBuilder, +pub fn add( + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, flags: super::Flags, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([xin, yin])?; let cout = builder.add_committed("cout", log_rows, BinaryField1b::TOWER_LEVEL); @@ -151,17 +135,13 @@ where Ok(zout) } -pub fn sub( - builder: &mut ConstraintSystemBuilder, +pub fn sub( + builder: &mut ConstraintSystemBuilder, name: impl ToString, zin: OracleId, yin: OracleId, flags: super::Flags, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([zin, yin])?; let cout = builder.add_committed("cout", log_rows, BinaryField1b::TOWER_LEVEL); @@ -218,16 +198,12 @@ where Ok(xout) } -pub fn half( - builder: &mut ConstraintSystemBuilder, +pub fn half( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, flags: super::Flags, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { if matches!(flags, super::Flags::Checked) { // Assert that the number is even let lsb = select_bit(builder, "lsb", input, 0)?; @@ -236,23 +212,24 @@ where shr(builder, name, input, 1) } -pub fn shl( - builder: &mut ConstraintSystemBuilder, +pub fn shl( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, offset: usize, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { if offset == 0 { return Ok(input); } let shifted = builder.add_shifted(name, input, offset, 5, ShiftVariant::LogicalLeft)?; if let Some(witness) = builder.witness() { - (witness.new_column(shifted).as_mut_slice::(), witness.get(input)?.as_slice::()) + ( + witness + .new_column::(shifted) + .as_mut_slice::(), + witness.get::(input)?.as_slice::(), + ) .into_par_iter() .for_each(|(shifted, input)| *shifted = *input << offset); } @@ -260,23 +237,24 @@ where Ok(shifted) } -pub fn shr( - builder: &mut ConstraintSystemBuilder, +pub fn shr( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, offset: usize, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { if offset == 0 { return Ok(input); } let shifted = builder.add_shifted(name, input, offset, 5, ShiftVariant::LogicalRight)?; if let Some(witness) = builder.witness() { - (witness.new_column(shifted).as_mut_slice::(), witness.get(input)?.as_slice::()) + ( + witness + .new_column::(shifted) + .as_mut_slice::(), + witness.get::(input)?.as_slice::(), + ) .into_par_iter() .for_each(|(shifted, input)| *shifted = *input >> offset); } @@ -284,16 +262,12 @@ where Ok(shifted) } -pub fn select_bit( - builder: &mut ConstraintSystemBuilder, +pub fn select_bit( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, index: usize, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { let log_rows = builder.log_rows([input])?; anyhow::ensure!(log_rows >= 5, "Polynomial must have n_vars >= 5. Got {log_rows}"); anyhow::ensure!(index < 32, "Only index values between 0 and 32 are allowed. Got {index}"); @@ -304,7 +278,7 @@ where if let Some(witness) = builder.witness() { let mut bits = witness.new_column::(bits); let bits = bits.packed(); - let input = witness.get(input)?.as_slice::(); + let input = witness.get::(input)?.as_slice::(); input.iter().enumerate().for_each(|(i, &val)| { let value = match (val >> index) & 1 { 0 => BinaryField1b::ZERO, @@ -317,16 +291,12 @@ where Ok(bits) } -pub fn constant( - builder: &mut ConstraintSystemBuilder, +pub fn constant( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_count: usize, value: u32, -) -> Result -where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); // This would not need to be committed if we had `builder.add_unpacked(..)` let output = builder.add_committed("output", log_count + 5, BinaryField1b::TOWER_LEVEL); @@ -360,50 +330,47 @@ where #[cfg(test)] mod tests { - use binius_core::constraint_system::validate::validate_witness; - use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b, TowerField}; + use binius_field::{BinaryField1b, TowerField}; - use crate::{arithmetic, builder::ConstraintSystemBuilder, unconstrained::unconstrained}; - - type U = OptimalUnderlier; - type F = BinaryField128b; + use crate::{arithmetic, builder::test_utils::test_circuit, unconstrained::unconstrained}; #[test] fn test_mul_const() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - - let a = builder.add_committed("a", 5, BinaryField1b::TOWER_LEVEL); - if let Some(witness) = builder.witness() { - witness - .new_column::(a) - .as_mut_slice::() - .iter_mut() - .for_each(|v| *v = 0b01000000_00000000_00000000_00000000u32); - } - - let _c = arithmetic::u32::mul_const(&mut builder, "mul3", a, 3, arithmetic::Flags::Checked) - .unwrap(); + test_circuit(|builder| { + let a = builder.add_committed("a", 5, BinaryField1b::TOWER_LEVEL); + if let Some(witness) = builder.witness() { + witness + .new_column::(a) + .as_mut_slice::() + .iter_mut() + .for_each(|v| *v = 0b01000000_00000000_00000000_00000000u32); + } + let _c = arithmetic::u32::mul_const(builder, "mul3", a, 3, arithmetic::Flags::Checked)?; + Ok(vec![]) + }) + .unwrap(); + } - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + #[test] + fn test_add() { + test_circuit(|builder| { + let log_size = 14; + let a = unconstrained::(builder, "a", log_size)?; + let b = unconstrained::(builder, "b", log_size)?; + let _c = arithmetic::u32::add(builder, "u32add", a, b, arithmetic::Flags::Unchecked)?; + Ok(vec![]) + }) + .unwrap(); } #[test] fn test_sub() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - - let a = unconstrained::(&mut builder, "a", 7).unwrap(); - let b = unconstrained::(&mut builder, "a", 7).unwrap(); - let _c = - arithmetic::u32::sub(&mut builder, "c", a, b, arithmetic::Flags::Unchecked).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let a = unconstrained::(builder, "a", 7).unwrap(); + let b = unconstrained::(builder, "a", 7).unwrap(); + let _c = arithmetic::u32::sub(builder, "c", a, b, arithmetic::Flags::Unchecked)?; + Ok(vec![]) + }) + .unwrap(); } } diff --git a/crates/circuits/src/bitwise.rs b/crates/circuits/src/bitwise.rs index 33e2d3a11..42d7cbd77 100644 --- a/crates/circuits/src/bitwise.rs +++ b/crates/circuits/src/bitwise.rs @@ -1,25 +1,18 @@ // Copyright 2024-2025 Irreducible Inc. use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, TowerField, -}; +use binius_field::{BinaryField1b, Field, TowerField}; use binius_macros::arith_expr; use binius_maybe_rayon::prelude::*; -use bytemuck::Pod; use crate::builder::ConstraintSystemBuilder; -pub fn and( - builder: &mut ConstraintSystemBuilder, +pub fn and( + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([xin, yin])?; let zout = builder.add_committed("zout", log_rows, BinaryField1b::TOWER_LEVEL); @@ -45,19 +38,16 @@ where Ok(zout) } -pub fn xor( - builder: &mut ConstraintSystemBuilder, +pub fn xor( + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([xin, yin])?; - let zout = builder.add_linear_combination("zout", log_rows, [(xin, F::ONE), (yin, F::ONE)])?; + let zout = + builder.add_linear_combination("zout", log_rows, [(xin, Field::ONE), (yin, Field::ONE)])?; if let Some(witness) = builder.witness() { ( witness.get::(xin)?.as_slice::(), @@ -75,16 +65,12 @@ where Ok(zout) } -pub fn or( - builder: &mut ConstraintSystemBuilder, +pub fn or( + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([xin, yin])?; let zout = builder.add_committed("zout", log_rows, BinaryField1b::TOWER_LEVEL); @@ -109,3 +95,24 @@ where builder.pop_namespace(); Ok(zout) } + +#[cfg(test)] +mod tests { + use binius_field::BinaryField1b; + + use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained}; + + #[test] + fn test_bitwise() { + test_circuit(|builder| { + let log_size = 6; + let a = unconstrained::(builder, "a", log_size)?; + let b = unconstrained::(builder, "b", log_size)?; + let _and = super::and(builder, "and", a, b)?; + let _xor = super::xor(builder, "xor", a, b)?; + let _or = super::or(builder, "or", a, b)?; + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/circuits/src/blake3.rs b/crates/circuits/src/blake3.rs new file mode 100644 index 000000000..4d5325587 --- /dev/null +++ b/crates/circuits/src/blake3.rs @@ -0,0 +1,209 @@ +// Copyright 2024-2025 Irreducible Inc. + +use binius_core::oracle::{OracleId, ShiftVariant}; +use binius_field::{BinaryField1b, Field}; +use binius_utils::checked_arithmetics::checked_log_2; + +use crate::{ + arithmetic, + arithmetic::Flags, + builder::{types::F, ConstraintSystemBuilder}, +}; + +type F1 = BinaryField1b; +const LOG_U32_BITS: usize = checked_log_2(32); + +// Gadget that performs two u32 variables XOR and then rotates the result +fn xor_rotate_right( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + log_size: usize, + a: OracleId, + b: OracleId, + rotate_right_offset: u32, +) -> Result { + assert!(rotate_right_offset <= 32); + + builder.push_namespace(name); + + let xor = builder + .add_linear_combination("xor", log_size, [(a, F::ONE), (b, F::ONE)]) + .unwrap(); + + let rotate = builder.add_shifted( + "rotate", + xor, + 32 - rotate_right_offset as usize, + LOG_U32_BITS, + ShiftVariant::CircularLeft, + )?; + + if let Some(witness) = builder.witness() { + let a_value = witness.get::(a)?.as_slice::(); + let b_value = witness.get::(b)?.as_slice::(); + + let mut xor_witness = witness.new_column::(xor); + let xor_value = xor_witness.as_mut_slice::(); + + for (idx, v) in xor_value.iter_mut().enumerate() { + *v = a_value[idx] ^ b_value[idx]; + } + + let mut rotate_witness = witness.new_column::(rotate); + let rotate_value = rotate_witness.as_mut_slice::(); + for (idx, v) in rotate_value.iter_mut().enumerate() { + *v = xor_value[idx].rotate_right(rotate_right_offset); + } + } + + builder.pop_namespace(); + + Ok(rotate) +} + +#[allow(clippy::too_many_arguments)] +pub fn blake3_g( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + a_in: OracleId, + b_in: OracleId, + c_in: OracleId, + d_in: OracleId, + mx: OracleId, + my: OracleId, + log_size: usize, +) -> Result<[OracleId; 4], anyhow::Error> { + builder.push_namespace(name); + + let ab = arithmetic::u32::add(builder, "a_in + b_in", a_in, b_in, Flags::Unchecked)?; + let a1 = arithmetic::u32::add(builder, "a_in + b_in + mx", ab, mx, Flags::Unchecked)?; + + let d1 = xor_rotate_right(builder, "(d_in ^ a1).rotate_right(16)", log_size, d_in, a1, 16u32)?; + + let c1 = arithmetic::u32::add(builder, "c_in + d1", c_in, d1, Flags::Unchecked)?; + + let b1 = xor_rotate_right(builder, "(b_in ^ c1).rotate_right(12)", log_size, b_in, c1, 12u32)?; + + let a1b1 = arithmetic::u32::add(builder, "a1 + b1", a1, b1, Flags::Unchecked)?; + let a2 = arithmetic::u32::add(builder, "a1 + b1 + my_in", a1b1, my, Flags::Unchecked)?; + + let d2 = xor_rotate_right(builder, "(d1 ^ a2).rotate_right(8)", log_size, d1, a2, 8u32)?; + + let c2 = arithmetic::u32::add(builder, "c1 + d2", c1, d2, Flags::Unchecked)?; + + let b2 = xor_rotate_right(builder, "(b1 ^ c2).rotate_right(7)", log_size, b1, c2, 7u32)?; + + builder.pop_namespace(); + + Ok([a2, b2, c2, d2]) +} + +#[cfg(test)] +mod tests { + use binius_core::constraint_system::validate::validate_witness; + use binius_field::BinaryField1b; + use binius_maybe_rayon::prelude::*; + + use crate::{ + blake3::blake3_g, + builder::ConstraintSystemBuilder, + unconstrained::{fixed_u32, unconstrained}, + }; + + type F1 = BinaryField1b; + + const LOG_SIZE: usize = 5; + + // The Blake3 mixing function, G, which mixes either a column or a diagonal. + // https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs + const fn g( + a_in: u32, + b_in: u32, + c_in: u32, + d_in: u32, + mx: u32, + my: u32, + ) -> (u32, u32, u32, u32) { + let a1 = a_in.wrapping_add(b_in).wrapping_add(mx); + let d1 = (d_in ^ a1).rotate_right(16); + let c1 = c_in.wrapping_add(d1); + let b1 = (b_in ^ c1).rotate_right(12); + + let a2 = a1.wrapping_add(b1).wrapping_add(my); + let d2 = (d1 ^ a2).rotate_right(8); + let c2 = c1.wrapping_add(d2); + let b2 = (b1 ^ c2).rotate_right(7); + + (a2, b2, c2, d2) + } + + #[test] + fn test_vector() { + // Let's use some fixed data input to check that our in-circuit computation + // produces same output as out-of-circuit one + let a = 0xaaaaaaaau32; + let b = 0xbbbbbbbbu32; + let c = 0xccccccccu32; + let d = 0xddddddddu32; + let mx = 0xffff00ffu32; + let my = 0xff00ffffu32; + + let (expected_0, expected_1, expected_2, expected_3) = g(a, b, c, d, mx, my); + + let size = 1 << LOG_SIZE; + + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + let a_in = fixed_u32::(&mut builder, "a", LOG_SIZE, vec![a; size]).unwrap(); + let b_in = fixed_u32::(&mut builder, "b", LOG_SIZE, vec![b; size]).unwrap(); + let c_in = fixed_u32::(&mut builder, "c", LOG_SIZE, vec![c; size]).unwrap(); + let d_in = fixed_u32::(&mut builder, "d", LOG_SIZE, vec![d; size]).unwrap(); + let mx_in = fixed_u32::(&mut builder, "mx", LOG_SIZE, vec![mx; size]).unwrap(); + let my_in = fixed_u32::(&mut builder, "my", LOG_SIZE, vec![my; size]).unwrap(); + + let output = + blake3_g(&mut builder, "g", a_in, b_in, c_in, d_in, mx_in, my_in, LOG_SIZE).unwrap(); + + if let Some(witness) = builder.witness() { + ( + witness.get::(output[0]).unwrap().as_slice::(), + witness.get::(output[1]).unwrap().as_slice::(), + witness.get::(output[2]).unwrap().as_slice::(), + witness.get::(output[3]).unwrap().as_slice::(), + ) + .into_par_iter() + .for_each(|(actual_0, actual_1, actual_2, actual_3)| { + assert_eq!(*actual_0, expected_0); + assert_eq!(*actual_1, expected_1); + assert_eq!(*actual_2, expected_2); + assert_eq!(*actual_3, expected_3); + }); + } + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); + } + + #[test] + fn test_random_input() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + let a_in = unconstrained::(&mut builder, "a", LOG_SIZE).unwrap(); + let b_in = unconstrained::(&mut builder, "b", LOG_SIZE).unwrap(); + let c_in = unconstrained::(&mut builder, "c", LOG_SIZE).unwrap(); + let d_in = unconstrained::(&mut builder, "d", LOG_SIZE).unwrap(); + let mx_in = unconstrained::(&mut builder, "mx", LOG_SIZE).unwrap(); + let my_in = unconstrained::(&mut builder, "my", LOG_SIZE).unwrap(); + + blake3_g(&mut builder, "g", a_in, b_in, c_in, d_in, mx_in, my_in, LOG_SIZE).unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); + } +} diff --git a/crates/circuits/src/builder/constraint_system.rs b/crates/circuits/src/builder/constraint_system.rs index ba8a2c865..e4010cafa 100644 --- a/crates/circuits/src/builder/constraint_system.rs +++ b/crates/circuits/src/builder/constraint_system.rs @@ -16,35 +16,28 @@ use binius_core::{ transparent::step_down::StepDown, witness::MultilinearExtensionIndex, }; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, TowerField, -}; +use binius_field::{as_packed_field::PackScalar, BinaryField1b}; use binius_math::ArithExpr; use binius_utils::bail; -use crate::builder::witness; +use crate::builder::{ + types::{F, U}, + witness, +}; #[derive(Default)] -pub struct ConstraintSystemBuilder<'arena, U, F> -where - U: UnderlierType + PackScalar, - F: TowerField, -{ +pub struct ConstraintSystemBuilder<'arena> { oracles: Rc>>, constraints: ConstraintSetBuilder, non_zero_oracle_ids: Vec, flushes: Vec, step_down_dedup: HashMap<(usize, usize), OracleId>, - witness: Option>, + witness: Option>, next_channel_id: ChannelId, namespace_path: Vec, } -impl<'arena, U, F> ConstraintSystemBuilder<'arena, U, F> -where - U: UnderlierType + PackScalar, - F: TowerField, -{ +impl<'arena> ConstraintSystemBuilder<'arena> { pub fn new() -> Self { Self::default() } @@ -79,7 +72,7 @@ where }) } - pub fn witness(&mut self) -> Option<&mut witness::Builder<'arena, U, F>> { + pub fn witness(&mut self) -> Option<&mut witness::Builder<'arena>> { self.witness.as_mut() } @@ -354,7 +347,7 @@ where /// /// let log_size = 14; /// - /// let mut builder = ConstraintSystemBuilder::::new(); + /// let mut builder = ConstraintSystemBuilder::new(); /// builder.push_namespace("a"); /// let x = builder.add_committed("x", log_size, BinaryField1b::TOWER_LEVEL); /// builder.push_namespace("b"); diff --git a/crates/circuits/src/builder/mod.rs b/crates/circuits/src/builder/mod.rs index c0b59130b..7ee629f3c 100644 --- a/crates/circuits/src/builder/mod.rs +++ b/crates/circuits/src/builder/mod.rs @@ -1,6 +1,8 @@ // Copyright 2024-2025 Irreducible Inc. pub mod constraint_system; +pub mod test_utils; +pub mod types; pub mod witness; pub use constraint_system::ConstraintSystemBuilder; diff --git a/crates/circuits/src/builder/test_utils.rs b/crates/circuits/src/builder/test_utils.rs new file mode 100644 index 000000000..0aa80ad92 --- /dev/null +++ b/crates/circuits/src/builder/test_utils.rs @@ -0,0 +1,23 @@ +// Copyright 2025 Irreducible Inc. + +use binius_core::constraint_system::{channel::Boundary, validate::validate_witness}; + +use super::{types::F, ConstraintSystemBuilder}; + +pub fn test_circuit( + build_circuit: fn(&mut ConstraintSystemBuilder) -> Result>, anyhow::Error>, +) -> Result<(), anyhow::Error> { + let mut verifier_builder = ConstraintSystemBuilder::new(); + let verifier_boundaries = build_circuit(&mut verifier_builder)?; + let verifier_constraint_system = verifier_builder.build()?; + + let allocator = bumpalo::Bump::new(); + let mut prover_builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let prover_boundaries = build_circuit(&mut prover_builder)?; + let prover_witness = prover_builder.take_witness()?; + let _prover_constraint_system = prover_builder.build()?; + + assert_eq!(verifier_boundaries, prover_boundaries); + validate_witness(&verifier_constraint_system, &verifier_boundaries, &prover_witness)?; + Ok(()) +} diff --git a/crates/circuits/src/builder/types.rs b/crates/circuits/src/builder/types.rs new file mode 100644 index 000000000..ae33e7e4f --- /dev/null +++ b/crates/circuits/src/builder/types.rs @@ -0,0 +1,4 @@ +// Copyright 2025 Irreducible Inc. + +pub type F = binius_field::BinaryField128b; +pub type U = binius_field::arch::OptimalUnderlier; diff --git a/crates/circuits/src/builder/witness.rs b/crates/circuits/src/builder/witness.rs index 2209a36f4..77e83868d 100644 --- a/crates/circuits/src/builder/witness.rs +++ b/crates/circuits/src/builder/witness.rs @@ -10,35 +10,33 @@ use binius_core::{ use binius_field::{ as_packed_field::{PackScalar, PackedType}, underlier::WithUnderlier, - ExtensionField, Field, PackedField, TowerField, + ExtensionField, PackedField, TowerField, }; use binius_math::MultilinearExtension; use binius_utils::bail; use bytemuck::{must_cast_slice, must_cast_slice_mut, Pod}; -pub struct Builder<'arena, U: PackScalar, FW: TowerField> { +use super::types::{F, U}; + +pub struct Builder<'arena> { bump: &'arena bumpalo::Bump, - oracles: Rc>>, + oracles: Rc>>, #[allow(clippy::type_complexity)] - entries: Rc>>>>, + entries: Rc>>>>, } -struct WitnessBuilderEntry<'arena, U: PackScalar, FW: Field> { - witness: Result>, binius_math::Error>, +struct WitnessBuilderEntry<'arena> { + witness: Result>, binius_math::Error>, tower_level: usize, data: &'arena [U], } -impl<'arena, U, FW> Builder<'arena, U, FW> -where - U: PackScalar, - FW: TowerField, -{ +impl<'arena> Builder<'arena> { pub fn new( allocator: &'arena bumpalo::Bump, - oracles: Rc>>, + oracles: Rc>>, ) -> Self { Self { bump: allocator, @@ -47,10 +45,10 @@ where } } - pub fn new_column(&self, id: OracleId) -> EntryBuilder<'arena, U, FW, FS> + pub fn new_column(&self, id: OracleId) -> EntryBuilder<'arena, FS> where U: PackScalar, - FW: ExtensionField, + F: ExtensionField, { let oracles = self.oracles.borrow(); let log_rows = oracles.n_vars(id); @@ -69,10 +67,10 @@ where &self, id: OracleId, default: FS, - ) -> EntryBuilder<'arena, U, FW, FS> + ) -> EntryBuilder<'arena, FS> where U: PackScalar, - FW: ExtensionField, + F: ExtensionField, { let oracles = self.oracles.borrow(); let log_rows = oracles.n_vars(id); @@ -88,10 +86,11 @@ where } } - pub fn get(&self, id: OracleId) -> Result, Error> + pub fn get(&self, id: OracleId) -> Result, Error> where + FS: TowerField, U: PackScalar, - FW: ExtensionField, + F: ExtensionField, { let entries = self.entries.borrow(); let oracles = self.oracles.borrow(); @@ -122,11 +121,11 @@ where pub fn set( &self, id: OracleId, - entry: WitnessEntry<'arena, U, FS>, + entry: WitnessEntry<'arena, FS>, ) -> Result<(), Error> where U: PackScalar, - FW: ExtensionField, + F: ExtensionField, { let oracles = self.oracles.borrow(); if !oracles.is_valid_oracle_id(id) { @@ -145,7 +144,7 @@ where Ok(()) } - pub fn build(self) -> Result, Error> { + pub fn build(self) -> Result, Error> { let mut result = MultilinearExtensionIndex::new(); let entries = Rc::into_inner(self.entries) .ok_or_else(|| anyhow!("Failed to build. There are still entries refs. Make sure there are no pending column insertions."))? @@ -160,26 +159,37 @@ where } #[derive(Debug, Clone, Copy)] -pub struct WitnessEntry<'arena, U: PackScalar, FS: TowerField> { +pub struct WitnessEntry<'arena, FS: TowerField> +where + U: PackScalar, +{ data: &'arena [U], log_rows: usize, _marker: PhantomData, } -impl<'arena, U: PackScalar, FS: TowerField> WitnessEntry<'arena, U, FS> { +impl<'arena, FS: TowerField> WitnessEntry<'arena, FS> +where + U: PackScalar, +{ #[inline] pub fn packed(&self) -> &'arena [PackedType] { WithUnderlier::from_underliers_ref(self.data) } - pub const fn repacked(&self) -> WitnessEntry<'arena, U, FW> + #[inline] + pub const fn as_slice(&self) -> &'arena [T] { + must_cast_slice(self.data) + } + + pub const fn repacked(&self) -> WitnessEntry<'arena, FE> where - FW: TowerField + ExtensionField, - U: PackScalar, + FE: TowerField + ExtensionField, + U: PackScalar, { WitnessEntry { data: self.data, - log_rows: self.log_rows - >::LOG_DEGREE, + log_rows: self.log_rows - >::LOG_DEGREE, _marker: PhantomData, } } @@ -189,38 +199,36 @@ impl<'arena, U: PackScalar, FS: TowerField> WitnessEntry<'arena, U, FS> { } } -impl<'arena, U: PackScalar + Pod, FS: TowerField> WitnessEntry<'arena, U, FS> { - #[inline] - pub const fn as_slice(&self) -> &'arena [T] { - must_cast_slice(self.data) - } -} - -pub struct EntryBuilder<'arena, U, FW, FS> +pub struct EntryBuilder<'arena, FS> where - U: PackScalar + PackScalar, FS: TowerField, - FW: TowerField + ExtensionField, + U: PackScalar, + F: ExtensionField, { _marker: PhantomData, #[allow(clippy::type_complexity)] - entries: Rc>>>>, + entries: Rc>>>>, id: OracleId, log_rows: usize, data: Option<&'arena mut [U]>, } -impl EntryBuilder<'_, U, FW, FS> +impl EntryBuilder<'_, FS> where - U: PackScalar + PackScalar, FS: TowerField, - FW: TowerField + ExtensionField, + U: PackScalar, + F: ExtensionField, { #[inline] pub fn packed(&mut self) -> &mut [PackedType] { PackedType::::from_underliers_ref_mut(self.underliers()) } + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [T] { + must_cast_slice_mut(self.underliers()) + } + #[inline] fn underliers(&mut self) -> &mut [U] { self.data @@ -229,23 +237,11 @@ where } } -impl EntryBuilder<'_, U, FW, FS> -where - U: PackScalar + PackScalar + Pod, - FS: TowerField, - FW: TowerField + ExtensionField, -{ - #[inline] - pub fn as_mut_slice(&mut self) -> &mut [T] { - must_cast_slice_mut(self.underliers()) - } -} - -impl Drop for EntryBuilder<'_, U, FW, FS> +impl Drop for EntryBuilder<'_, FS> where - U: PackScalar + PackScalar, FS: TowerField, - FW: TowerField + ExtensionField, + U: PackScalar, + F: ExtensionField, { fn drop(&mut self) { let data = Option::take(&mut self.data).expect("data is always Some until this point"); diff --git a/crates/circuits/src/collatz.rs b/crates/circuits/src/collatz.rs index df1f2d052..b76fdf2fd 100644 --- a/crates/circuits/src/collatz.rs +++ b/crates/circuits/src/collatz.rs @@ -10,7 +10,14 @@ use binius_field::{ use binius_macros::arith_expr; use bytemuck::Pod; -use crate::{arithmetic, builder::ConstraintSystemBuilder, transparent}; +use crate::{ + arithmetic, + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, + transparent, +}; pub type Advice = (usize, usize); @@ -37,9 +44,9 @@ impl Collatz { (self.evens.len(), self.odds.len()) } - pub fn build( + pub fn build( self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, advice: Advice, ) -> Result>, anyhow::Error> where @@ -58,16 +65,12 @@ impl Collatz { Ok(boundaries) } - fn even( + fn even( &self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, channel: ChannelId, count: usize, - ) -> Result<(), anyhow::Error> - where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, - { + ) -> Result<(), anyhow::Error> { let log_1b_rows = 5 + binius_utils::checked_arithmetics::log2_ceil_usize(count); let even = builder.add_committed("even", log_1b_rows, BinaryField1b::TOWER_LEVEL); if let Some(witness) = builder.witness() { @@ -90,16 +93,12 @@ impl Collatz { Ok(()) } - fn odd( + fn odd( &self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, channel: ChannelId, count: usize, - ) -> Result<(), anyhow::Error> - where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, - { + ) -> Result<(), anyhow::Error> { let log_32b_rows = binius_utils::checked_arithmetics::log2_ceil_usize(count); let log_1b_rows = 5 + log_32b_rows; @@ -136,10 +135,7 @@ impl Collatz { Ok(()) } - fn get_boundaries(&self, channel_id: usize) -> Vec> - where - F: TowerField + From, - { + fn get_boundaries(&self, channel_id: usize) -> Vec> { vec![ Boundary { channel_id, @@ -179,15 +175,11 @@ pub fn collatz_orbit(x0: u32) -> Vec { res } -pub fn ensure_odd( - builder: &mut ConstraintSystemBuilder, +pub fn ensure_odd( + builder: &mut ConstraintSystemBuilder, input: OracleId, count: usize, -) -> Result<(), anyhow::Error> -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result<(), anyhow::Error> { let log_32b_rows = builder.log_rows([input])? - 5; let lsb = arithmetic::u32::select_bit(builder, "lsb", input, 0)?; let selector = transparent::step_down(builder, "count", log_32b_rows, count)?; @@ -201,28 +193,17 @@ where #[cfg(test)] mod tests { - use binius_core::constraint_system::validate::validate_witness; - use binius_field::{arch::OptimalUnderlier, BinaryField128b}; - - use crate::{builder::ConstraintSystemBuilder, collatz::Collatz}; + use crate::{builder::test_utils::test_circuit, collatz::Collatz}; #[test] fn test_collatz() { - let allocator = bumpalo::Bump::new(); - let mut builder = - ConstraintSystemBuilder::::new_with_witness( - &allocator, - ); - - let x0 = 9999999; - - let mut collatz = Collatz::new(x0); - let advice = collatz.init_prover(); - - let boundaries = collatz.build(&mut builder, advice).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let x0 = 9999999; + let mut collatz = Collatz::new(x0); + let advice = collatz.init_prover(); + let boundaries = collatz.build(builder, advice)?; + Ok(boundaries) + }) + .unwrap(); } } diff --git a/crates/circuits/src/groestl.rs b/crates/circuits/src/groestl.rs index d1ab36d39..598e6de16 100644 --- a/crates/circuits/src/groestl.rs +++ b/crates/circuits/src/groestl.rs @@ -378,3 +378,28 @@ fn s_box(x: AESTowerField8b) -> AESTowerField8b { let idx = u8::from(x) as usize; AESTowerField8b::from(S_BOX[idx]) } + +#[cfg(test)] +mod tests { + use binius_core::constraint_system::validate::validate_witness; + use binius_field::{arch::OptimalUnderlier, AESTowerField16b}; + + use super::groestl_p_permutation; + use crate::builder::ConstraintSystemBuilder; + + #[test] + fn test_groestl() { + let allocator = bumpalo::Bump::new(); + let mut builder = + ConstraintSystemBuilder::::new_with_witness( + &allocator, + ); + let log_size = 9; + let _state_out = groestl_p_permutation(&mut builder, log_size).unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } +} diff --git a/crates/circuits/src/keccakf.rs b/crates/circuits/src/keccakf.rs index a73c5cea8..89d11d560 100644 --- a/crates/circuits/src/keccakf.rs +++ b/crates/circuits/src/keccakf.rs @@ -8,14 +8,19 @@ use binius_core::{ transparent::multilinear_extension::MultilinearExtensionTransparent, }; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::{UnderlierType, WithUnderlier}, - BinaryField1b, BinaryField64b, ExtensionField, PackedField, TowerField, + as_packed_field::PackedType, underlier::WithUnderlier, BinaryField1b, BinaryField64b, Field, + PackedField, TowerField, }; use binius_macros::arith_expr; use bytemuck::{pod_collect_to_vec, Pod}; -use crate::{builder::ConstraintSystemBuilder, transparent::step_down}; +use crate::{ + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, + transparent::step_down, +}; #[derive(Default, Clone, Copy)] pub struct KeccakfState(pub [u64; STATE_SIZE]); @@ -25,15 +30,11 @@ pub struct KeccakfOracles { pub output: [OracleId; STATE_SIZE], } -pub fn keccakf( - builder: &mut ConstraintSystemBuilder, - input_witness: Option>, +pub fn keccakf( + builder: &mut ConstraintSystemBuilder, + input_witness: &Option>, log_size: usize, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar + PackScalar, - F: TowerField + ExtensionField, -{ +) -> Result { let internal_log_size = log_size + LOG_BIT_ROWS_PER_PERMUTATION; let round_consts_single: [OracleId; ROUNDS_PER_STATE_ROW] = array::try_from_fn(|round_within_row| { @@ -124,7 +125,7 @@ where builder.add_projected( "output", packed_state_out[xy], - vec![F::ONE; LOG_STATE_ROWS_PER_PERMUTATION], + vec![Field::ONE; LOG_STATE_ROWS_PER_PERMUTATION], ProjectionVariant::FirstVars, ) })?; @@ -135,7 +136,7 @@ where "c", internal_log_size, array::from_fn::<_, 5, _>(|offset| { - (state[round_within_row][x + 5 * offset], F::ONE) + (state[round_within_row][x + 5 * offset], Field::ONE) }), ) }) @@ -159,8 +160,8 @@ where "d", internal_log_size, [ - (c[round_within_row][(x + 4) % 5], F::ONE), - (c_shift[round_within_row][(x + 1) % 5], F::ONE), + (c[round_within_row][(x + 4) % 5], Field::ONE), + (c_shift[round_within_row][(x + 1) % 5], Field::ONE), ], ) }) @@ -174,8 +175,8 @@ where format!("a_theta[{xy}]"), internal_log_size, [ - (state[round_within_row][xy], F::ONE), - (d[round_within_row][x], F::ONE), + (state[round_within_row][xy], Field::ONE), + (d[round_within_row][x], Field::ONE), ], ) }) @@ -504,3 +505,23 @@ const KECCAKF_RC: [u64; ROUNDS_PER_PERMUTATION] = [ 0x0000000080000001, 0x8000000080008008, ]; + +#[cfg(test)] +mod tests { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + use super::{keccakf, KeccakfState}; + use crate::builder::test_utils::test_circuit; + + #[test] + fn test_keccakf() { + test_circuit(|builder| { + let log_size = 5; + let mut rng = StdRng::seed_from_u64(0); + let input_states = vec![KeccakfState(rng.gen())]; + let _state_out = keccakf(builder, &Some(input_states), log_size)?; + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/circuits/src/lasso/batch.rs b/crates/circuits/src/lasso/batch.rs index 8fb6d0697..52c98e297 100644 --- a/crates/circuits/src/lasso/batch.rs +++ b/crates/circuits/src/lasso/batch.rs @@ -4,12 +4,15 @@ use anyhow::Ok; use binius_core::oracle::OracleId; use binius_field::{ as_packed_field::{PackScalar, PackedType}, - BinaryField1b, ExtensionField, PackedFieldIndexable, TowerField, + ExtensionField, PackedFieldIndexable, TowerField, }; use itertools::Itertools; use super::lasso::lasso; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; pub struct LookupBatch { lookup_us: Vec>, u_to_t_mappings: Vec>, @@ -48,19 +51,16 @@ impl LookupBatch { self.lookup_col_lens.push(lookup_u_col_len); } - pub fn execute( - mut self, - builder: &mut ConstraintSystemBuilder, - ) -> Result<(), anyhow::Error> + pub fn execute(mut self, builder: &mut ConstraintSystemBuilder) -> Result<(), anyhow::Error> where - U: PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, FC: TowerField, - F: ExtensionField + TowerField, + U: PackScalar, + F: ExtensionField, + PackedType: PackedFieldIndexable, { let channel = builder.add_channel(); - lasso::<_, _, FC>( + lasso::( builder, "batched lasso", &self.lookup_col_lens, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs index 43349e09c..302d2384b 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs @@ -3,14 +3,7 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b}; use crate::{ builder::ConstraintSystemBuilder, @@ -19,32 +12,16 @@ use crate::{ type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; -pub fn byte_sliced_add>( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_add: Sized>>( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, - x_in: &Level::Data, - y_in: &Level::Data, + x_in: &Level::Data, + y_in: &Level::Data, carry_in: OracleId, log_size: usize, lookup_batch_add: &mut LookupBatch, -) -> Result<(OracleId, Level::Data), anyhow::Error> -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, - Level::Data: Sized, -{ +) -> Result<(OracleId, Level::Data), anyhow::Error> { if Level::WIDTH == 1 { let (carry_out, sum) = u8add(builder, lookup_batch_add, name, x_in[0], y_in[0], carry_in, log_size)?; @@ -58,7 +35,7 @@ where let (lower_half_x, upper_half_x) = Level::split(x_in); let (lower_half_y, upper_half_y) = Level::split(y_in); - let (internal_carry, lower_sum) = byte_sliced_add::<_, _, Level::Base>( + let (internal_carry, lower_sum) = byte_sliced_add::( builder, format!("lower sum {}b", Level::Base::WIDTH), lower_half_x, @@ -68,7 +45,7 @@ where lookup_batch_add, )?; - let (carry_out, upper_sum) = byte_sliced_add::<_, _, Level::Base>( + let (carry_out, upper_sum) = byte_sliced_add::( builder, format!("upper sum {}b", Level::Base::WIDTH), upper_half_x, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs index 881aa1b92..ab7b864c2 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs @@ -3,14 +3,7 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b}; use super::byte_sliced_add; use crate::{ @@ -20,34 +13,18 @@ use crate::{ type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_add_carryfree>( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_add_carryfree: Sized>>( + builder: &mut ConstraintSystemBuilder, name: impl ToString, - x_in: &Level::Data, - y_in: &Level::Data, + x_in: &Level::Data, + y_in: &Level::Data, carry_in: OracleId, log_size: usize, lookup_batch_add: &mut LookupBatch, lookup_batch_add_carryfree: &mut LookupBatch, -) -> Result -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, - Level::Data: Sized, -{ +) -> Result, anyhow::Error> { if Level::WIDTH == 1 { let sum = u8add_carryfree( builder, @@ -68,7 +45,7 @@ where let (lower_half_x, upper_half_x) = Level::split(x_in); let (lower_half_y, upper_half_y) = Level::split(y_in); - let (internal_carry, lower_sum) = byte_sliced_add::<_, _, Level::Base>( + let (internal_carry, lower_sum) = byte_sliced_add::( builder, format!("lower sum {}b", Level::Base::WIDTH), lower_half_x, @@ -78,7 +55,7 @@ where lookup_batch_add, )?; - let upper_sum = byte_sliced_add_carryfree::<_, _, Level::Base>( + let upper_sum = byte_sliced_add_carryfree::( builder, format!("upper sum {}b", Level::Base::WIDTH), upper_half_x, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs index b14baa3e1..9000c4572 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs @@ -3,14 +3,7 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b}; use crate::{ builder::ConstraintSystemBuilder, @@ -19,34 +12,18 @@ use crate::{ type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_double_conditional_increment>( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_double_conditional_increment: Sized>>( + builder: &mut ConstraintSystemBuilder, name: impl ToString, - x_in: &Level::Data, + x_in: &Level::Data, first_carry_in: OracleId, second_carry_in: OracleId, log_size: usize, zero_oracle_carry: usize, lookup_batch_dci: &mut LookupBatch, -) -> Result<(OracleId, Level::Data), anyhow::Error> -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, - Level::Data: Sized, -{ +) -> Result<(OracleId, Level::Data), anyhow::Error> { if Level::WIDTH == 1 { let (carry_out, sum) = u8_double_conditional_increment( builder, @@ -66,7 +43,7 @@ where let (lower_half_x, upper_half_x) = Level::split(x_in); - let (internal_carry, lower_sum) = byte_sliced_double_conditional_increment::<_, _, Level::Base>( + let (internal_carry, lower_sum) = byte_sliced_double_conditional_increment::( builder, format!("lower sum {}b", Level::Base::WIDTH), lower_half_x, @@ -77,7 +54,7 @@ where lookup_batch_dci, )?; - let (carry_out, upper_sum) = byte_sliced_double_conditional_increment::<_, _, Level::Base>( + let (carry_out, upper_sum) = byte_sliced_double_conditional_increment::( builder, format!("upper sum {}b", Level::Base::WIDTH), upper_half_x, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs index ead52cccf..d480de5c2 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs @@ -4,59 +4,32 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::{oracle::OracleId, transparent::constant::Constant}; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::{UnderlierType, WithUnderlier}, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, + tower_levels::TowerLevel, underlier::WithUnderlier, BinaryField32b, BinaryField8b, TowerField, }; use binius_macros::arith_expr; -use bytemuck::Pod; use super::{byte_sliced_add_carryfree, byte_sliced_mul}; use crate::{ - builder::ConstraintSystemBuilder, + builder::{types::F, ConstraintSystemBuilder}, lasso::{ batch::LookupBatch, lookups::u8_arithmetic::{add_carryfree_lookup, add_lookup, dci_lookup, mul_lookup}, }, }; -type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_modular_mul< - U, - F, - LevelIn: TowerLevel, - LevelOut: TowerLevel, ->( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_modular_mul>( + builder: &mut ConstraintSystemBuilder, name: impl ToString, - mult_a: &LevelIn::Data, - mult_b: &LevelIn::Data, + mult_a: &LevelIn::Data, + mult_b: &LevelIn::Data, modulus_input: &[u8], log_size: usize, zero_byte_oracle: OracleId, zero_carry_oracle: OracleId, -) -> Result -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, - ::Underlier: From, -{ +) -> Result, anyhow::Error> { builder.push_namespace(name); let lookup_t_mul = mul_lookup(builder, "mul table")?; @@ -88,12 +61,14 @@ where "modulus", Constant::new( log_size, - F::from_underlier(>::into(modulus_input[byte_idx])), + ::from_underlier(::Underlier, + >>::into(modulus_input[byte_idx])), ), )?; } - let ab = byte_sliced_mul::<_, _, LevelIn, LevelOut>( + let ab = byte_sliced_mul::( builder, "ab", mult_a, @@ -166,7 +141,7 @@ where } } - let qm = byte_sliced_mul::<_, _, LevelIn, LevelOut>( + let qm = byte_sliced_mul::( builder, "qm", "ient, @@ -183,7 +158,7 @@ where repeating_zero[byte_idx] = zero_byte_oracle; } - let qm_plus_r = byte_sliced_add_carryfree::<_, _, LevelOut>( + let qm_plus_r = byte_sliced_add_carryfree::( builder, "hi*lo", &qm, @@ -194,12 +169,12 @@ where &mut lookup_batch_add_carryfree, )?; - lookup_batch_mul.execute::<_, _, BinaryField32b>(builder)?; - lookup_batch_add.execute::<_, _, BinaryField32b>(builder)?; - lookup_batch_add_carryfree.execute::<_, _, BinaryField32b>(builder)?; + lookup_batch_mul.execute::(builder)?; + lookup_batch_add.execute::(builder)?; + lookup_batch_add_carryfree.execute::(builder)?; if LevelIn::WIDTH != 1 { - lookup_batch_dci.execute::<_, _, BinaryField32b>(builder)?; + lookup_batch_dci.execute::(builder)?; } let consistency = arith_expr!([x, y] = x - y); diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs index 03742d347..de386d7ad 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs @@ -3,14 +3,7 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{tower_levels::TowerLevel, BinaryField8b}; use super::{byte_sliced_add, byte_sliced_double_conditional_increment}; use crate::{ @@ -18,41 +11,20 @@ use crate::{ lasso::{batch::LookupBatch, u8mul::u8mul_bytesliced}, }; -type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_mul< - U, - F, - LevelIn: TowerLevel, - LevelOut: TowerLevel, ->( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_mul>( + builder: &mut ConstraintSystemBuilder, name: impl ToString, - mult_a: &LevelIn::Data, - mult_b: &LevelIn::Data, + mult_a: &LevelIn::Data, + mult_b: &LevelIn::Data, log_size: usize, zero_carry_oracle: OracleId, lookup_batch_mul: &mut LookupBatch, lookup_batch_add: &mut LookupBatch, lookup_batch_dci: &mut LookupBatch, -) -> Result -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result, anyhow::Error> { if LevelIn::WIDTH == 1 { let result_of_u8mul = u8mul_bytesliced( builder, @@ -77,7 +49,7 @@ where let (mult_a_low, mult_a_high) = LevelIn::split(mult_a); let (mult_b_low, mult_b_high) = LevelIn::split(mult_b); - let a_lo_b_lo = byte_sliced_mul::<_, _, LevelIn::Base, LevelOut::Base>( + let a_lo_b_lo = byte_sliced_mul::( builder, format!("lo*lo {}b", LevelIn::Base::WIDTH), mult_a_low, @@ -88,7 +60,7 @@ where lookup_batch_add, lookup_batch_dci, )?; - let a_lo_b_hi = byte_sliced_mul::<_, _, LevelIn::Base, LevelOut::Base>( + let a_lo_b_hi = byte_sliced_mul::( builder, format!("lo*hi {}b", LevelIn::Base::WIDTH), mult_a_low, @@ -99,7 +71,7 @@ where lookup_batch_add, lookup_batch_dci, )?; - let a_hi_b_lo = byte_sliced_mul::<_, _, LevelIn::Base, LevelOut::Base>( + let a_hi_b_lo = byte_sliced_mul::( builder, format!("hi*lo {}b", LevelIn::Base::WIDTH), mult_a_high, @@ -110,7 +82,7 @@ where lookup_batch_add, lookup_batch_dci, )?; - let a_hi_b_hi = byte_sliced_mul::<_, _, LevelIn::Base, LevelOut::Base>( + let a_hi_b_hi = byte_sliced_mul::( builder, format!("hi*hi {}b", LevelIn::Base::WIDTH), mult_a_high, @@ -122,7 +94,7 @@ where lookup_batch_dci, )?; - let (karatsuba_carry_for_high_chunk, karatsuba_term) = byte_sliced_add::<_, _, LevelIn>( + let (karatsuba_carry_for_high_chunk, karatsuba_term) = byte_sliced_add::( builder, format!("karastsuba addition {}b", LevelIn::WIDTH), &a_lo_b_hi, @@ -135,7 +107,7 @@ where let (a_lo_b_lo_lower_half, a_lo_b_lo_upper_half) = LevelIn::split(&a_lo_b_lo); let (a_hi_b_hi_lower_half, a_hi_b_hi_upper_half) = LevelIn::split(&a_hi_b_hi); - let (additional_carry_for_high_chunk, final_middle_chunk) = byte_sliced_add::<_, _, LevelIn>( + let (additional_carry_for_high_chunk, final_middle_chunk) = byte_sliced_add::( builder, format!("post kartsuba middle term addition {}b", LevelIn::WIDTH), &karatsuba_term, @@ -145,7 +117,7 @@ where lookup_batch_add, )?; - let (_, final_high_chunk) = byte_sliced_double_conditional_increment::<_, _, LevelIn::Base>( + let (_, final_high_chunk) = byte_sliced_double_conditional_increment::( builder, format!("high chunk DCI {}b", LevelIn::Base::WIDTH), a_hi_b_hi_upper_half, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs index 1ceaae046..411a5c151 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs @@ -3,19 +3,18 @@ use std::{array, fmt::Debug}; use alloy_primitives::U512; -use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; +use binius_core::oracle::OracleId; use binius_field::{ - arch::OptimalUnderlier, tower_levels::TowerLevel, BinaryField128b, BinaryField1b, - BinaryField32b, BinaryField8b, Field, TowerField, + tower_levels::TowerLevel, BinaryField1b, BinaryField32b, BinaryField8b, Field, TowerField, }; -use rand::{rngs::ThreadRng, thread_rng, Rng}; +use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng}; use super::{ byte_sliced_add, byte_sliced_add_carryfree, byte_sliced_double_conditional_increment, byte_sliced_modular_mul, byte_sliced_mul, }; use crate::{ - builder::ConstraintSystemBuilder, + builder::test_utils::test_circuit, lasso::{ batch::LookupBatch, lookups::u8_arithmetic::{add_carryfree_lookup, add_lookup, dci_lookup, mul_lookup}, @@ -27,297 +26,230 @@ use crate::{ type B8 = BinaryField8b; type B32 = BinaryField32b; -pub fn random_u512(rng: &mut ThreadRng) -> U512 { +pub fn random_u512(rng: &mut impl Rng) -> U512 { let limbs = array::from_fn(|_| rng.gen()); U512::from_limbs(limbs) } pub fn test_bytesliced_add() where - TL: TowerLevel, + TL: TowerLevel, { - type U = OptimalUnderlier; - type F = BinaryField128b; - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - - let x_in = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "x", log_size).unwrap() - }); - let y_in = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "y", log_size).unwrap() - }); - let c_in = unconstrained::<_, _, BinaryField1b>(&mut builder, "cin first", log_size).unwrap(); - - let lookup_t_add = add_lookup(&mut builder, "add table").unwrap(); - - let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); - let _sum_and_cout = byte_sliced_add::<_, _, TL>( - &mut builder, - "lasso_bytesliced_add", - &x_in, - &y_in, - c_in, - log_size, - &mut lookup_batch_add, - ) + test_circuit(|builder| { + let log_size = 14; + let x_in = TL::from_fn(|_| unconstrained::(builder, "x", log_size).unwrap()); + let y_in = TL::from_fn(|_| unconstrained::(builder, "y", log_size).unwrap()); + let c_in = unconstrained::(builder, "cin first", log_size)?; + let lookup_t_add = add_lookup(builder, "add table")?; + let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); + let _sum_and_cout = byte_sliced_add::( + builder, + "lasso_bytesliced_add", + &x_in, + &y_in, + c_in, + log_size, + &mut lookup_batch_add, + )?; + lookup_batch_add.execute::(builder)?; + Ok(vec![]) + }) .unwrap(); - - lookup_batch_add.execute::<_, _, B32>(&mut builder).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } pub fn test_bytesliced_add_carryfree() where - TL: TowerLevel, + TL: TowerLevel, { - type U = OptimalUnderlier; - type F = BinaryField128b; - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - let x_in = array::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL)); - let y_in = array::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL)); - let c_in = builder.add_committed("c", log_size, BinaryField1b::TOWER_LEVEL); - - if let Some(witness) = builder.witness() { - let mut x_in: [_; WIDTH] = - array::from_fn(|byte_idx| witness.new_column::(x_in[byte_idx])); - let mut y_in: [_; WIDTH] = - array::from_fn(|byte_idx| witness.new_column::(y_in[byte_idx])); - let mut c_in = witness.new_column::(c_in); - - let x_in_bytes_u8: [_; WIDTH] = x_in.each_mut().map(|col| col.as_mut_slice::()); - let y_in_bytes_u8: [_; WIDTH] = y_in.each_mut().map(|col| col.as_mut_slice::()); - let c_in_u8 = c_in.as_mut_slice::(); - - for row_idx in 0..1 << log_size { - let mut rng = thread_rng(); - let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8); - let mut x = random_u512(&mut rng); - x &= input_bitmask; - let mut y = random_u512(&mut rng); - y &= input_bitmask; - - let mut c: bool = rng.gen(); - - while (x + y + U512::from(c)) > input_bitmask { - x = random_u512(&mut rng); + test_circuit(|builder| { + let log_size = 14; + let x_in = + TL::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL)); + let y_in = + TL::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL)); + let c_in = builder.add_committed("c", log_size, BinaryField1b::TOWER_LEVEL); + + if let Some(witness) = builder.witness() { + let mut x_in: [_; WIDTH] = + array::from_fn(|byte_idx| witness.new_column::(x_in[byte_idx])); + let mut y_in: [_; WIDTH] = + array::from_fn(|byte_idx| witness.new_column::(y_in[byte_idx])); + let mut c_in = witness.new_column::(c_in); + + let x_in_bytes_u8: [_; WIDTH] = x_in.each_mut().map(|col| col.as_mut_slice::()); + let y_in_bytes_u8: [_; WIDTH] = y_in.each_mut().map(|col| col.as_mut_slice::()); + let c_in_u8 = c_in.as_mut_slice::(); + + for row_idx in 0..1 << log_size { + let mut rng = thread_rng(); + let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8); + let mut x = random_u512(&mut rng); x &= input_bitmask; - y = random_u512(&mut rng); + let mut y = random_u512(&mut rng); y &= input_bitmask; - c = rng.gen(); - } - for byte_idx in 0..WIDTH { - x_in_bytes_u8[byte_idx][row_idx] = x.byte(byte_idx); + let mut c: bool = rng.gen(); - y_in_bytes_u8[byte_idx][row_idx] = y.byte(byte_idx); - } + while (x + y + U512::from(c)) > input_bitmask { + x = random_u512(&mut rng); + x &= input_bitmask; + y = random_u512(&mut rng); + y &= input_bitmask; + c = rng.gen(); + } - c_in_u8[row_idx / 8] |= (c as u8) << (row_idx % 8); - } - } + for byte_idx in 0..WIDTH { + x_in_bytes_u8[byte_idx][row_idx] = x.byte(byte_idx); - let lookup_t_add = add_lookup(&mut builder, "add table").unwrap(); - let lookup_t_add_carryfree = add_carryfree_lookup(&mut builder, "add table").unwrap(); + y_in_bytes_u8[byte_idx][row_idx] = y.byte(byte_idx); + } - let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); - let mut lookup_batch_add_carryfree = LookupBatch::new([lookup_t_add_carryfree]); + c_in_u8[row_idx / 8] |= (c as u8) << (row_idx % 8); + } + } - let _sum_and_cout = byte_sliced_add_carryfree::<_, _, TL>( - &mut builder, - "lasso_bytesliced_add_carryfree", - &x_in, - &y_in, - c_in, - log_size, - &mut lookup_batch_add, - &mut lookup_batch_add_carryfree, - ) + let lookup_t_add = add_lookup(builder, "add table")?; + let lookup_t_add_carryfree = add_carryfree_lookup(builder, "add table")?; + + let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); + let mut lookup_batch_add_carryfree = LookupBatch::new([lookup_t_add_carryfree]); + + let _sum_and_cout = byte_sliced_add_carryfree::( + builder, + "lasso_bytesliced_add_carryfree", + &x_in, + &y_in, + c_in, + log_size, + &mut lookup_batch_add, + &mut lookup_batch_add_carryfree, + )?; + + lookup_batch_add.execute::(builder)?; + lookup_batch_add_carryfree.execute::(builder)?; + Ok(vec![]) + }) .unwrap(); - - lookup_batch_add.execute::<_, _, B32>(&mut builder).unwrap(); - lookup_batch_add_carryfree - .execute::<_, _, B32>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } pub fn test_bytesliced_double_conditional_increment() where - TL: TowerLevel, + TL: TowerLevel, { - type U = OptimalUnderlier; - type F = BinaryField128b; - - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - - let x_in = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "x", log_size).unwrap() - }); - - let first_c_in = - unconstrained::<_, _, BinaryField1b>(&mut builder, "cin first", log_size).unwrap(); - - let second_c_in = - unconstrained::<_, _, BinaryField1b>(&mut builder, "cin second", log_size).unwrap(); - - let zero_oracle_carry = - transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - let lookup_t_dci = dci_lookup(&mut builder, "add table").unwrap(); - - let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); - - let _sum_and_cout = byte_sliced_double_conditional_increment::<_, _, TL>( - &mut builder, - "lasso_bytesliced_DCI", - &x_in, - first_c_in, - second_c_in, - log_size, - zero_oracle_carry, - &mut lookup_batch_dci, - ) + test_circuit(|builder| { + let log_size = 14; + let x_in = TL::from_fn(|_| unconstrained::(builder, "x", log_size).unwrap()); + let first_c_in = unconstrained::(builder, "cin first", log_size)?; + let second_c_in = unconstrained::(builder, "cin second", log_size)?; + let zero_oracle_carry = + transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?; + let lookup_t_dci = dci_lookup(builder, "add table")?; + let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); + let _sum_and_cout = byte_sliced_double_conditional_increment::( + builder, + "lasso_bytesliced_DCI", + &x_in, + first_c_in, + second_c_in, + log_size, + zero_oracle_carry, + &mut lookup_batch_dci, + )?; + lookup_batch_dci.execute::(builder)?; + Ok(vec![]) + }) .unwrap(); - - lookup_batch_dci.execute::<_, _, B32>(&mut builder).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } pub fn test_bytesliced_mul() where - TL: TowerLevel, - TL::Base: TowerLevel, + TL: TowerLevel, { - type U = OptimalUnderlier; - type F = BinaryField128b; - - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - - let mult_a = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "a", log_size).unwrap() - }); - let mult_b = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "b", log_size).unwrap() - }); - - let zero_oracle_carry = - transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - - let lookup_t_mul = mul_lookup(&mut builder, "mul lookup").unwrap(); - let lookup_t_add = add_lookup(&mut builder, "add lookup").unwrap(); - let lookup_t_dci = dci_lookup(&mut builder, "dci lookup").unwrap(); - - let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]); - let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); - let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); - - let _sum_and_cout = byte_sliced_mul::<_, _, TL::Base, TL>( - &mut builder, - "lasso_bytesliced_mul", - &mult_a, - &mult_b, - log_size, - zero_oracle_carry, - &mut lookup_batch_mul, - &mut lookup_batch_add, - &mut lookup_batch_dci, - ) + test_circuit(|builder| { + let log_size = 14; + let mult_a = + TL::Base::from_fn(|_| unconstrained::(builder, "a", log_size).unwrap()); + let mult_b = + TL::Base::from_fn(|_| unconstrained::(builder, "b", log_size).unwrap()); + let zero_oracle_carry = + transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?; + let lookup_t_mul = mul_lookup(builder, "mul lookup")?; + let lookup_t_add = add_lookup(builder, "add lookup")?; + let lookup_t_dci = dci_lookup(builder, "dci lookup")?; + let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]); + let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); + let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); + let _sum_and_cout = byte_sliced_mul::( + builder, + "lasso_bytesliced_mul", + &mult_a, + &mult_b, + log_size, + zero_oracle_carry, + &mut lookup_batch_mul, + &mut lookup_batch_add, + &mut lookup_batch_dci, + )?; + Ok(vec![]) + }) .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } pub fn test_bytesliced_modular_mul() where - TL: TowerLevel, - TL::Base: TowerLevel, - >::Data: Debug, + TL: TowerLevel: Debug>, + TL::Base: TowerLevel = [OracleId; WIDTH]>, { - type U = OptimalUnderlier; - type F = BinaryField128b; - - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - - let mut rng = thread_rng(); - - let mult_a = builder.add_committed_multiple::("a", log_size, B8::TOWER_LEVEL); - let mult_b = builder.add_committed_multiple::("b", log_size, B8::TOWER_LEVEL); - - let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8); + test_circuit(|builder| { + let log_size = 14; + let mut rng = thread_rng(); + let mult_a = builder.add_committed_multiple::("a", log_size, B8::TOWER_LEVEL); + let mult_b = builder.add_committed_multiple::("b", log_size, B8::TOWER_LEVEL); + let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8); + let modulus = + (random_u512(&mut StdRng::from_seed([42; 32])) % input_bitmask) + U512::from(1u8); - let modulus = (random_u512(&mut rng) % input_bitmask) + U512::from(1u8); + if let Some(witness) = builder.witness() { + let mut mult_a: [_; WIDTH] = + array::from_fn(|byte_idx| witness.new_column::(mult_a[byte_idx])); - if let Some(witness) = builder.witness() { - let mut mult_a: [_; WIDTH] = - array::from_fn(|byte_idx| witness.new_column::(mult_a[byte_idx])); + let mult_a_u8 = mult_a.each_mut().map(|col| col.as_mut_slice::()); - let mult_a_u8 = mult_a.each_mut().map(|col| col.as_mut_slice::()); + let mut mult_b: [_; WIDTH] = + array::from_fn(|byte_idx| witness.new_column::(mult_b[byte_idx])); - let mut mult_b: [_; WIDTH] = - array::from_fn(|byte_idx| witness.new_column::(mult_b[byte_idx])); + let mult_b_u8 = mult_b.each_mut().map(|col| col.as_mut_slice::()); - let mult_b_u8 = mult_b.each_mut().map(|col| col.as_mut_slice::()); + for row_idx in 0..1 << log_size { + let mut a = random_u512(&mut rng); + let mut b = random_u512(&mut rng); - for row_idx in 0..1 << log_size { - let mut a = random_u512(&mut rng); - let mut b = random_u512(&mut rng); + a %= modulus; + b %= modulus; - a %= modulus; - b %= modulus; - - for byte_idx in 0..WIDTH { - mult_a_u8[byte_idx][row_idx] = a.byte(byte_idx); - mult_b_u8[byte_idx][row_idx] = b.byte(byte_idx); + for byte_idx in 0..WIDTH { + mult_a_u8[byte_idx][row_idx] = a.byte(byte_idx); + mult_b_u8[byte_idx][row_idx] = b.byte(byte_idx); + } } } - } - - let modulus_input: [_; WIDTH] = array::from_fn(|byte_idx| modulus.byte(byte_idx)); - - let zero_oracle_byte = - transparent::constant(&mut builder, "zero carry", log_size, BinaryField8b::ZERO).unwrap(); - let zero_oracle_carry = - transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - - let _modded_product = byte_sliced_modular_mul::<_, _, TL::Base, TL>( - &mut builder, - "lasso_bytesliced_mul", - &mult_a, - &mult_b, - &modulus_input, - log_size, - zero_oracle_byte, - zero_oracle_carry, - ) + let modulus_input: [_; WIDTH] = array::from_fn(|byte_idx| modulus.byte(byte_idx)); + let zero_oracle_byte = + transparent::constant(builder, "zero carry", log_size, BinaryField8b::ZERO)?; + let zero_oracle_carry = + transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?; + let _modded_product = byte_sliced_modular_mul::( + builder, + "lasso_bytesliced_mul", + &mult_a, + &mult_b, + &modulus_input, + log_size, + zero_oracle_byte, + zero_oracle_carry, + )?; + Ok(vec![]) + }) .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } diff --git a/crates/circuits/src/lasso/big_integer_ops/mod.rs b/crates/circuits/src/lasso/big_integer_ops/mod.rs index 3046af9be..4b41b524f 100644 --- a/crates/circuits/src/lasso/big_integer_ops/mod.rs +++ b/crates/circuits/src/lasso/big_integer_ops/mod.rs @@ -12,3 +12,56 @@ pub use byte_sliced_add_carryfree::byte_sliced_add_carryfree; pub use byte_sliced_double_conditional_increment::byte_sliced_double_conditional_increment; pub use byte_sliced_modular_mul::byte_sliced_modular_mul; pub use byte_sliced_mul::byte_sliced_mul; + +#[cfg(test)] +mod tests { + use binius_field::tower_levels::{ + TowerLevel1, TowerLevel16, TowerLevel2, TowerLevel4, TowerLevel8, + }; + + use super::byte_sliced_test_utils::{ + test_bytesliced_add, test_bytesliced_add_carryfree, + test_bytesliced_double_conditional_increment, test_bytesliced_modular_mul, + test_bytesliced_mul, + }; + + #[test] + fn test_lasso_add_bytesliced() { + test_bytesliced_add::<1, TowerLevel1>(); + test_bytesliced_add::<2, TowerLevel2>(); + test_bytesliced_add::<4, TowerLevel4>(); + test_bytesliced_add::<8, TowerLevel8>(); + } + + #[test] + fn test_lasso_mul_bytesliced() { + test_bytesliced_mul::<1, TowerLevel2>(); + test_bytesliced_mul::<2, TowerLevel4>(); + test_bytesliced_mul::<4, TowerLevel8>(); + test_bytesliced_mul::<8, TowerLevel16>(); + } + + #[test] + fn test_lasso_modular_mul_bytesliced() { + test_bytesliced_modular_mul::<1, TowerLevel2>(); + test_bytesliced_modular_mul::<2, TowerLevel4>(); + test_bytesliced_modular_mul::<4, TowerLevel8>(); + test_bytesliced_modular_mul::<8, TowerLevel16>(); + } + + #[test] + fn test_lasso_bytesliced_double_conditional_increment() { + test_bytesliced_double_conditional_increment::<1, TowerLevel1>(); + test_bytesliced_double_conditional_increment::<2, TowerLevel2>(); + test_bytesliced_double_conditional_increment::<4, TowerLevel4>(); + test_bytesliced_double_conditional_increment::<8, TowerLevel8>(); + } + + #[test] + fn test_lasso_bytesliced_add_carryfree() { + test_bytesliced_add_carryfree::<1, TowerLevel1>(); + test_bytesliced_add_carryfree::<2, TowerLevel2>(); + test_bytesliced_add_carryfree::<4, TowerLevel4>(); + test_bytesliced_add_carryfree::<8, TowerLevel8>(); + } +} diff --git a/crates/circuits/src/lasso/lasso.rs b/crates/circuits/src/lasso/lasso.rs index 1ecde278a..d93e2d484 100644 --- a/crates/circuits/src/lasso/lasso.rs +++ b/crates/circuits/src/lasso/lasso.rs @@ -4,15 +4,20 @@ use anyhow::{ensure, Error, Result}; use binius_core::{constraint_system::channel::ChannelId, oracle::OracleId}; use binius_field::{ as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField1b, ExtensionField, PackedFieldIndexable, TowerField, + ExtensionField, Field, PackedFieldIndexable, TowerField, }; use itertools::{izip, Itertools}; -use crate::{builder::ConstraintSystemBuilder, transparent}; +use crate::{ + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, + transparent, +}; -pub fn lasso( - builder: &mut ConstraintSystemBuilder, +pub fn lasso( + builder: &mut ConstraintSystemBuilder, name: impl ToString, n_lookups: &[usize], u_to_t_mappings: &[impl AsRef<[usize]>], @@ -21,10 +26,10 @@ pub fn lasso( channel: ChannelId, ) -> Result<()> where - U: UnderlierType + PackScalar + PackScalar + PackScalar, - F: TowerField + ExtensionField + From, - PackedType: PackedFieldIndexable, FC: TowerField, + U: PackScalar, + F: ExtensionField + From, + PackedType: PackedFieldIndexable, { if n_lookups.len() != lookups_u.len() { Err(anyhow::Error::msg("n_vars and lookups_u must be of the same length"))?; @@ -55,7 +60,7 @@ where } let t_log_rows = builder.log_rows(lookup_t.as_ref().iter().copied())?; - let lookup_o = transparent::constant(builder, "lookup_o", t_log_rows, F::ONE)?; + let lookup_o = transparent::constant(builder, "lookup_o", t_log_rows, Field::ONE)?; let lookup_f = builder.add_committed("lookup_f", t_log_rows, FC::TOWER_LEVEL); let lookups_r = u_log_rows .iter() diff --git a/crates/circuits/src/lasso/lookups/u8_arithmetic.rs b/crates/circuits/src/lasso/lookups/u8_arithmetic.rs index 5fdbe0c71..bb9a4e6c3 100644 --- a/crates/circuits/src/lasso/lookups/u8_arithmetic.rs +++ b/crates/circuits/src/lasso/lookups/u8_arithmetic.rs @@ -2,34 +2,19 @@ use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField32b, TowerField}; use crate::builder::ConstraintSystemBuilder; -type B8 = BinaryField8b; -type B16 = BinaryField16b; type B32 = BinaryField32b; const T_LOG_SIZE_MUL: usize = 16; const T_LOG_SIZE_ADD: usize = 17; const T_LOG_SIZE_DCI: usize = 10; -pub fn mul_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn mul_lookup( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_MUL, B32::TOWER_LEVEL); @@ -53,17 +38,10 @@ where Ok(lookup_t) } -pub fn add_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn add_lookup( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_ADD, B32::TOWER_LEVEL); @@ -95,17 +73,10 @@ where Ok(lookup_t) } -pub fn add_carryfree_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn add_carryfree_lookup( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_ADD, B32::TOWER_LEVEL); @@ -139,17 +110,10 @@ where Ok(lookup_t) } -pub fn dci_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn dci_lookup( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_DCI, B32::TOWER_LEVEL); @@ -182,3 +146,125 @@ where builder.pop_namespace(); Ok(lookup_t) } + +#[cfg(test)] +mod tests { + use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b}; + + use crate::{ + builder::test_utils::test_circuit, + lasso::{self, batch::LookupBatch}, + unconstrained::unconstrained, + }; + + #[test] + fn test_lasso_u8add_carryfree_rejects_carry() { + // TODO: Make this test 100% certain to pass instead of 2^14 bits of security from randomness + test_circuit(|builder| { + let log_size = 14; + let x_in = unconstrained::(builder, "x", log_size)?; + let y_in = unconstrained::(builder, "y", log_size)?; + let c_in = unconstrained::(builder, "c", log_size)?; + + let lookup_t = super::add_carryfree_lookup(builder, "add cf table")?; + let mut lookup_batch = LookupBatch::new([lookup_t]); + let _sum_and_cout = lasso::u8add_carryfree( + builder, + &mut lookup_batch, + "lasso_u8add", + x_in, + y_in, + c_in, + log_size, + )?; + lookup_batch.execute::(builder)?; + Ok(vec![]) + }) + .expect_err("Rejected overflowing add"); + } + + #[test] + fn test_lasso_u8mul() { + test_circuit(|builder| { + let log_size = 10; + + let mult_a = unconstrained::(builder, "mult_a", log_size)?; + let mult_b = unconstrained::(builder, "mult_b", log_size)?; + + let mul_lookup_table = super::mul_lookup(builder, "mul table")?; + + let mut lookup_batch = LookupBatch::new([mul_lookup_table]); + + let _product = lasso::u8mul( + builder, + &mut lookup_batch, + "lasso_u8mul", + mult_a, + mult_b, + 1 << log_size, + )?; + + lookup_batch.execute::(builder)?; + Ok(vec![]) + }) + .unwrap(); + } + + #[test] + fn test_lasso_batched_u8mul() { + test_circuit(|builder| { + let log_size = 10; + let mul_lookup_table = super::mul_lookup(builder, "mul table")?; + + let mut lookup_batch = LookupBatch::new([mul_lookup_table]); + + for _ in 0..10 { + let mult_a = unconstrained::(builder, "mult_a", log_size)?; + let mult_b = unconstrained::(builder, "mult_b", log_size)?; + + let _product = lasso::u8mul( + builder, + &mut lookup_batch, + "lasso_u8mul", + mult_a, + mult_b, + 1 << log_size, + )?; + } + + lookup_batch.execute::(builder)?; + Ok(vec![]) + }) + .unwrap(); + } + + #[test] + fn test_lasso_batched_u8mul_rejects() { + test_circuit(|builder| { + let log_size = 10; + + // We try to feed in the add table instead + let mul_lookup_table = super::add_lookup(builder, "mul table")?; + + let mut lookup_batch = LookupBatch::new([mul_lookup_table]); + + // TODO?: Make this test fail 100% of the time, even though its almost impossible with rng + for _ in 0..10 { + let mult_a = unconstrained::(builder, "mult_a", log_size)?; + let mult_b = unconstrained::(builder, "mult_b", log_size)?; + let _product = lasso::u8mul( + builder, + &mut lookup_batch, + "lasso_u8mul", + mult_a, + mult_b, + 1 << log_size, + )?; + } + + lookup_batch.execute::(builder)?; + Ok(vec![]) + }) + .expect_err("Channels should be unbalanced"); + } +} diff --git a/crates/circuits/src/lasso/sha256.rs b/crates/circuits/src/lasso/sha256.rs index c8677acb6..1b8f3c3dd 100644 --- a/crates/circuits/src/lasso/sha256.rs +++ b/crates/circuits/src/lasso/sha256.rs @@ -1,21 +1,19 @@ // Copyright 2024-2025 Irreducible Inc. -use std::marker::PhantomData; - use anyhow::Result; use binius_core::oracle::OracleId; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField16b, BinaryField1b, BinaryField32b, BinaryField4b, BinaryField8b, ExtensionField, + as_packed_field::PackedType, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField4b, PackedFieldIndexable, TowerField, }; -use bytemuck::Pod; use itertools::izip; use super::{lasso::lasso, u32add::SeveralU32add}; use crate::{ - builder::ConstraintSystemBuilder, + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, pack::pack, sha256::{rotate_and_xor, u32const_repeating, RotateRightType, INIT, ROUND_CONSTS_K}, }; @@ -24,36 +22,19 @@ pub const CH_MAJ_T_LOG_SIZE: usize = 12; type B1 = BinaryField1b; type B4 = BinaryField4b; -type B8 = BinaryField8b; type B16 = BinaryField16b; type B32 = BinaryField32b; -struct SeveralBitwise { +struct SeveralBitwise { n_lookups: Vec, lookup_t: OracleId, lookups_u: Vec<[OracleId; 1]>, u_to_t_mappings: Vec>, f: fn(u32, u32, u32) -> u32, - _phantom: PhantomData<(U, F)>, } -impl SeveralBitwise -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + ExtensionField + ExtensionField + ExtensionField, -{ - pub fn new( - builder: &mut ConstraintSystemBuilder, - f: fn(u32, u32, u32) -> u32, - ) -> Result { +impl SeveralBitwise { + pub fn new(builder: &mut ConstraintSystemBuilder, f: fn(u32, u32, u32) -> u32) -> Result { let lookup_t = builder.add_committed("bitwise lookup_t", CH_MAJ_T_LOG_SIZE, B16::TOWER_LEVEL); @@ -80,13 +61,12 @@ where lookups_u: Vec::new(), u_to_t_mappings: Vec::new(), f, - _phantom: PhantomData, }) } pub fn calculate( &mut self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, params: [OracleId; 3], ) -> Result { @@ -94,9 +74,9 @@ where let log_size = builder.log_rows(params)?; - let xin_packed = pack::(xin, builder, "xin_packed")?; - let yin_packed = pack::(yin, builder, "yin_packed")?; - let zin_packed = pack::(zin, builder, "zin_packed")?; + let xin_packed = pack::(xin, builder, "xin_packed")?; + let yin_packed = pack::(yin, builder, "yin_packed")?; + let zin_packed = pack::(zin, builder, "zin_packed")?; let res = builder.add_committed(name, log_size, B1::TOWER_LEVEL); @@ -160,12 +140,12 @@ where pub fn finalize( self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, ) -> Result<()> { let channel = builder.add_channel(); - lasso::<_, _, B32>( + lasso::( builder, name, &self.n_lookups, @@ -177,29 +157,11 @@ where } } -pub fn sha256( - builder: &mut ConstraintSystemBuilder, +pub fn sha256( + builder: &mut ConstraintSystemBuilder, input: [OracleId; 16], log_size: usize, -) -> Result<[OracleId; 8], anyhow::Error> -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField - + ExtensionField - + ExtensionField - + ExtensionField - + ExtensionField, -{ +) -> Result<[OracleId; 8], anyhow::Error> { let n_vars = log_size; let mut several_u32_add = SeveralU32add::new(builder)?; @@ -309,3 +271,64 @@ where Ok(output) } + +#[cfg(test)] +mod tests { + use binius_core::oracle::OracleId; + use binius_field::{as_packed_field::PackedType, BinaryField1b, BinaryField8b, TowerField}; + use sha2::{compress256, digest::generic_array::GenericArray}; + + use crate::{ + builder::{test_utils::test_circuit, types::U}, + unconstrained::unconstrained, + }; + + #[test] + fn test_sha256_lasso() { + test_circuit(|builder| { + let log_size = PackedType::::LOG_WIDTH + BinaryField8b::TOWER_LEVEL; + let input: [OracleId; 16] = std::array::from_fn(|i| { + unconstrained::(builder, i, log_size).unwrap() + }); + let state_output = super::sha256(builder, input, log_size).unwrap(); + + if let Some(witness) = builder.witness() { + let input_witneses: [_; 16] = std::array::from_fn(|i| { + witness + .get::(input[i]) + .unwrap() + .as_slice::() + }); + + let output_witneses: [_; 8] = std::array::from_fn(|i| { + witness + .get::(state_output[i]) + .unwrap() + .as_slice::() + }); + + let mut generic_array_input = GenericArray::::default(); + + let n_compressions = input_witneses[0].len(); + + for j in 0..n_compressions { + for i in 0..16 { + for z in 0..4 { + generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; + } + } + + let mut output = crate::sha256::INIT; + compress256(&mut output, &[generic_array_input]); + + for i in 0..8 { + assert_eq!(output[i], output_witneses[i][j]); + } + } + } + + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/circuits/src/lasso/u32add.rs b/crates/circuits/src/lasso/u32add.rs index dca11d724..9b260b8c2 100644 --- a/crates/circuits/src/lasso/u32add.rs +++ b/crates/circuits/src/lasso/u32add.rs @@ -7,14 +7,19 @@ use binius_core::oracle::{OracleId, ShiftVariant}; use binius_field::{ as_packed_field::{PackScalar, PackedType}, packed::set_packed_slice, - underlier::{UnderlierType, U1}, + underlier::U1, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, PackedFieldIndexable, TowerField, }; -use bytemuck::Pod; use itertools::izip; use super::lasso::lasso; -use crate::{builder::ConstraintSystemBuilder, pack::pack}; +use crate::{ + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, + pack::pack, +}; const ADD_T_LOG_SIZE: usize = 17; @@ -22,32 +27,18 @@ type B1 = BinaryField1b; type B8 = BinaryField8b; type B32 = BinaryField32b; -pub fn u32add( - builder: &mut ConstraintSystemBuilder, +pub fn u32add( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, xin: OracleId, yin: OracleId, ) -> Result where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - B8: ExtensionField + ExtensionField, - F: TowerField - + ExtensionField - + ExtensionField - + ExtensionField - + ExtensionField, FInput: TowerField, FOutput: TowerField, - B32: TowerField, + U: PackScalar + PackScalar, + B8: ExtensionField + ExtensionField, + F: ExtensionField + ExtensionField, { let mut several = SeveralU32add::new(builder)?; let sum = several.u32add::(builder, name.clone(), xin, yin)?; @@ -55,7 +46,7 @@ where Ok(sum) } -pub struct SeveralU32add { +pub struct SeveralU32add { n_lookups: Vec, lookup_t: OracleId, lookups_u: Vec<[OracleId; 1]>, @@ -64,19 +55,8 @@ pub struct SeveralU32add { _phantom: PhantomData<(U, F)>, } -impl SeveralU32add -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + ExtensionField + ExtensionField, -{ - pub fn new(builder: &mut ConstraintSystemBuilder) -> Result { +impl SeveralU32add { + pub fn new(builder: &mut ConstraintSystemBuilder) -> Result { let lookup_t = builder.add_committed("lookup_t", ADD_T_LOG_SIZE, B32::TOWER_LEVEL); if let Some(witness) = builder.witness() { @@ -111,15 +91,15 @@ where pub fn u32add( &mut self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, ) -> Result where - U: PackScalar + PackScalar, FInput: TowerField, FOutput: TowerField, + U: PackScalar + PackScalar, F: ExtensionField + ExtensionField, B8: ExtensionField + ExtensionField, { @@ -143,8 +123,8 @@ where let cin = builder.add_shifted("cin", cout, 1, 2, ShiftVariant::LogicalLeft)?; - let xin_u8 = pack::<_, _, FInput, B8>(xin, builder, "repacked xin")?; - let yin_u8 = pack::<_, _, FInput, B8>(yin, builder, "repacked yin")?; + let xin_u8 = pack::(xin, builder, "repacked xin")?; + let yin_u8 = pack::(yin, builder, "repacked yin")?; let lookup_u = builder.add_linear_combination( "lookup_u", @@ -231,12 +211,12 @@ where pub fn finalize( mut self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, ) -> Result<()> { let channel = builder.add_channel(); self.finalized = true; - lasso::<_, _, B32>( + lasso::( builder, name, &self.n_lookups, @@ -248,8 +228,56 @@ where } } -impl Drop for SeveralU32add { +impl Drop for SeveralU32add { fn drop(&mut self) { assert!(self.finalized) } } + +#[cfg(test)] +mod tests { + use binius_field::{BinaryField1b, BinaryField8b}; + + use super::SeveralU32add; + use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained}; + + #[test] + fn test_several_lasso_u32add() { + test_circuit(|builder| { + let mut several_u32_add = SeveralU32add::new(builder).unwrap(); + for log_size in [11, 12, 13] { + // BinaryField8b is used here because we utilize an 8x8x1→8 table + let add_a_u8 = unconstrained::(builder, "add_a", log_size).unwrap(); + let add_b_u8 = unconstrained::(builder, "add_b", log_size).unwrap(); + let _sum = several_u32_add + .u32add::( + builder, + "lasso_u32add", + add_a_u8, + add_b_u8, + ) + .unwrap(); + } + several_u32_add.finalize(builder, "lasso_u32add").unwrap(); + Ok(vec![]) + }) + .unwrap(); + } + + #[test] + fn test_lasso_u32add() { + test_circuit(|builder| { + let log_size = 14; + let add_a = unconstrained::(builder, "add_a", log_size)?; + let add_b = unconstrained::(builder, "add_b", log_size)?; + let _sum = super::u32add::( + builder, + "lasso_u32add", + add_a, + add_b, + )?; + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/circuits/src/lasso/u8_double_conditional_increment.rs b/crates/circuits/src/lasso/u8_double_conditional_increment.rs index 907ac418c..1f10b443a 100644 --- a/crates/circuits/src/lasso/u8_double_conditional_increment.rs +++ b/crates/circuits/src/lasso/u8_double_conditional_increment.rs @@ -2,44 +2,24 @@ use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b, TowerField}; use super::batch::LookupBatch; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; type B32 = BinaryField32b; -pub fn u8_double_conditional_increment( - builder: &mut ConstraintSystemBuilder, +pub fn u8_double_conditional_increment( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, x_in: OracleId, first_carry_in: OracleId, second_carry_in: OracleId, log_size: usize, -) -> Result<(OracleId, OracleId), anyhow::Error> -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result<(OracleId, OracleId), anyhow::Error> { builder.push_namespace(name); let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL); diff --git a/crates/circuits/src/lasso/u8add.rs b/crates/circuits/src/lasso/u8add.rs index fe42995dc..7e3dd0e58 100644 --- a/crates/circuits/src/lasso/u8add.rs +++ b/crates/circuits/src/lasso/u8add.rs @@ -4,44 +4,24 @@ use std::vec; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b, TowerField}; use super::batch::LookupBatch; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; type B32 = BinaryField32b; -pub fn u8add( - builder: &mut ConstraintSystemBuilder, +pub fn u8add( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, x_in: OracleId, y_in: OracleId, carry_in: OracleId, log_size: usize, -) -> Result<(OracleId, OracleId), anyhow::Error> -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result<(OracleId, OracleId), anyhow::Error> { builder.push_namespace(name); let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL); diff --git a/crates/circuits/src/lasso/u8add_carryfree.rs b/crates/circuits/src/lasso/u8add_carryfree.rs index 45bebbd85..fd1959592 100644 --- a/crates/circuits/src/lasso/u8add_carryfree.rs +++ b/crates/circuits/src/lasso/u8add_carryfree.rs @@ -2,44 +2,24 @@ use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b, TowerField}; use super::batch::LookupBatch; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; type B32 = BinaryField32b; -pub fn u8add_carryfree( - builder: &mut ConstraintSystemBuilder, +pub fn u8add_carryfree( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, x_in: OracleId, y_in: OracleId, carry_in: OracleId, log_size: usize, -) -> Result -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL); diff --git a/crates/circuits/src/lasso/u8mul.rs b/crates/circuits/src/lasso/u8mul.rs index 589bf7a4e..d0c93ecd1 100644 --- a/crates/circuits/src/lasso/u8mul.rs +++ b/crates/circuits/src/lasso/u8mul.rs @@ -2,37 +2,24 @@ use anyhow::{ensure, Result}; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField16b, BinaryField32b, BinaryField8b, TowerField}; use itertools::izip; use super::batch::LookupBatch; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; type B8 = BinaryField8b; type B16 = BinaryField16b; type B32 = BinaryField32b; -pub fn u8mul_bytesliced( - builder: &mut ConstraintSystemBuilder, +pub fn u8mul_bytesliced( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, mult_a: OracleId, mult_b: OracleId, n_multiplications: usize, -) -> Result<[OracleId; 2], anyhow::Error> -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result<[OracleId; 2], anyhow::Error> { builder.push_namespace(name); let log_rows = builder.log_rows([mult_a, mult_b])?; let product = builder.add_committed_multiple("product", log_rows, B8::TOWER_LEVEL); @@ -92,21 +79,14 @@ where Ok(product) } -pub fn u8mul( - builder: &mut ConstraintSystemBuilder, +pub fn u8mul( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, mult_a: OracleId, mult_b: OracleId, n_multiplications: usize, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name.clone()); let product_bytesliced = diff --git a/crates/circuits/src/lib.rs b/crates/circuits/src/lib.rs index c76857c67..59ae8195d 100644 --- a/crates/circuits/src/lib.rs +++ b/crates/circuits/src/lib.rs @@ -11,9 +11,9 @@ pub mod arithmetic; pub mod bitwise; +pub mod blake3; pub mod builder; pub mod collatz; -pub mod groestl; pub mod keccakf; pub mod lasso; mod pack; @@ -26,504 +26,32 @@ pub mod vision; #[cfg(test)] mod tests { - use std::array; - use binius_core::{ constraint_system::{ self, channel::{Boundary, FlushDirection}, - validate::validate_witness, }, fiat_shamir::HasherChallenger, - oracle::OracleId, tower::CanonicalTowerFamily, }; use binius_field::{ - arch::OptimalUnderlier, - as_packed_field::PackedType, - tower_levels::{TowerLevel1, TowerLevel16, TowerLevel2, TowerLevel4, TowerLevel8}, - underlier::WithUnderlier, - AESTowerField16b, BinaryField128b, BinaryField1b, BinaryField32b, BinaryField64b, - BinaryField8b, Field, TowerField, + as_packed_field::PackedType, underlier::WithUnderlier, BinaryField8b, Field, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; use groestl_crypto::Groestl256; - use rand::{rngs::StdRng, Rng, SeedableRng}; - use sha2::{compress256, digest::generic_array::GenericArray}; - - use crate::{ - arithmetic, bitwise, - builder::ConstraintSystemBuilder, - groestl::groestl_p_permutation, - keccakf::{keccakf, KeccakfState}, - lasso::{ - self, - batch::LookupBatch, - big_integer_ops::byte_sliced_test_utils::{ - test_bytesliced_add, test_bytesliced_add_carryfree, - test_bytesliced_double_conditional_increment, test_bytesliced_modular_mul, - test_bytesliced_mul, - }, - lookups, - u32add::SeveralU32add, - }, - plain_lookup, - sha256::sha256, - u32fib::u32fib, - unconstrained::unconstrained, - vision::vision_permutation, - }; - - type U = OptimalUnderlier; - type F = BinaryField128b; - - #[test] - fn test_lasso_u8add_carryfree_rejects_carry() { - // TODO: Make this test 100% certain to pass instead of 2^14 bits of security from randomness - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - let x_in = unconstrained::<_, _, BinaryField8b>(&mut builder, "x", log_size).unwrap(); - let y_in = unconstrained::<_, _, BinaryField8b>(&mut builder, "y", log_size).unwrap(); - let c_in = unconstrained::<_, _, BinaryField1b>(&mut builder, "c", log_size).unwrap(); - - let lookup_t = - lookups::u8_arithmetic::add_carryfree_lookup(&mut builder, "add cf table").unwrap(); - let mut lookup_batch = LookupBatch::new([lookup_t]); - let _sum_and_cout = lasso::u8add_carryfree( - &mut builder, - &mut lookup_batch, - "lasso_u8add", - x_in, - y_in, - c_in, - log_size, - ) - .unwrap(); - - lookup_batch - .execute::<_, _, BinaryField32b>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness) - .expect_err("Rejected overflowing add"); - } - - #[test] - fn test_lasso_add_bytesliced() { - test_bytesliced_add::<1, TowerLevel1>(); - test_bytesliced_add::<2, TowerLevel2>(); - test_bytesliced_add::<4, TowerLevel4>(); - test_bytesliced_add::<8, TowerLevel8>(); - } - - #[test] - fn test_lasso_mul_bytesliced() { - test_bytesliced_mul::<1, TowerLevel2>(); - test_bytesliced_mul::<2, TowerLevel4>(); - test_bytesliced_mul::<4, TowerLevel8>(); - test_bytesliced_mul::<8, TowerLevel16>(); - } - - #[test] - fn test_lasso_modular_mul_bytesliced() { - test_bytesliced_modular_mul::<1, TowerLevel2>(); - test_bytesliced_modular_mul::<2, TowerLevel4>(); - test_bytesliced_modular_mul::<4, TowerLevel8>(); - test_bytesliced_modular_mul::<8, TowerLevel16>(); - } - - #[test] - fn test_lasso_bytesliced_double_conditional_increment() { - test_bytesliced_double_conditional_increment::<1, TowerLevel1>(); - test_bytesliced_double_conditional_increment::<2, TowerLevel2>(); - test_bytesliced_double_conditional_increment::<4, TowerLevel4>(); - test_bytesliced_double_conditional_increment::<8, TowerLevel8>(); - } - - #[test] - fn test_lasso_bytesliced_add_carryfree() { - test_bytesliced_add_carryfree::<1, TowerLevel1>(); - test_bytesliced_add_carryfree::<2, TowerLevel2>(); - test_bytesliced_add_carryfree::<4, TowerLevel4>(); - test_bytesliced_add_carryfree::<8, TowerLevel8>(); - } - - #[test] - fn test_lasso_u8mul() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 10; - - let mult_a = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_a", log_size).unwrap(); - let mult_b = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_b", log_size).unwrap(); - - let mul_lookup_table = - lookups::u8_arithmetic::mul_lookup(&mut builder, "mul table").unwrap(); - - let mut lookup_batch = LookupBatch::new([mul_lookup_table]); - - let _product = lasso::u8mul( - &mut builder, - &mut lookup_batch, - "lasso_u8mul", - mult_a, - mult_b, - 1 << log_size, - ) - .unwrap(); - - lookup_batch - .execute::<_, _, BinaryField32b>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_lasso_batched_u8mul() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 10; - let mul_lookup_table = - lookups::u8_arithmetic::mul_lookup(&mut builder, "mul table").unwrap(); - - let mut lookup_batch = LookupBatch::new([mul_lookup_table]); - - for _ in 0..10 { - let mult_a = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_a", log_size).unwrap(); - let mult_b = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_b", log_size).unwrap(); - - let _product = lasso::u8mul( - &mut builder, - &mut lookup_batch, - "lasso_u8mul", - mult_a, - mult_b, - 1 << log_size, - ) - .unwrap(); - } - - lookup_batch - .execute::<_, _, BinaryField32b>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_lasso_batched_u8mul_rejects() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 10; - - // We try to feed in the add table instead - let mul_lookup_table = - lookups::u8_arithmetic::add_lookup(&mut builder, "mul table").unwrap(); - - let mut lookup_batch = LookupBatch::new([mul_lookup_table]); - - // TODO?: Make this test fail 100% of the time, even though its almost impossible with rng - for _ in 0..10 { - let mult_a = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_a", log_size).unwrap(); - let mult_b = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_b", log_size).unwrap(); - - let _product = lasso::u8mul( - &mut builder, - &mut lookup_batch, - "lasso_u8mul", - mult_a, - mult_b, - 1 << log_size, - ) - .unwrap(); - } - - lookup_batch - .execute::<_, _, BinaryField32b>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness) - .expect_err("Channels should be unbalanced"); - } - - #[test] - fn test_several_lasso_u32add() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - - let mut several_u32_add = SeveralU32add::new(&mut builder).unwrap(); - - for log_size in [11, 12, 13] { - // BinaryField8b is used here because we utilize an 8x8x1→8 table - let add_a_u8 = - unconstrained::<_, _, BinaryField8b>(&mut builder, "add_a", log_size).unwrap(); - let add_b_u8 = - unconstrained::<_, _, BinaryField8b>(&mut builder, "add_b", log_size).unwrap(); - let _sum = several_u32_add - .u32add::( - &mut builder, - "lasso_u32add", - add_a_u8, - add_b_u8, - ) - .unwrap(); - } - - several_u32_add - .finalize(&mut builder, "lasso_u32add") - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_lasso_u32add() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - - let add_a = unconstrained::<_, _, BinaryField1b>(&mut builder, "add_a", log_size).unwrap(); - let add_b = unconstrained::<_, _, BinaryField1b>(&mut builder, "add_b", log_size).unwrap(); - let _sum = lasso::u32add::<_, _, BinaryField1b, BinaryField1b>( - &mut builder, - "lasso_u32add", - add_a, - add_b, - ) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_u32add() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - let a = unconstrained::<_, _, BinaryField1b>(&mut builder, "a", log_size).unwrap(); - let b = unconstrained::<_, _, BinaryField1b>(&mut builder, "b", log_size).unwrap(); - let _c = arithmetic::u32::add(&mut builder, "u32add", a, b, arithmetic::Flags::Unchecked) - .unwrap(); - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_u32fib() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size_1b = 14; - let _ = u32fib(&mut builder, "u32fib", log_size_1b).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_bitwise() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 6; - let a = unconstrained::<_, _, BinaryField1b>(&mut builder, "a", log_size).unwrap(); - let b = unconstrained::<_, _, BinaryField1b>(&mut builder, "b", log_size).unwrap(); - let _and = bitwise::and(&mut builder, "and", a, b).unwrap(); - let _xor = bitwise::xor(&mut builder, "xor", a, b).unwrap(); - let _or = bitwise::or(&mut builder, "or", a, b).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_keccakf() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 5; - - let mut rng = StdRng::seed_from_u64(0); - let input_states = vec![KeccakfState(rng.gen())]; - let _state_out = keccakf(&mut builder, Some(input_states), log_size); - - let witness = builder.take_witness().unwrap(); - - let constraint_system = builder.build().unwrap(); - - let boundaries = vec![]; - - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_sha256() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = PackedType::::LOG_WIDTH; - let input: [OracleId; 16] = array::from_fn(|i| { - unconstrained::<_, _, BinaryField1b>(&mut builder, i, log_size).unwrap() - }); - let state_output = sha256(&mut builder, input, log_size).unwrap(); - - let witness = builder.witness().unwrap(); - - let input_witneses: [_; 16] = - array::from_fn(|i| witness.get(input[i]).unwrap().as_slice::()); - - let output_witneses: [_; 8] = - array::from_fn(|i| witness.get(state_output[i]).unwrap().as_slice::()); - - let mut generic_array_input = GenericArray::::default(); - - let n_compressions = input_witneses[0].len(); - - for j in 0..n_compressions { - for i in 0..16 { - for z in 0..4 { - generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; - } - } - - let mut output = crate::sha256::INIT; - compress256(&mut output, &[generic_array_input]); - - for i in 0..8 { - assert_eq!(output[i], output_witneses[i][j]); - } - } - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_sha256_lasso() { - let allocator = bumpalo::Bump::new(); - let mut builder = - ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = PackedType::::LOG_WIDTH + BinaryField8b::TOWER_LEVEL; - let input: [OracleId; 16] = array::from_fn(|i| { - unconstrained::<_, _, BinaryField1b>(&mut builder, i, log_size).unwrap() - }); - let state_output = lasso::sha256(&mut builder, input, log_size).unwrap(); - - let witness = builder.witness().unwrap(); - - let input_witneses: [_; 16] = array::from_fn(|i| { - witness - .get::(input[i]) - .unwrap() - .as_slice::() - }); - - let output_witneses: [_; 8] = array::from_fn(|i| { - witness - .get::(state_output[i]) - .unwrap() - .as_slice::() - }); - - let mut generic_array_input = GenericArray::::default(); - - let n_compressions = input_witneses[0].len(); - - for j in 0..n_compressions { - for i in 0..16 { - for z in 0..4 { - generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; - } - } - - let mut output = crate::sha256::INIT; - compress256(&mut output, &[generic_array_input]); - - for i in 0..8 { - assert_eq!(output[i], output_witneses[i][j]); - } - } - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_groestl() { - let allocator = bumpalo::Bump::new(); - let mut builder = - ConstraintSystemBuilder::::new_with_witness( - &allocator, - ); - let log_size = 9; - let _state_out = groestl_p_permutation(&mut builder, log_size).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_vision32b() { - let allocator = bumpalo::Bump::new(); - let mut builder = - ConstraintSystemBuilder::::new_with_witness( - &allocator, - ); - let log_size = 8; - let state_in: [OracleId; 24] = array::from_fn(|i| { - unconstrained::<_, _, BinaryField32b>(&mut builder, format!("p_in[{i}]"), log_size) - .unwrap() - }); - let _state_out = vision_permutation(&mut builder, log_size, state_in).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } + use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }; #[test] fn test_boundaries() { // Proving Collatz Orbits let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = PackedType::::LOG_WIDTH + 2; @@ -633,74 +161,4 @@ mod tests { >(&constraint_system, 1, 10, &boundaries, proof) .unwrap(); } - - #[test] - fn test_plain_u8_mul_lookup() { - const MAX_LOG_MULTIPLICITY: usize = 18; - let log_lookup_count = 19; - - let log_inv_rate = 1; - let security_bits = 20; - - let proof = { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - - let boundary = plain_lookup::test_plain_lookup::test_u8_mul_lookup::< - _, - _, - MAX_LOG_MULTIPLICITY, - >(&mut builder, log_lookup_count) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - // validating witness with `validate_witness` is too slow for large transparents like the `table` - - let domain_factory = DefaultEvaluationDomainFactory::default(); - let backend = make_portable_backend(); - - constraint_system::prove::< - U, - CanonicalTowerFamily, - _, - Groestl256, - Groestl256ByteCompression, - HasherChallenger, - _, - >( - &constraint_system, - log_inv_rate, - security_bits, - &[boundary], - witness, - &domain_factory, - &backend, - ) - .unwrap() - }; - - // verify - { - let mut builder = ConstraintSystemBuilder::::new(); - - let boundary = plain_lookup::test_plain_lookup::test_u8_mul_lookup::< - _, - _, - MAX_LOG_MULTIPLICITY, - >(&mut builder, log_lookup_count) - .unwrap(); - - let constraint_system = builder.build().unwrap(); - - constraint_system::verify::< - U, - CanonicalTowerFamily, - Groestl256, - Groestl256ByteCompression, - HasherChallenger, - >(&constraint_system, log_inv_rate, security_bits, &[boundary], proof) - .unwrap(); - } - } } diff --git a/crates/circuits/src/pack.rs b/crates/circuits/src/pack.rs index 5d3c8f79e..290a44729 100644 --- a/crates/circuits/src/pack.rs +++ b/crates/circuits/src/pack.rs @@ -2,22 +2,23 @@ use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, ExtensionField, TowerField, -}; +use binius_field::{as_packed_field::PackScalar, ExtensionField, TowerField}; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; -pub fn pack( +pub fn pack( oracle_id: OracleId, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, ) -> Result where - F: TowerField + ExtensionField + ExtensionField, + F: ExtensionField + ExtensionField, FInput: TowerField, FOutput: TowerField + ExtensionField, - U: UnderlierType + PackScalar + PackScalar + PackScalar, + U: PackScalar + PackScalar, { if FInput::TOWER_LEVEL == FOutput::TOWER_LEVEL { return Ok(oracle_id); diff --git a/crates/circuits/src/plain_lookup.rs b/crates/circuits/src/plain_lookup.rs index 0c41cc6ef..afcb77ac7 100644 --- a/crates/circuits/src/plain_lookup.rs +++ b/crates/circuits/src/plain_lookup.rs @@ -1,17 +1,16 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_core::{ - constraint_system::channel::{Boundary, FlushDirection}, - oracle::OracleId, -}; +use binius_core::{constraint_system::channel::FlushDirection, oracle::OracleId}; use binius_field::{ as_packed_field::PackScalar, packed::set_packed_slice, BinaryField1b, ExtensionField, Field, TowerField, }; use bytemuck::Pod; -use itertools::izip; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; /// Checks values in `lookup_values` to be in `table`. /// @@ -24,45 +23,36 @@ use crate::builder::ConstraintSystemBuilder; /// # Parameters /// - `builder`: a mutable reference to the `ConstraintSystemBuilder`. /// - `table`: an oracle holding the table of valid lookup values. -/// - `table_count`: only the first `table_count` values of `table` are considered valid lookup values. -/// - `balancer_value`: any valid table value, needed for balancing the channel. /// - `lookup_values`: an oracle holding the values to be looked up. /// - `lookup_values_count`: only the first `lookup_values_count` values in `lookup_values` will be looked up. /// -/// # Constraints -/// - no value in `lookup_values` can be looked only less than `1 << LOG_MAX_MULTIPLICITY` times, limiting completeness not soundness. -/// /// # How this Works /// We create a single channel for this lookup. /// We let the prover push all values in `lookup_values`, that is all values to be looked up, into the channel. /// We also must pull valid table values (i.e. values that appear in `table`) from the channel if the channel is to balance. /// By ensuring that only valid table values get pulled from the channel, and observing the channel to balance, we ensure that only valid table values get pushed (by the prover) into the channel. /// Therefore our construction is sound. -/// In order for the construction to be complete, allowing an honest prover to pass, we must pull each table value from the channel with exactly the same multiplicity (duplicate count) that the prover pushed that table value into the channel. +/// In order for the construction to be complete, allowing an honest prover to pass, we must pull each +/// table value from the channel with exactly the same multiplicity (duplicate count) that the prover pushed that table value into the channel. /// To do so, we allow the prover to commit information on the multiplicity of each table value. /// -/// The prover counts the multiplicity of each table value, and commits columns holding the bit-decomposition of the multiplicities. -/// Using these bit columns we create `component` columns the same height as the table, which select the table value where a multiplicity bit is 1 and select `balancer_value` where the bit is 0. -/// Pulling these component columns out of the channel with appropriate multiplicities, we pull out each table value from the channel with the multiplicity requested by the prover. -/// Due to the `balancer_value` appearing in the component columns, however, we will also pull the table value `balancer_value` from the channel many more times than needed. -/// To rectify this we put `balancer_value` in a boundary value and push this boundary value to the channel with a multiplicity that will balance the channel. -/// This boundary value is returned from the gadget. +/// The prover counts the multiplicity of each table value, and creates a bit column for +/// each of the LOG_MAX_MULTIPLICITY bits in the bit-decomposition of the multiplicities. +/// Then we flush the table values LOG_MAX_MULTIPLICITY times, each time using a different bit column as the 'selector' oracle to select which values in the +/// table actually get pushed into the channel flushed. When flushing the table with the i'th bit column as the selector, we flush with multiplicity 1 << i. /// -pub fn plain_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn plain_lookup( + builder: &mut ConstraintSystemBuilder, table: OracleId, - table_count: usize, - balancer_value: FS, lookup_values: OracleId, lookup_values_count: usize, -) -> Result, anyhow::Error> +) -> Result<(), anyhow::Error> where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, + U: PackScalar + Pod, + F: ExtensionField, FS: TowerField + Pod, { let n_vars = builder.log_rows([table])?; - debug_assert!(table_count <= 1 << n_vars); let channel = builder.add_channel(); @@ -75,58 +65,29 @@ where let values_slice = witness.get::(lookup_values)?.as_slice::(); multiplicities = Some(count_multiplicities( - &table_slice[0..table_count], + &table_slice[0..1 << n_vars], &values_slice[0..lookup_values_count], false, )?); } - let components: [OracleId; LOG_MAX_MULTIPLICITY] = - get_components::<_, _, FS, LOG_MAX_MULTIPLICITY>( - builder, - table, - table_count, - balancer_value, - multiplicities, - )?; - - components - .into_iter() - .enumerate() - .try_for_each(|(i, component)| { - builder.flush_with_multiplicity( - FlushDirection::Pull, - channel, - table_count, - [component], - 1 << i, - ) - })?; - - let balancer_value_multiplicity = - (((1 << LOG_MAX_MULTIPLICITY) - 1) * table_count - lookup_values_count) as u64; - - let boundary = Boundary { - values: vec![balancer_value.into()], - channel_id: channel, - direction: FlushDirection::Push, - multiplicity: balancer_value_multiplicity, - }; + let bits: [OracleId; LOG_MAX_MULTIPLICITY] = get_bits(builder, table, multiplicities)?; + bits.into_iter().enumerate().try_for_each(|(i, bit)| { + builder.flush_custom(FlushDirection::Pull, channel, bit, [table], 1 << i) + })?; - Ok(boundary) + Ok(()) } -// the `i`'th returned component holds values that are the product of the `table` values and the bits had by taking the `i`'th bit across the multiplicities. -fn get_components( - builder: &mut ConstraintSystemBuilder, +// the `i`'th returned bit column holds the `i`'th multiplicity bit. +fn get_bits( + builder: &mut ConstraintSystemBuilder, table: OracleId, - table_count: usize, - balancer_value: FS, multiplicities: Option>, ) -> Result<[OracleId; LOG_MAX_MULTIPLICITY], anyhow::Error> where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, + U: PackScalar, + F: ExtensionField, FS: TowerField + Pod, { let n_vars = builder.log_rows([table])?; @@ -134,13 +95,10 @@ where let bits: [OracleId; LOG_MAX_MULTIPLICITY] = builder .add_committed_multiple::("bits", n_vars, BinaryField1b::TOWER_LEVEL); - let components: [OracleId; LOG_MAX_MULTIPLICITY] = builder - .add_committed_multiple::("components", n_vars, FS::TOWER_LEVEL); - if let Some(witness) = builder.witness() { let multiplicities = multiplicities.ok_or_else(|| anyhow::anyhow!("multiplicities empty for prover"))?; - debug_assert_eq!(table_count, multiplicities.len()); + debug_assert_eq!(1 << n_vars, multiplicities.len()); // check all multiplicities are in range if multiplicities @@ -155,19 +113,13 @@ where // create the columns for the bits let mut bit_cols = bits.map(|bit| witness.new_column::(bit)); let mut packed_bit_cols = bit_cols.each_mut().map(|bit_col| bit_col.packed()); - // create the columns for the components - let mut component_cols = components.map(|component| witness.new_column::(component)); - let mut packed_component_cols = component_cols - .each_mut() - .map(|component_col| component_col.packed()); - - let table_slice = witness.get::(table)?.as_slice::(); - izip!(table_slice, multiplicities).enumerate().for_each( - |(i, (table_val, multiplicity))| { - for j in 0..LOG_MAX_MULTIPLICITY { + multiplicities + .iter() + .enumerate() + .for_each(|(i, multiplicity)| { + (0..LOG_MAX_MULTIPLICITY).for_each(|j| { let bit_set = multiplicity & (1 << j) != 0; - // set the bit value set_packed_slice( packed_bit_cols[j], i, @@ -176,36 +128,11 @@ where false => BinaryField1b::ZERO, }, ); - // set the component value - set_packed_slice( - packed_component_cols[j], - i, - match bit_set { - true => *table_val, - false => balancer_value, - }, - ); - } - }, - ); + }) + }); } - let expression = { - use binius_math::ArithExpr as Expr; - let table = Expr::Var(0); - let bit = Expr::Var(1); - let component = Expr::Var(2); - component - (bit.clone() * table + (Expr::one() - bit) * Expr::Const(balancer_value)) - }; - (0..LOG_MAX_MULTIPLICITY).for_each(|i| { - builder.assert_zero( - format!("lookup_{i}"), - [table, bits[i], components[i]], - expression.convert_field(), - ); - }); - - Ok(components) + Ok(bits) } #[cfg(test)] @@ -242,28 +169,20 @@ pub mod test_plain_lookup { }); } - pub fn test_u8_mul_lookup( - builder: &mut ConstraintSystemBuilder, + pub fn test_u8_mul_lookup( + builder: &mut ConstraintSystemBuilder, log_lookup_count: usize, - ) -> Result, anyhow::Error> - where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, - { + ) -> Result<(), anyhow::Error> { let table_values = generate_u8_mul_table(); let table = transparent::make_transparent( builder, "u8_mul_table", bytemuck::cast_slice::<_, BinaryField32b>(&table_values), )?; - let balancer_value = BinaryField32b::new(table_values[99]); // any table value let lookup_values = builder.add_committed("lookup_values", log_lookup_count, BinaryField32b::TOWER_LEVEL); - // reduce these if only some table values are valid - // or only some lookup_values are to be looked up - let table_count = table_values.len(); let lookup_values_count = 1 << log_lookup_count; if let Some(witness) = builder.witness() { @@ -272,16 +191,14 @@ pub mod test_plain_lookup { generate_random_u8_mul_claims(&mut mut_slice[0..lookup_values_count]); } - let boundary = plain_lookup::( + plain_lookup::( builder, table, - table_count, - balancer_value, lookup_values, lookup_values_count, )?; - Ok(boundary) + Ok(()) } } @@ -365,3 +282,83 @@ mod count_multiplicity_tests { assert_eq!(result, vec![1, 2, 3]); } } + +#[cfg(test)] +mod tests { + use binius_core::{fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; + use binius_hal::make_portable_backend; + use binius_hash::compress::Groestl256ByteCompression; + use binius_math::DefaultEvaluationDomainFactory; + use groestl_crypto::Groestl256; + + use super::test_plain_lookup; + use crate::builder::ConstraintSystemBuilder; + + #[test] + fn test_plain_u8_mul_lookup() { + const MAX_LOG_MULTIPLICITY: usize = 18; + let log_lookup_count = 19; + + let log_inv_rate = 1; + let security_bits = 20; + + let proof = { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + test_plain_lookup::test_u8_mul_lookup::( + &mut builder, + log_lookup_count, + ) + .unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + // validating witness with `validate_witness` is too slow for large transparents like the `table` + + let domain_factory = DefaultEvaluationDomainFactory::default(); + let backend = make_portable_backend(); + + binius_core::constraint_system::prove::< + crate::builder::types::U, + CanonicalTowerFamily, + _, + Groestl256, + Groestl256ByteCompression, + HasherChallenger, + _, + >( + &constraint_system, + log_inv_rate, + security_bits, + &[], + witness, + &domain_factory, + &backend, + ) + .unwrap() + }; + + // verify + { + let mut builder = ConstraintSystemBuilder::new(); + + test_plain_lookup::test_u8_mul_lookup::( + &mut builder, + log_lookup_count, + ) + .unwrap(); + + let constraint_system = builder.build().unwrap(); + + binius_core::constraint_system::verify::< + crate::builder::types::U, + CanonicalTowerFamily, + Groestl256, + Groestl256ByteCompression, + HasherChallenger, + >(&constraint_system, log_inv_rate, security_bits, &[], proof) + .unwrap(); + } + } +} diff --git a/crates/circuits/src/sha256.rs b/crates/circuits/src/sha256.rs index d2e8fe6e9..53b4a33b2 100644 --- a/crates/circuits/src/sha256.rs +++ b/crates/circuits/src/sha256.rs @@ -5,16 +5,21 @@ use binius_core::{ transparent::multilinear_extension::MultilinearExtensionTransparent, }; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::{UnderlierType, WithUnderlier}, - BinaryField1b, PackedField, TowerField, + as_packed_field::PackedType, underlier::WithUnderlier, BinaryField1b, Field, PackedField, + TowerField, }; use binius_macros::arith_expr; use binius_utils::checked_arithmetics::checked_log_2; use bytemuck::{pod_collect_to_vec, Pod}; use itertools::izip; -use crate::{arithmetic, builder::ConstraintSystemBuilder}; +use crate::{ + arithmetic, + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, +}; const LOG_U32_BITS: usize = checked_log_2(32); @@ -41,15 +46,11 @@ pub enum RotateRightType { Logical, } -pub fn rotate_and_xor( +pub fn rotate_and_xor( log_size: usize, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, r: &[(OracleId, usize, RotateRightType)], -) -> Result -where - F: TowerField, - U: UnderlierType + Pod + PackScalar + PackScalar, -{ +) -> Result { let shifted_oracle_ids = r .iter() .map(|(oracle_id, shift, t)| { @@ -76,7 +77,7 @@ where let result_oracle_id = builder.add_linear_combination( format!("linear combination of {:?}", shifted_oracle_ids), log_size, - shifted_oracle_ids.iter().map(|s| (*s, F::ONE)), + shifted_oracle_ids.iter().map(|s| (*s, Field::ONE)), )?; if let Some(witness) = builder.witness() { @@ -116,16 +117,12 @@ where .collect() } -pub fn u32const_repeating( +pub fn u32const_repeating( log_size: usize, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, x: u32, name: &str, -) -> Result -where - F: TowerField, - U: UnderlierType + Pod + PackScalar + PackScalar, -{ +) -> Result { let brodcasted = vec![x; 1 << (PackedType::::LOG_WIDTH.saturating_sub(LOG_U32_BITS))]; let transparent_id = builder.add_transparent( @@ -152,15 +149,11 @@ where Ok(repeating_id) } -pub fn sha256( - builder: &mut ConstraintSystemBuilder, +pub fn sha256( + builder: &mut ConstraintSystemBuilder, input: [OracleId; 16], log_size: usize, -) -> Result<[OracleId; 8], anyhow::Error> -where - U: UnderlierType + Pod + PackScalar + PackScalar, - F: TowerField, -{ +) -> Result<[OracleId; 8], anyhow::Error> { if log_size < >::LOG_WIDTH { Err(anyhow::Error::msg("log_size too small"))? } @@ -319,3 +312,64 @@ where Ok(output) } + +#[cfg(test)] +mod tests { + use binius_core::oracle::OracleId; + use binius_field::{as_packed_field::PackedType, BinaryField1b}; + use sha2::{compress256, digest::generic_array::GenericArray}; + + use crate::{ + builder::{test_utils::test_circuit, types::U}, + unconstrained::unconstrained, + }; + + #[test] + fn test_sha256() { + test_circuit(|builder| { + let log_size = PackedType::::LOG_WIDTH; + let input: [OracleId; 16] = std::array::from_fn(|i| { + unconstrained::(builder, i, log_size).unwrap() + }); + let state_output = super::sha256(builder, input, log_size).unwrap(); + + if let Some(witness) = builder.witness() { + let input_witneses: [_; 16] = std::array::from_fn(|i| { + witness + .get::(input[i]) + .unwrap() + .as_slice::() + }); + + let output_witneses: [_; 8] = std::array::from_fn(|i| { + witness + .get::(state_output[i]) + .unwrap() + .as_slice::() + }); + + let mut generic_array_input = GenericArray::::default(); + + let n_compressions = input_witneses[0].len(); + + for j in 0..n_compressions { + for i in 0..16 { + for z in 0..4 { + generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; + } + } + + let mut output = crate::sha256::INIT; + compress256(&mut output, &[generic_array_input]); + + for i in 0..8 { + assert_eq!(output[i], output_witneses[i][j]); + } + } + } + + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/circuits/src/transparent.rs b/crates/circuits/src/transparent.rs index 16efed3ac..11e36c223 100644 --- a/crates/circuits/src/transparent.rs +++ b/crates/circuits/src/transparent.rs @@ -3,23 +3,20 @@ use binius_core::{oracle::OracleId, transparent}; use binius_field::{ as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, BinaryField1b, ExtensionField, PackedField, TowerField, }; -use bytemuck::Pod; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; -pub fn step_down( - builder: &mut ConstraintSystemBuilder, +pub fn step_down( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, index: usize, -) -> Result -where - U: UnderlierType + PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { let step_down = transparent::step_down::StepDown::new(log_size, index)?; let id = builder.add_transparent(name, step_down.clone())?; if let Some(witness) = builder.witness() { @@ -28,16 +25,12 @@ where Ok(id) } -pub fn step_up( - builder: &mut ConstraintSystemBuilder, +pub fn step_up( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, index: usize, -) -> Result -where - U: UnderlierType + PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { let step_up = transparent::step_up::StepUp::new(log_size, index)?; let id = builder.add_transparent(name, step_up.clone())?; if let Some(witness) = builder.witness() { @@ -46,14 +39,14 @@ where Ok(id) } -pub fn constant( - builder: &mut ConstraintSystemBuilder, +pub fn constant( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, value: FS, ) -> Result where - U: UnderlierType + PackScalar + PackScalar, + U: PackScalar, F: TowerField + ExtensionField, FS: TowerField, { @@ -68,13 +61,13 @@ where Ok(id) } -pub fn make_transparent( - builder: &mut ConstraintSystemBuilder, +pub fn make_transparent( + builder: &mut ConstraintSystemBuilder, name: impl ToString, values: &[FS], ) -> Result where - U: PackScalar + PackScalar, + U: PackScalar, F: TowerField + ExtensionField, FS: TowerField, { diff --git a/crates/circuits/src/u32fib.rs b/crates/circuits/src/u32fib.rs index 1f4197af4..59563fb7c 100644 --- a/crates/circuits/src/u32fib.rs +++ b/crates/circuits/src/u32fib.rs @@ -1,26 +1,22 @@ // Copyright 2024-2025 Irreducible Inc. use binius_core::oracle::{OracleId, ShiftVariant}; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, BinaryField32b, - ExtensionField, TowerField, -}; +use binius_field::{BinaryField1b, BinaryField32b, TowerField}; use binius_macros::arith_expr; use binius_maybe_rayon::prelude::*; -use bytemuck::Pod; use rand::{thread_rng, Rng}; -use crate::{arithmetic, builder::ConstraintSystemBuilder, transparent::step_down}; +use crate::{ + arithmetic, + builder::{types::F, ConstraintSystemBuilder}, + transparent::step_down, +}; -pub fn u32fib( - builder: &mut ConstraintSystemBuilder, +pub fn u32fib( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar + PackScalar, - F: TowerField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let current = builder.add_committed("current", log_size, BinaryField1b::TOWER_LEVEL); let next = builder.add_shifted("next", current, 32, log_size, ShiftVariant::LogicalRight)?; @@ -75,3 +71,18 @@ where builder.pop_namespace(); Ok(current) } + +#[cfg(test)] +mod tests { + use crate::builder::test_utils::test_circuit; + + #[test] + fn test_u32fib() { + test_circuit(|builder| { + let log_size_1b = 14; + let _ = super::u32fib(builder, "u32fib", log_size_1b)?; + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/circuits/src/unconstrained.rs b/crates/circuits/src/unconstrained.rs index a798ec5f1..f39f48cda 100644 --- a/crates/circuits/src/unconstrained.rs +++ b/crates/circuits/src/unconstrained.rs @@ -1,21 +1,22 @@ // Copyright 2024-2025 Irreducible Inc. use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, ExtensionField, TowerField, -}; +use binius_field::{as_packed_field::PackScalar, ExtensionField, TowerField}; use binius_maybe_rayon::prelude::*; use bytemuck::Pod; use rand::{thread_rng, Rng}; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; -pub fn unconstrained( - builder: &mut ConstraintSystemBuilder, +pub fn unconstrained( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, ) -> Result where - U: UnderlierType + Pod + PackScalar + PackScalar, + U: PackScalar + Pod, F: TowerField + ExtensionField, FS: TowerField, { @@ -33,3 +34,31 @@ where Ok(rng) } + +// Same as 'unconstrained' but uses some pre-defined values instead of a random ones +pub fn fixed_u32( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + log_size: usize, + value: Vec, +) -> Result +where + U: PackScalar + Pod, + F: TowerField + ExtensionField, + FS: TowerField, +{ + let rng = builder.add_committed(name, log_size, FS::TOWER_LEVEL); + + if let Some(witness) = builder.witness() { + witness + .new_column::(rng) + .as_mut_slice::() + .into_par_iter() + .zip(value.into_par_iter()) + .for_each(|(data, value)| { + *data = value; + }); + } + + Ok(rng) +} diff --git a/crates/circuits/src/vision.rs b/crates/circuits/src/vision.rs index 5e8fe20a0..182566db5 100644 --- a/crates/circuits/src/vision.rs +++ b/crates/circuits/src/vision.rs @@ -12,36 +12,22 @@ use std::array; use anyhow::Result; use binius_core::{oracle::OracleId, transparent::constant::Constant}; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - linear_transformation::Transformation, - make_aes_to_binary_packed_transformer, - packed::get_packed_slice, - underlier::UnderlierType, - BinaryField1b, BinaryField32b, BinaryField64b, ExtensionField, PackedAESBinaryField8x32b, - PackedBinaryField8x32b, PackedField, TowerField, + linear_transformation::Transformation, make_aes_to_binary_packed_transformer, + packed::get_packed_slice, BinaryField1b, BinaryField32b, ExtensionField, Field, + PackedAESBinaryField8x32b, PackedBinaryField8x32b, PackedField, TowerField, }; use binius_hash::{Vision32MDSTransform, INV_PACKED_TRANS_AES}; use binius_macros::arith_expr; use binius_math::ArithExpr; -use bytemuck::{must_cast_slice, Pod}; +use bytemuck::must_cast_slice; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; -pub fn vision_permutation( - builder: &mut ConstraintSystemBuilder, +pub fn vision_permutation( + builder: &mut ConstraintSystemBuilder, log_size: usize, p_in: [OracleId; STATE_SIZE], -) -> Result<[OracleId; STATE_SIZE]> -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - F: TowerField + ExtensionField + ExtensionField, - PackedType: Pod, -{ +) -> Result<[OracleId; STATE_SIZE]> { // This only acts as a shorthand type B32 = BinaryField32b; @@ -77,7 +63,7 @@ where } let perm_out = (0..N_ROUNDS).try_fold(round_0_input, |state, round_i| { - vision_round::(builder, log_size, round_i, state) + vision_round(builder, log_size, round_i, state) })?; #[cfg(debug_assertions)] @@ -237,22 +223,13 @@ fn inv_constraint_expr() -> Result> { Ok(non_zero_case * zero_case) } -fn vision_round( - builder: &mut ConstraintSystemBuilder, +fn vision_round( + builder: &mut ConstraintSystemBuilder, log_size: usize, round_i: usize, perm_in: [OracleId; STATE_SIZE], ) -> Result<[OracleId; STATE_SIZE]> -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - F: TowerField + ExtensionField + ExtensionField, - PackedType: Pod, -{ +where { builder.push_namespace(format!("round[{round_i}]")); let inv_0 = builder.add_committed_multiple::( "inv_evens", @@ -318,7 +295,10 @@ where .add_linear_combination( format!("round_out_evens_{}", row), log_size, - [(mds_out_0[row], F::ONE), (even_round_consts[row], F::ONE)], + [ + (mds_out_0[row], Field::ONE), + (even_round_consts[row], Field::ONE), + ], ) .unwrap() }); @@ -328,7 +308,10 @@ where .add_linear_combination( format!("round_out_odd_{}", row), log_size, - [(mds_out_1[row], F::ONE), (odd_round_consts[row], F::ONE)], + [ + (mds_out_1[row], Field::ONE), + (odd_round_consts[row], Field::ONE), + ], ) .unwrap() }); @@ -476,3 +459,25 @@ where Ok(perm_out) } + +#[cfg(test)] +mod tests { + use binius_core::oracle::OracleId; + use binius_field::BinaryField32b; + + use super::vision_permutation; + use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained}; + + #[test] + fn test_vision32b() { + test_circuit(|builder| { + let log_size = 8; + let state_in: [OracleId; 24] = std::array::from_fn(|i| { + unconstrained::(builder, format!("p_in[{i}]"), log_size).unwrap() + }); + let _state_out = vision_permutation(builder, log_size, state_in).unwrap(); + Ok(vec![]) + }) + .unwrap(); + } +} diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 113cac7a9..7f74200ae 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -10,6 +10,7 @@ workspace = true [dependencies] assert_matches.workspace = true auto_impl.workspace = true +binius_macros = { path = "../macros" } binius_field = { path = "../field" } binius_hal = { path = "../hal" } binius_hash = { path = "../hash" } @@ -23,6 +24,7 @@ derive_more.workspace = true digest.workspace = true either.workspace = true getset.workspace = true +inventory.workspace = true itertools.workspace = true rand.workspace = true stackalloc.workspace = true diff --git a/crates/core/benches/composition_poly.rs b/crates/core/benches/composition_poly.rs index 845c865b4..166dcfe51 100644 --- a/crates/core/benches/composition_poly.rs +++ b/crates/core/benches/composition_poly.rs @@ -8,7 +8,7 @@ use binius_field::{ PackedField, }; use binius_macros::{arith_circuit_poly, composition_poly}; -use binius_math::{ArithExpr as Expr, CompositionPolyOS}; +use binius_math::{ArithExpr as Expr, CompositionPoly}; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use rand::{thread_rng, RngCore}; @@ -28,7 +28,7 @@ fn generate_input_data(mut rng: impl RngCore) -> Vec> { fn evaluate_arith_circuit_poly( query: &[&[P]], - arith_circuit_poly: &impl CompositionPolyOS

, + arith_circuit_poly: &impl CompositionPoly

, ) { for i in 0..BATCH_SIZE { let result = arith_circuit_poly diff --git a/crates/core/src/composition/index.rs b/crates/core/src/composition/index.rs index 0eff19493..2a918439b 100644 --- a/crates/core/src/composition/index.rs +++ b/crates/core/src/composition/index.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use binius_field::{Field, PackedField}; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::bail; use crate::polynomial::Error; @@ -34,7 +34,7 @@ impl IndexComposition { } } -impl, const N: usize> CompositionPolyOS

+impl, const N: usize> CompositionPoly

for IndexComposition { fn n_vars(&self) -> usize { @@ -159,7 +159,7 @@ mod tests { }; assert_eq!( - (&composition as &dyn CompositionPolyOS).expression(), + (&composition as &dyn CompositionPoly).expression(), ArithExpr::Add( Box::new(ArithExpr::Var(1)), Box::new(ArithExpr::Mul( diff --git a/crates/core/src/composition/product_composition.rs b/crates/core/src/composition/product_composition.rs index 245c28653..7fddf3312 100644 --- a/crates/core/src/composition/product_composition.rs +++ b/crates/core/src/composition/product_composition.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::PackedField; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::bail; #[derive(Debug, Default, Copy, Clone)] @@ -17,7 +17,7 @@ impl ProductComposition { } } -impl CompositionPolyOS

for ProductComposition { +impl CompositionPoly

for ProductComposition { fn n_vars(&self) -> usize { self.n_vars() } diff --git a/crates/core/src/constraint_system/channel.rs b/crates/core/src/constraint_system/channel.rs index 93a28cb09..01891f180 100644 --- a/crates/core/src/constraint_system/channel.rs +++ b/crates/core/src/constraint_system/channel.rs @@ -52,14 +52,14 @@ use std::collections::HashMap; use binius_field::{as_packed_field::PackScalar, underlier::UnderlierType, TowerField}; -use bytes::BufMut; +use binius_macros::{DeserializeBytes, SerializeBytes}; use super::error::{Error, VerificationError}; -use crate::{oracle::OracleId, transcript::TranscriptWriter, witness::MultilinearExtensionIndex}; +use crate::{oracle::OracleId, witness::MultilinearExtensionIndex}; pub type ChannelId = usize; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct Flush { pub oracles: Vec, pub channel_id: ChannelId, @@ -68,7 +68,7 @@ pub struct Flush { pub multiplicity: u64, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub struct Boundary { pub values: Vec, pub channel_id: ChannelId, @@ -76,7 +76,7 @@ pub struct Boundary { pub multiplicity: u64, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub enum FlushDirection { Push, Pull, @@ -220,26 +220,6 @@ impl Channel { } } -impl Boundary { - pub fn write_to(&self, writer: &mut TranscriptWriter) { - writer.buffer().put_u64(self.values.len() as u64); - writer.write_slice( - &self - .values - .iter() - .copied() - .map(F::Canonical::from) - .collect::>(), - ); - writer.buffer().put_u64(self.channel_id as u64); - writer.buffer().put_u64(self.multiplicity); - writer.buffer().put_u64(match self.direction { - FlushDirection::Pull => 0, - FlushDirection::Push => 1, - }); - } -} - #[cfg(test)] mod tests { use binius_field::BinaryField64b; diff --git a/crates/core/src/constraint_system/mod.rs b/crates/core/src/constraint_system/mod.rs index de178b31f..81719eac4 100644 --- a/crates/core/src/constraint_system/mod.rs +++ b/crates/core/src/constraint_system/mod.rs @@ -7,7 +7,9 @@ mod prove; pub mod validate; mod verify; -use binius_field::TowerField; +use binius_field::{BinaryField128b, TowerField}; +use binius_macros::SerializeBytes; +use binius_utils::{DeserializeBytes, SerializationError, SerializationMode}; use channel::{ChannelId, Flush}; pub use prove::prove; pub use verify::verify; @@ -21,7 +23,7 @@ use crate::oracle::{ConstraintSet, MultilinearOracleSet, OracleId}; /// /// As a result, a ConstraintSystem allows us to validate all of these /// constraints against a witness, as well as enabling generic prove/verify -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes)] pub struct ConstraintSystem { pub oracles: MultilinearOracleSet, pub table_constraints: Vec>, @@ -30,6 +32,24 @@ pub struct ConstraintSystem { pub max_channel_id: ChannelId, } +impl DeserializeBytes for ConstraintSystem { + fn deserialize( + mut read_buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(Self { + oracles: DeserializeBytes::deserialize(&mut read_buf, mode)?, + table_constraints: DeserializeBytes::deserialize(&mut read_buf, mode)?, + non_zero_oracle_ids: DeserializeBytes::deserialize(&mut read_buf, mode)?, + flushes: DeserializeBytes::deserialize(&mut read_buf, mode)?, + max_channel_id: DeserializeBytes::deserialize(&mut read_buf, mode)?, + }) + } +} + impl ConstraintSystem { pub const fn no_base_constraints(self) -> Self { self diff --git a/crates/core/src/constraint_system/prove.rs b/crates/core/src/constraint_system/prove.rs index d405405e1..a04539d5e 100644 --- a/crates/core/src/constraint_system/prove.rs +++ b/crates/core/src/constraint_system/prove.rs @@ -12,7 +12,7 @@ use binius_field::{ use binius_hal::ComputationBackend; use binius_hash::PseudoCompressionFunction; use binius_math::{ - ArithExpr, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEDirectAdapter, + EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly, }; use binius_maybe_rayon::prelude::*; @@ -54,7 +54,7 @@ use crate::{ }, }, ring_switch, - tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier, TowerFamily}, + tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier}, transcript::ProverTranscript, witness::{MultilinearExtensionIndex, MultilinearWitness}, }; @@ -104,12 +104,7 @@ where let fast_domain_factory = IsomorphicEvaluationDomainFactory::>::default(); let mut transcript = ProverTranscript::::new(); - { - let mut observer = transcript.observe(); - for boundary in boundaries { - boundary.write_to(&mut observer); - } - } + transcript.observe().write_slice(boundaries); let ConstraintSystem { mut oracles, @@ -241,7 +236,7 @@ where let (flush_oracle_ids, flush_selectors, flush_final_layer_claims) = reorder_for_flushing_by_n_vars( &oracles, - flush_oracle_ids, + &flush_oracle_ids, flush_selectors, flush_final_layer_claims, ); @@ -302,7 +297,7 @@ where .map(|multilinear| 7 - multilinear.log_extension_degree()), constraints .iter() - .map(|constraint| arith_expr_base_tower_level::(&constraint.composition)) + .map(|constraint| constraint.composition.binary_tower_level()) ) .max() .unwrap_or(0); @@ -423,7 +418,7 @@ where let system = ring_switch::EvalClaimSystem::new( &oracles, &commit_meta, - oracle_to_commit_index, + &oracle_to_commit_index, &eval_claims, )?; @@ -457,30 +452,6 @@ where }) } -fn arith_expr_base_tower_level(composition: &ArithExpr>) -> usize { - if composition.try_convert_field::().is_ok() { - return 0; - } - - if composition.try_convert_field::().is_ok() { - return 3; - } - - if composition.try_convert_field::().is_ok() { - return 4; - } - - if composition.try_convert_field::().is_ok() { - return 5; - } - - if composition.try_convert_field::().is_ok() { - return 6; - } - - 7 -} - type TypeErasedUnivariateZerocheck<'a, F> = Box + 'a>; type TypeErasedSumcheck<'a, F> = Box + 'a>; type TypeErasedProver<'a, F> = @@ -518,7 +489,7 @@ where P: PackedExtension + PackedExtension + PackedExtension, - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, { let univariate_prover = sumcheck::prove::constraint_set_zerocheck_prover::<_, _, FBase, _, _>( diff --git a/crates/core/src/constraint_system/verify.rs b/crates/core/src/constraint_system/verify.rs index 467efd938..415f909d3 100644 --- a/crates/core/src/constraint_system/verify.rs +++ b/crates/core/src/constraint_system/verify.rs @@ -4,7 +4,7 @@ use std::{cmp::Reverse, iter}; use binius_field::{BinaryField, PackedField, TowerField}; use binius_hash::PseudoCompressionFunction; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::{bail, checked_arithmetics::log2_ceil_usize}; use digest::{core_api::BlockSizeUser, Digest, Output}; use itertools::{izip, multiunzip, Itertools}; @@ -75,12 +75,7 @@ where let Proof { transcript } = proof; let mut transcript = VerifierTranscript::::new(transcript); - { - let mut observer = transcript.observe(); - for boundary in boundaries { - boundary.write_to(&mut observer); - } - } + transcript.observe().write_slice(boundaries); let merkle_scheme = BinaryMerkleTreeScheme::<_, Hash, _>::new(Compress::default()); let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?; @@ -155,7 +150,7 @@ where let (flush_oracle_ids, flush_selectors, flush_final_layer_claims) = reorder_for_flushing_by_n_vars( &oracles, - flush_oracle_ids, + &flush_oracle_ids, flush_selectors, flush_final_layer_claims, ); @@ -284,7 +279,7 @@ where let system = ring_switch::EvalClaimSystem::new( &oracles, &commit_meta, - oracle_to_commit_index, + &oracle_to_commit_index, &eval_claims, )?; @@ -315,7 +310,7 @@ pub fn max_n_vars_and_skip_rounds( ) -> (usize, usize) where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, { let max_n_vars = max_n_vars(zerocheck_claims); @@ -339,7 +334,7 @@ where fn max_n_vars(zerocheck_claims: &[ZerocheckClaim]) -> usize where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, { zerocheck_claims .iter() @@ -572,7 +567,7 @@ pub fn get_flush_dedup_sumcheck_metas( #[derive(Debug)] pub struct FlushSumcheckComposition; -impl CompositionPolyOS

for FlushSumcheckComposition { +impl CompositionPoly

for FlushSumcheckComposition { fn n_vars(&self) -> usize { 2 } @@ -644,7 +639,7 @@ pub fn get_post_flush_sumcheck_eval_claims_without_eq( Ok(evalcheck_claims) } -pub struct DedupSumcheckClaims> { +pub struct DedupSumcheckClaims> { sumcheck_claims: Vec>, gkr_eval_points: Vec>, flush_selectors_unique_by_claim: Vec>, @@ -654,7 +649,7 @@ pub struct DedupSumcheckClaims> #[allow(clippy::type_complexity)] pub fn get_flush_dedup_sumcheck_claims( flush_sumcheck_metas: Vec>, -) -> Result>, Error> { +) -> Result>, Error> { let n_claims = flush_sumcheck_metas.len(); let mut sumcheck_claims = Vec::with_capacity(n_claims); let mut gkr_eval_points = Vec::with_capacity(n_claims); @@ -698,7 +693,7 @@ pub fn get_flush_dedup_sumcheck_claims( pub fn reorder_for_flushing_by_n_vars( oracles: &MultilinearOracleSet, - flush_oracle_ids: Vec, + flush_oracle_ids: &[OracleId], flush_selectors: Vec, flush_final_layer_claims: Vec>, ) -> (Vec, Vec, Vec>) { diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index c7ebeab0e..48f1f060d 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -7,7 +7,7 @@ //! performance, while verifier-side functions are optimized for auditability and security. // This is to silence clippy errors around suspicious usage of XOR -// in our arithmetic. This is safe to do becasue we're operating +// in our arithmetic. This is safe to do because we're operating // over binary fields. #![allow(clippy::suspicious_arithmetic_impl)] #![allow(clippy::suspicious_op_assign_impl)] @@ -28,3 +28,5 @@ pub mod tower; pub mod transcript; pub mod transparent; pub mod witness; + +pub use inventory; diff --git a/crates/core/src/merkle_tree/binary_merkle_tree.rs b/crates/core/src/merkle_tree/binary_merkle_tree.rs index f15d689be..86383830d 100644 --- a/crates/core/src/merkle_tree/binary_merkle_tree.rs +++ b/crates/core/src/merkle_tree/binary_merkle_tree.rs @@ -2,10 +2,12 @@ use std::{array, fmt::Debug, mem::MaybeUninit}; -use binius_field::{serialize_canonical, TowerField}; +use binius_field::TowerField; use binius_hash::{HashBuffer, PseudoCompressionFunction}; use binius_maybe_rayon::{prelude::*, slice::ParallelSlice}; -use binius_utils::{bail, checked_arithmetics::log2_strict_usize}; +use binius_utils::{ + bail, checked_arithmetics::log2_strict_usize, SerializationMode, SerializeBytes, +}; use digest::{crypto_common::BlockSizeUser, Digest, FixedOutputReset, Output}; use tracing::instrument; @@ -210,7 +212,8 @@ where { let mut hash_buffer = HashBuffer::new(hasher); for elem in elems { - serialize_canonical(elem, &mut hash_buffer) + let mode = SerializationMode::CanonicalTower; + SerializeBytes::serialize(&elem, &mut hash_buffer, mode) .expect("HashBuffer has infinite capacity"); } } diff --git a/crates/core/src/merkle_tree/scheme.rs b/crates/core/src/merkle_tree/scheme.rs index c29940cf1..5055e3549 100644 --- a/crates/core/src/merkle_tree/scheme.rs +++ b/crates/core/src/merkle_tree/scheme.rs @@ -2,11 +2,12 @@ use std::{array, fmt::Debug, marker::PhantomData}; -use binius_field::{serialize_canonical, TowerField}; +use binius_field::TowerField; use binius_hash::{HashBuffer, PseudoCompressionFunction}; use binius_utils::{ bail, checked_arithmetics::{log2_ceil_usize, log2_strict_usize}, + SerializationMode, SerializeBytes, }; use bytes::Buf; use digest::{core_api::BlockSizeUser, Digest, Output}; @@ -108,42 +109,36 @@ where fn verify_opening( &self, - index: usize, + mut index: usize, values: &[F], layer_depth: usize, tree_depth: usize, layer_digests: &[Self::Digest], proof: &mut TranscriptReader, ) -> Result<(), Error> { - if 1 << layer_depth != layer_digests.len() { - bail!(VerificationError::IncorrectVectorLength) + if (1 << layer_depth) != layer_digests.len() { + bail!(VerificationError::IncorrectVectorLength); } - if index > (1 << tree_depth) - 1 { + if index >= (1 << tree_depth) { bail!(Error::IndexOutOfRange { - max: (1 << tree_depth) - 1, + max: (1 << tree_depth) - 1 }); } - let leaf_digest = hash_field_elems::<_, H>(values); - let branch = proof.read_vec(tree_depth - layer_depth)?; - - let mut index = index; - let root = branch.into_iter().fold(leaf_digest, |node, branch_node| { - let next_node = if index & 1 == 0 { - self.compression.compress([node, branch_node]) + let mut leaf_digest = hash_field_elems::<_, H>(values); + for branch_node in proof.read_vec(tree_depth - layer_depth)? { + leaf_digest = self.compression.compress(if index & 1 == 0 { + [leaf_digest, branch_node] } else { - self.compression.compress([branch_node, node]) - }; + [branch_node, leaf_digest] + }); index >>= 1; - next_node - }); - - if root == layer_digests[index] { - Ok(()) - } else { - bail!(VerificationError::InvalidProof) } + + (leaf_digest == layer_digests[index]) + .then_some(()) + .ok_or_else(|| VerificationError::InvalidProof.into()) } } @@ -178,8 +173,10 @@ where let mut hasher = H::new(); { let mut buffer = HashBuffer::new(&mut hasher); - for &elem in elems { - serialize_canonical(elem, &mut buffer).expect("HashBuffer has infinite capacity"); + for elem in elems { + let mode = SerializationMode::CanonicalTower; + SerializeBytes::serialize(elem, &mut buffer, mode) + .expect("HashBuffer has infinite capacity"); } } hasher.finalize() diff --git a/crates/core/src/oracle/composite.rs b/crates/core/src/oracle/composite.rs index 7c904deaf..d90102a22 100644 --- a/crates/core/src/oracle/composite.rs +++ b/crates/core/src/oracle/composite.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use binius_field::TowerField; -use binius_math::CompositionPolyOS; +use binius_math::CompositionPoly; use binius_utils::bail; use crate::oracle::{Error, MultilinearPolyOracle, OracleId}; @@ -12,11 +12,11 @@ use crate::oracle::{Error, MultilinearPolyOracle, OracleId}; pub struct CompositePolyOracle { n_vars: usize, inner: Vec>, - composition: Arc>, + composition: Arc>, } impl CompositePolyOracle { - pub fn new + 'static>( + pub fn new + 'static>( n_vars: usize, inner: Vec>, composition: C, @@ -67,7 +67,7 @@ impl CompositePolyOracle { self.inner.clone() } - pub fn composition(&self) -> Arc> { + pub fn composition(&self) -> Arc> { self.composition.clone() } } @@ -82,7 +82,7 @@ mod tests { #[derive(Clone, Debug)] struct TestByteComposition; - impl CompositionPolyOS for TestByteComposition { + impl CompositionPoly for TestByteComposition { fn n_vars(&self) -> usize { 3 } diff --git a/crates/core/src/oracle/constraint.rs b/crates/core/src/oracle/constraint.rs index 4a9d0225a..df6833945 100644 --- a/crates/core/src/oracle/constraint.rs +++ b/crates/core/src/oracle/constraint.rs @@ -4,7 +4,8 @@ use core::iter::IntoIterator; use std::sync::Arc; use binius_field::{Field, TowerField}; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_macros::{DeserializeBytes, SerializeBytes}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::bail; use itertools::Itertools; @@ -12,26 +13,26 @@ use super::{Error, MultilinearOracleSet, MultilinearPolyVariant, OracleId}; /// Composition trait object that can be used to create lists of compositions of differing /// concrete types. -pub type TypeErasedComposition

= Arc>; +pub type TypeErasedComposition

= Arc>; /// Constraint is a type erased composition along with a predicate on its values on the boolean hypercube -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct Constraint { - pub name: Arc, + pub name: String, pub composition: ArithExpr, pub predicate: ConstraintPredicate, } /// Predicate can either be a sum of values of a composition on the hypercube (sumcheck) or equality to zero /// on the hypercube (zerocheck) -#[derive(Clone, Debug)] +#[derive(Clone, Debug, SerializeBytes, DeserializeBytes)] pub enum ConstraintPredicate { Sum(F), Zero, } /// Constraint set is a group of constraints that operate over the same set of oracle-identified multilinears -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct ConstraintSet { pub n_vars: usize, pub oracle_ids: Vec, @@ -41,7 +42,7 @@ pub struct ConstraintSet { // A deferred constraint constructor that instantiates index composition after the superset of oracles is known #[allow(clippy::type_complexity)] struct UngroupedConstraint { - name: Arc, + name: String, oracle_ids: Vec, composition: ArithExpr, predicate: ConstraintPredicate, @@ -82,7 +83,7 @@ impl ConstraintSetBuilder { composition: ArithExpr, ) { self.constraints.push(UngroupedConstraint { - name: name.to_string().into(), + name: name.to_string(), oracle_ids: oracle_ids.into_iter().collect(), composition, predicate: ConstraintPredicate::Zero, diff --git a/crates/core/src/oracle/multilinear.rs b/crates/core/src/oracle/multilinear.rs index bfba63489..46d391681 100644 --- a/crates/core/src/oracle/multilinear.rs +++ b/crates/core/src/oracle/multilinear.rs @@ -2,8 +2,10 @@ use std::{array, fmt::Debug, sync::Arc}; -use binius_field::{Field, TowerField}; -use binius_utils::bail; +use binius_field::{BinaryField128b, Field, TowerField}; +use binius_macros::{DeserializeBytes, SerializeBytes}; +use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes}; +use bytes::Buf; use getset::{CopyGetters, Getters}; use crate::{ @@ -280,9 +282,20 @@ impl MultilinearOracleSetAddition<'_, F> { /// /// The oracle set also tracks the committed polynomial in batches where each batch is committed /// together with a polynomial commitment scheme. -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, SerializeBytes)] pub struct MultilinearOracleSet { - oracles: Vec>>, + oracles: Vec>, +} + +impl DeserializeBytes for MultilinearOracleSet { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + Ok(Self { + oracles: DeserializeBytes::deserialize(read_buf, mode)?, + }) + } } impl MultilinearOracleSet { @@ -323,12 +336,11 @@ impl MultilinearOracleSet { oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle, ) -> OracleId { let id = self.oracles.len(); - - self.oracles.push(Arc::new(oracle(id))); + self.oracles.push(oracle(id)); id } - fn get_from_set(&self, id: OracleId) -> Arc> { + fn get_from_set(&self, id: OracleId) -> MultilinearPolyOracle { self.oracles[id].clone() } @@ -401,7 +413,7 @@ impl MultilinearOracleSet { } pub fn oracle(&self, id: OracleId) -> MultilinearPolyOracle { - (*self.oracles[id]).clone() + self.oracles[id].clone() } pub fn n_vars(&self, id: OracleId) -> usize { @@ -438,7 +450,7 @@ impl MultilinearOracleSet { /// other oracles. This is formalized in [DP23] Section 4. /// /// [DP23]: -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)] pub struct MultilinearPolyOracle { pub id: OracleId, pub name: Option, @@ -447,7 +459,25 @@ pub struct MultilinearPolyOracle { pub variant: MultilinearPolyVariant, } -#[derive(Debug, Clone, PartialEq, Eq)] +impl DeserializeBytes for MultilinearPolyOracle { + fn deserialize( + mut read_buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(Self { + id: DeserializeBytes::deserialize(&mut read_buf, mode)?, + name: DeserializeBytes::deserialize(&mut read_buf, mode)?, + n_vars: DeserializeBytes::deserialize(&mut read_buf, mode)?, + tower_level: DeserializeBytes::deserialize(&mut read_buf, mode)?, + variant: DeserializeBytes::deserialize(&mut read_buf, mode)?, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)] pub enum MultilinearPolyVariant { Committed, Transparent(TransparentPolyOracle), @@ -459,6 +489,36 @@ pub enum MultilinearPolyVariant { ZeroPadded(OracleId), } +impl DeserializeBytes for MultilinearPolyVariant { + fn deserialize( + mut buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(match u8::deserialize(&mut buf, mode)? { + 0 => Self::Committed, + 1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?), + 2 => Self::Repeating { + id: DeserializeBytes::deserialize(&mut buf, mode)?, + log_count: DeserializeBytes::deserialize(buf, mode)?, + }, + 3 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?), + 4 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?), + 5 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?), + 6 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?), + 7 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?), + variant_index => { + return Err(SerializationError::UnknownEnumVariant { + name: "MultilinearPolyVariant", + index: variant_index, + }); + } + }) + } +} + /// A transparent multilinear polynomial oracle. /// /// See the [`MultilinearPolyOracle`] documentation for context. @@ -468,6 +528,30 @@ pub struct TransparentPolyOracle { poly: Arc>, } +impl SerializeBytes for TransparentPolyOracle { + fn serialize( + &self, + mut write_buf: impl bytes::BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + self.poly.erased_serialize(&mut write_buf, mode) + } +} + +impl DeserializeBytes for TransparentPolyOracle { + fn deserialize( + read_buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(Self { + poly: Box::>::deserialize(read_buf, mode)?.into(), + }) + } +} + impl TransparentPolyOracle { fn new(poly: Arc>) -> Result { if poly.binary_tower_level() > F::TOWER_LEVEL { @@ -494,13 +578,13 @@ impl PartialEq for TransparentPolyOracle { impl Eq for TransparentPolyOracle {} -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub enum ProjectionVariant { FirstVars, LastVars, } -#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)] +#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)] pub struct Projected { #[get_copy = "pub"] id: OracleId, @@ -530,14 +614,14 @@ impl Projected { } } -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub enum ShiftVariant { CircularLeft, LogicalLeft, LogicalRight, } -#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)] +#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)] pub struct Shifted { #[get_copy = "pub"] id: OracleId, @@ -579,7 +663,7 @@ impl Shifted { } } -#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)] +#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)] pub struct Packed { #[get_copy = "pub"] id: OracleId, @@ -593,7 +677,7 @@ pub struct Packed { log_degree: usize, } -#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)] +#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)] pub struct LinearCombination { #[get_copy = "pub"] n_vars: usize, @@ -606,7 +690,7 @@ impl LinearCombination { fn new( n_vars: usize, offset: F, - inner: impl IntoIterator>, F)>, + inner: impl IntoIterator, F)>, ) -> Result { let inner = inner .into_iter() diff --git a/crates/core/src/piop/prove.rs b/crates/core/src/piop/prove.rs index 9faed21c9..2d1331ec6 100644 --- a/crates/core/src/piop/prove.rs +++ b/crates/core/src/piop/prove.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{ - packed::set_packed_slice, BinaryField, ExtensionField, Field, PackedExtension, PackedField, + packed::set_packed_slice, BinaryField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField, }; use binius_hal::ComputationBackend; @@ -10,7 +10,7 @@ use binius_math::{ }; use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*}; use binius_ntt::{NTTOptions, ThreadingSettings}; -use binius_utils::{bail, serialization::SerializeBytes, sorting::is_sorted_ascending}; +use binius_utils::{bail, sorting::is_sorted_ascending, SerializeBytes}; use either::Either; use itertools::{chain, Itertools}; @@ -101,7 +101,7 @@ pub fn commit( multilins: &[M], ) -> Result, Error> where - F: BinaryField + ExtensionField, + F: BinaryField, FEncode: BinaryField, P: PackedField + PackedExtension, M: MultilinearPoly

, @@ -133,7 +133,7 @@ where let rs_code = ReedSolomonCode::new( fri_params.rs_code().log_dim(), fri_params.rs_code().log_inv_rate(), - NTTOptions { + &NTTOptions { precompute_twiddles: true, thread_settings: ThreadingSettings::MultithreadedDefault, }, @@ -166,7 +166,7 @@ pub fn prove Result<(), Error> where - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, FDomain: Field, FEncode: BinaryField, P: PackedFieldIndexable @@ -234,7 +234,7 @@ where merkle_prover, sumcheck_provers, codeword, - committed, + &committed, transcript, )?; @@ -247,11 +247,11 @@ fn prove_interleaved_fri_sumcheck>, codeword: &[P], - committed: MTProver::Committed, + committed: &MTProver::Committed, transcript: &mut ProverTranscript, ) -> Result<(), Error> where - F: TowerField + ExtensionField, + F: TowerField, FEncode: BinaryField, P: PackedFieldIndexable + PackedExtension, MTScheme: MerkleTreeScheme, @@ -259,7 +259,7 @@ where Challenger_: Challenger, { let mut fri_prover = - FRIFolder::new(fri_params, merkle_prover, P::unpack_scalars(codeword), &committed)?; + FRIFolder::new(fri_params, merkle_prover, P::unpack_scalars(codeword), committed)?; let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?; diff --git a/crates/core/src/piop/tests.rs b/crates/core/src/piop/tests.rs index 69f8a0548..67a52273b 100644 --- a/crates/core/src/piop/tests.rs +++ b/crates/core/src/piop/tests.rs @@ -3,15 +3,15 @@ use std::iter::repeat_with; use binius_field::{ - BinaryField, BinaryField16b, BinaryField8b, ExtensionField, Field, PackedBinaryField2x128b, - PackedExtension, PackedField, PackedFieldIndexable, TowerField, + BinaryField, BinaryField16b, BinaryField8b, Field, PackedBinaryField2x128b, PackedExtension, + PackedField, PackedFieldIndexable, TowerField, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::{ DefaultEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly, }; -use binius_utils::serialization::{DeserializeBytes, SerializeBytes}; +use binius_utils::{DeserializeBytes, SerializeBytes}; use groestl_crypto::Groestl256; use rand::{rngs::StdRng, Rng, SeedableRng}; @@ -104,7 +104,7 @@ fn commit_prove_verify( merkle_prover: &impl MerkleTreeProver, log_inv_rate: usize, ) where - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, FDomain: BinaryField, FEncode: BinaryField, P: PackedFieldIndexable diff --git a/crates/core/src/piop/verify.rs b/crates/core/src/piop/verify.rs index 96858f7fd..a0c5a5e7c 100644 --- a/crates/core/src/piop/verify.rs +++ b/crates/core/src/piop/verify.rs @@ -5,7 +5,7 @@ use std::{borrow::Borrow, cmp::Ordering, iter, ops::Range}; use binius_field::{BinaryField, ExtensionField, Field, TowerField}; use binius_math::evaluate_piecewise_multilinear; use binius_ntt::NTTOptions; -use binius_utils::{bail, serialization::DeserializeBytes}; +use binius_utils::{bail, DeserializeBytes}; use getset::CopyGetters; use tracing::instrument; @@ -137,7 +137,7 @@ where let log_batch_size = fold_arities.first().copied().unwrap_or(0); let log_dim = commit_meta.total_vars - log_batch_size; - let rs_code = ReedSolomonCode::new(log_dim, log_inv_rate, NTTOptions::default())?; + let rs_code = ReedSolomonCode::new(log_dim, log_inv_rate, &NTTOptions::default())?; let n_test_queries = fri::calculate_n_test_queries::(security_bits, &rs_code)?; let fri_params = FRIParams::new(rs_code, log_batch_size, fold_arities, n_test_queries)?; Ok(fri_params) diff --git a/crates/core/src/polynomial/arith_circuit.rs b/crates/core/src/polynomial/arith_circuit.rs index 3d7e1862b..2bec2030e 100644 --- a/crates/core/src/polynomial/arith_circuit.rs +++ b/crates/core/src/polynomial/arith_circuit.rs @@ -3,14 +3,12 @@ use std::{fmt::Debug, mem::MaybeUninit, sync::Arc}; use binius_field::{ExtensionField, Field, PackedField, TowerField}; -use binius_math::{ArithExpr, CompositionPoly, CompositionPolyOS, Error}; +use binius_math::{ArithExpr, CompositionPoly, Error}; use stackalloc::{ helpers::{slice_assume_init, slice_assume_init_mut}, stackalloc_uninit, }; -use super::MultivariatePoly; - /// Convert the expression to a sequence of arithmetic operations that can be evaluated in sequence. fn circuit_steps_for_expr( expr: &ArithExpr, @@ -50,10 +48,22 @@ fn circuit_steps_for_expr( result.push(CircuitStep::Mul(left, right)); CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1)) } - ArithExpr::Pow(id, exp) => { - let id = to_circuit_inner(id, result); - result.push(CircuitStep::Pow(id, *exp)); - CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1)) + ArithExpr::Pow(base, exp) => { + let mut acc = to_circuit_inner(base, result); + let base_expr = acc; + let highest_bit = exp.ilog2(); + + for i in (0..highest_bit).rev() { + result.push(CircuitStep::Square(acc)); + acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1)); + + if (exp >> i) & 1 != 0 { + result.push(CircuitStep::Mul(acc, base_expr)); + acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1)); + } + } + + acc } } } @@ -63,7 +73,7 @@ fn circuit_steps_for_expr( } /// Input of the circuit calculation step -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] enum CircuitNode { /// Input variable Var(usize), @@ -87,7 +97,7 @@ impl CircuitNode { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] enum CircuitStepArgument { Expr(CircuitNode), Const(F), @@ -101,15 +111,15 @@ enum CircuitStepArgument { enum CircuitStep { Add(CircuitStepArgument, CircuitStepArgument), Mul(CircuitStepArgument, CircuitStepArgument), - Pow(CircuitStepArgument, u64), + Square(CircuitStepArgument), AddMul(usize, CircuitStepArgument, CircuitStepArgument), } /// Describes polynomial evaluations using a directed acyclic graph of expressions. /// -/// This is meant as an alternative to a hard-coded CompositionPolyOS. +/// This is meant as an alternative to a hard-coded CompositionPoly. /// -/// The advantage over a hard coded CompositionPolyOS is that this can be constructed and manipulated dynamically at runtime +/// The advantage over a hard coded CompositionPoly is that this can be constructed and manipulated dynamically at runtime /// and the object representing different polnomials can be stored in a homogeneous collection. #[derive(Debug, Clone)] pub struct ArithCircuitPoly { @@ -119,12 +129,14 @@ pub struct ArithCircuitPoly { retval: CircuitStepArgument, degree: usize, n_vars: usize, + tower_level: usize, } -impl ArithCircuitPoly { +impl ArithCircuitPoly { pub fn new(expr: ArithExpr) -> Self { let degree = expr.degree(); let n_vars = expr.n_vars(); + let tower_level = expr.binary_tower_level(); let (exprs, retval) = circuit_steps_for_expr(&expr); Self { @@ -133,6 +145,7 @@ impl ArithCircuitPoly { retval, degree, n_vars, + tower_level, } } @@ -142,6 +155,7 @@ impl ArithCircuitPoly { /// arithmetic expression. pub fn with_n_vars(n_vars: usize, expr: ArithExpr) -> Result { let degree = expr.degree(); + let tower_level = expr.binary_tower_level(); if n_vars < expr.n_vars() { return Err(Error::IncorrectNumberOfVariables { expected: expr.n_vars(), @@ -156,11 +170,14 @@ impl ArithCircuitPoly { retval, n_vars, degree, + tower_level, }) } } -impl CompositionPoly for ArithCircuitPoly { +impl>> CompositionPoly

+ for ArithCircuitPoly +{ fn degree(&self) -> usize { self.degree } @@ -170,14 +187,14 @@ impl CompositionPoly for ArithCircuitPoly { } fn binary_tower_level(&self) -> usize { - F::TOWER_LEVEL + self.tower_level } - fn expression>(&self) -> ArithExpr { + fn expression(&self) -> ArithExpr { self.expr.convert_field() } - fn evaluate>>(&self, query: &[P]) -> Result { + fn evaluate(&self, query: &[P]) -> Result { if query.len() != self.n_vars { return Err(Error::IncorrectQuerySize { expected: self.n_vars, @@ -226,8 +243,8 @@ impl CompositionPoly for ArithCircuitPoly { after, get_argument_value(*x, before) * get_argument_value(*y, before), ), - CircuitStep::Pow(id, exp) => { - write_result(after, get_argument_value(*id, before).pow(*exp)) + CircuitStep::Square(x) => { + write_result(after, get_argument_value(*x, before).square()) } }; } @@ -241,11 +258,7 @@ impl CompositionPoly for ArithCircuitPoly { }) } - fn batch_evaluate>>( - &self, - batch_query: &[&[P]], - evals: &mut [P], - ) -> Result<(), Error> { + fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> { let row_len = evals.len(); if batch_query.iter().any(|row| row.len() != row_len) { return Err(Error::BatchEvaluateSizeMismatch); @@ -285,29 +298,31 @@ impl CompositionPoly for ArithCircuitPoly { }, ); } - CircuitStep::Pow(id, exp) => match id { - CircuitStepArgument::Expr(id) => { - let id = id.get_sparse_chunk(batch_query, before, row_len); - for j in 0..row_len { - // Safety: `current` and `id` have length equal to `row_len` - unsafe { - current - .get_unchecked_mut(j) - .write(id.get_unchecked(j).pow(*exp)); + CircuitStep::Square(arg) => { + match arg { + CircuitStepArgument::Expr(node) => { + let id_chunk = node.get_sparse_chunk(batch_query, before, row_len); + for j in 0..row_len { + // Safety: `current` and `id_chunk` have length equal to `row_len` + unsafe { + current + .get_unchecked_mut(j) + .write(id_chunk.get_unchecked(j).square()); + } } } - } - CircuitStepArgument::Const(id) => { - let id: P = P::broadcast((*id).into()); - let result = id.pow(*exp); - for j in 0..row_len { - // Safety: `current` has length equal to `row_len` - unsafe { - current.get_unchecked_mut(j).write(result); + CircuitStepArgument::Const(value) => { + let value: P = P::broadcast((*value).into()); + let result = value.square(); + for j in 0..row_len { + // Safety: `current` has length equal to `row_len` + unsafe { + current.get_unchecked_mut(j).write(result); + } } } } - }, + } CircuitStep::AddMul(target, left, right) => { let target = &before[row_len * target..(target + 1) * row_len]; // Safety: by construction of steps and evaluation order we know @@ -349,52 +364,6 @@ impl CompositionPoly for ArithCircuitPoly { } } -impl>> CompositionPolyOS

- for ArithCircuitPoly -{ - fn degree(&self) -> usize { - CompositionPoly::degree(self) - } - - fn n_vars(&self) -> usize { - CompositionPoly::n_vars(self) - } - - fn expression(&self) -> ArithExpr { - self.expr.convert_field() - } - - fn binary_tower_level(&self) -> usize { - CompositionPoly::binary_tower_level(self) - } - - fn evaluate(&self, query: &[P]) -> Result { - CompositionPoly::evaluate(self, query) - } - - fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> { - CompositionPoly::batch_evaluate(self, batch_query, evals) - } -} - -impl MultivariatePoly for ArithCircuitPoly { - fn degree(&self) -> usize { - CompositionPoly::degree(&self) - } - - fn n_vars(&self) -> usize { - CompositionPoly::n_vars(&self) - } - - fn binary_tower_level(&self) -> usize { - CompositionPoly::binary_tower_level(&self) - } - - fn evaluate(&self, query: &[F]) -> Result { - CompositionPoly::evaluate(&self, query).map_err(|e| e.into()) - } -} - /// Apply a binary operation to two arguments and store the result in `current_evals`. /// `op` must be a function that takes two arguments and initialized the result with the third argument. fn apply_binary_op>>( @@ -466,7 +435,7 @@ mod tests { use binius_field::{ BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedField, TowerField, }; - use binius_math::CompositionPolyOS; + use binius_math::CompositionPoly; use binius_utils::felts; use super::*; @@ -479,7 +448,7 @@ mod tests { let expr = ArithExpr::Const(F::new(123)); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); assert_eq!(typed_circuit.degree(), 0); assert_eq!(typed_circuit.n_vars(), 0); @@ -500,8 +469,8 @@ mod tests { let expr = ArithExpr::Var(0); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 0); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -528,8 +497,8 @@ mod tests { let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 3); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -548,8 +517,8 @@ mod tests { let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 3); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -574,8 +543,8 @@ mod tests { let expr = ArithExpr::Var(0).pow(13); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 0); assert_eq!(typed_circuit.degree(), 13); assert_eq!(typed_circuit.n_vars(), 1); @@ -600,8 +569,8 @@ mod tests { let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123))); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 3); assert_eq!(typed_circuit.degree(), 3); assert_eq!(typed_circuit.n_vars(), 2); @@ -662,7 +631,7 @@ mod tests { let circuit = ArithCircuitPoly::::new(expr); assert_eq!(circuit.steps.len(), 2); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 2); @@ -721,8 +690,8 @@ mod tests { let circuit = ArithCircuitPoly::::new(expr); assert_eq!(circuit.steps.len(), 1); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 1); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -746,10 +715,10 @@ mod tests { // ((x0^2)^3)^4 let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4); let circuit = ArithCircuitPoly::::new(expr); - assert_eq!(circuit.steps.len(), 1); + assert_eq!(circuit.steps.len(), 5); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + let typed_circuit: &dyn CompositionPoly

= &circuit; + assert_eq!(typed_circuit.binary_tower_level(), 0); assert_eq!(typed_circuit.degree(), 24); assert_eq!(typed_circuit.n_vars(), 1); @@ -764,4 +733,358 @@ mod tests { P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])), ); } + + #[test] + fn test_circuit_steps_for_expr_constant() { + type F = BinaryField8b; + + let expr = ArithExpr::Const(F::new(5)); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert!(steps.is_empty(), "No steps should be generated for a constant"); + assert_eq!(retval, CircuitStepArgument::Const(F::new(5))); + } + + #[test] + fn test_circuit_steps_for_expr_variable() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(18); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert!(steps.is_empty(), "No steps should be generated for a variable"); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18)))); + } + + #[test] + fn test_circuit_steps_for_expr_addition() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(14) + ArithExpr::::Var(56); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 1, "One addition step should be generated"); + assert!(matches!( + steps[0], + CircuitStep::Add( + CircuitStepArgument::Expr(CircuitNode::Var(14)), + CircuitStepArgument::Expr(CircuitNode::Var(56)) + ) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0)))); + } + + #[test] + fn test_circuit_steps_for_expr_multiplication() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(36) * ArithExpr::Var(26); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 1, "One multiplication step should be generated"); + assert!(matches!( + steps[0], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Var(36)), + CircuitStepArgument::Expr(CircuitNode::Var(26)) + ) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_1() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(12).pow(1); + let (steps, retval) = circuit_steps_for_expr(&expr); + + // No steps should be generated for x^1 + assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps"); + + // The return value should just be the variable itself + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_2() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(10).pow(2); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step"); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10))) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_3() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(5).pow(3); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!( + steps.len(), + 2, + "Pow(3) should generate one squaring and one multiplication step" + ); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5))) + )); + assert!(matches!( + steps[1], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(0)), + CircuitStepArgument::Expr(CircuitNode::Var(5)) + ) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_4() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(7).pow(4); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps"); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7))) + )); + + assert!(matches!( + steps[1], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0))) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_5() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(3).pow(5); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!( + steps.len(), + 3, + "Pow(5) should generate two squaring steps and one multiplication" + ); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3))) + )); + assert!(matches!( + steps[1], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0))) + )); + assert!(matches!( + steps[2], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(1)), + CircuitStepArgument::Expr(CircuitNode::Var(3)) + ) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_8() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(4).pow(8); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps"); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4))) + )); + assert!(matches!( + steps[1], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0))) + )); + assert!(matches!( + steps[2], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1))) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_9() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(8).pow(9); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!( + steps.len(), + 4, + "Pow(9) should generate three squaring steps and one multiplication" + ); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8))) + )); + assert!(matches!( + steps[1], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0))) + )); + assert!(matches!( + steps[2], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1))) + )); + assert!(matches!( + steps[3], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(2)), + CircuitStepArgument::Expr(CircuitNode::Var(8)) + ) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_12() { + type F = BinaryField8b; + let expr = ArithExpr::::Var(6).pow(12); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps."); + + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6))) + )); + assert!(matches!( + steps[1], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(0)), + CircuitStepArgument::Expr(CircuitNode::Var(6)) + ) + )); + assert!(matches!( + steps[2], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1))) + )); + assert!(matches!( + steps[3], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2))) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_13() { + type F = BinaryField8b; + let expr = ArithExpr::::Var(7).pow(13); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps."); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7))) + )); + assert!(matches!( + steps[1], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(0)), + CircuitStepArgument::Expr(CircuitNode::Var(7)) + ) + )); + assert!(matches!( + steps[2], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1))) + )); + assert!(matches!( + steps[3], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2))) + )); + assert!(matches!( + steps[4], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(3)), + CircuitStepArgument::Expr(CircuitNode::Var(7)) + ) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4)))); + } + + #[test] + fn test_circuit_steps_for_expr_complex() { + type F = BinaryField8b; + + let expr = (ArithExpr::::Var(0) * ArithExpr::Var(1)) + + (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2) + - ArithExpr::Var(3); + + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps"); + + assert!( + matches!( + steps[0], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Var(0)), + CircuitStepArgument::Expr(CircuitNode::Var(1)) + ) + ), + "First step should be multiplication x0 * x1" + ); + + assert!( + matches!( + steps[1], + CircuitStep::Add( + CircuitStepArgument::Const(F::ONE), + CircuitStepArgument::Expr(CircuitNode::Var(0)) + ) + ), + "Second step should be (1 - x0)" + ); + + assert!( + matches!( + steps[2], + CircuitStep::AddMul( + 0, + CircuitStepArgument::Expr(CircuitNode::Slot(1)), + CircuitStepArgument::Expr(CircuitNode::Var(2)) + ) + ), + "Third step should be (1 - x0) * x2" + ); + + assert!( + matches!( + steps[3], + CircuitStep::Add( + CircuitStepArgument::Expr(CircuitNode::Slot(0)), + CircuitStepArgument::Expr(CircuitNode::Var(3)) + ) + ), + "Fourth step should be x0 * x1 + (1 - x0) * x2 + x3" + ); + + assert!( + matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))), + "Final result should be stored in Slot(3)" + ); + } } diff --git a/crates/core/src/polynomial/cached.rs b/crates/core/src/polynomial/cached.rs deleted file mode 100644 index 181257e5f..000000000 --- a/crates/core/src/polynomial/cached.rs +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2024-2025 Irreducible Inc. - -use std::{ - any::{Any, TypeId}, - collections::HashMap, - fmt::Debug, - marker::PhantomData, -}; - -use binius_field::{ExtensionField, Field, PackedField}; -use binius_math::{ArithExpr, CompositionPoly, CompositionPolyOS, Error}; - -/// Cached composition poly wrapper. -/// -/// It stores the efficient implementations of the composition poly for some known set of packed field types. -/// We are usually able to use this when the inner poly is constructed with a macro for the known field and packed field types. -#[derive(Default, Debug)] -pub struct CachedPoly> { - inner: Inner, - cache: PackedFieldCache, -} - -impl> CachedPoly { - /// Create a new cached polynomial with the given inner polynomial. - pub fn new(inner: Inner) -> Self { - Self { - inner, - cache: Default::default(), - } - } - - /// Register efficient implementations for the `P` packed field type in the cache. - pub fn register>>( - &mut self, - composition: impl CompositionPolyOS

+ 'static, - ) { - self.cache.register(composition); - } -} - -impl> CompositionPoly for CachedPoly { - fn n_vars(&self) -> usize { - self.inner.n_vars() - } - - fn degree(&self) -> usize { - self.inner.degree() - } - - fn binary_tower_level(&self) -> usize { - self.inner.binary_tower_level() - } - - fn expression>(&self) -> ArithExpr { - self.inner.expression() - } - - fn evaluate>>(&self, query: &[P]) -> Result { - if let Some(result) = self.cache.try_evaluate(query) { - result - } else { - self.inner.evaluate(query) - } - } - - fn batch_evaluate>>( - &self, - batch_query: &[&[P]], - evals: &mut [P], - ) -> Result<(), Error> { - if let Some(result) = self.cache.try_batch_evaluate(batch_query, evals) { - result - } else { - self.inner.batch_evaluate(batch_query, evals) - } - } -} - -impl, P: PackedField>> - CompositionPolyOS

for CachedPoly -{ - fn binary_tower_level(&self) -> usize { - CompositionPoly::binary_tower_level(&self) - } - - fn n_vars(&self) -> usize { - CompositionPoly::n_vars(&self) - } - - fn degree(&self) -> usize { - CompositionPoly::degree(&self) - } - - fn expression(&self) -> ArithExpr { - CompositionPoly::expression(&self) - } - - fn evaluate(&self, query: &[P]) -> Result { - CompositionPoly::evaluate(&self, query) - } - - fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> { - CompositionPoly::batch_evaluate(&self, batch_query, evals) - } -} - -#[derive(Default)] -struct PackedFieldCache { - /// Map from the packed field type 'P to the efficient implementation of the composition polynomial - /// with actual type `Box>`. - entries: HashMap>, - _pd: PhantomData, -} - -impl PackedFieldCache { - /// Register efficient implementations for the `P` packed field type in the cache. - fn register>>( - &mut self, - composition: impl CompositionPolyOS

+ 'static, - ) { - let boxed_composition = Box::new(composition) as Box>; - self.entries - .insert(TypeId::of::

(), Box::new(boxed_composition) as Box); - } - - /// Try to evaluate the expression using the efficient implementation for the `P` packed field type. - /// If no implementation is found, return None. - fn try_evaluate>>( - &self, - query: &[P], - ) -> Option> { - if let Some(entry) = self.entries.get(&TypeId::of::

()) { - let entry = entry - .downcast_ref::>>() - .expect("cast must succeed"); - Some(entry.evaluate(query)) - } else { - None - } - } - - /// Try to batch evaluate the expression using the efficient implementation for the `P` packed field type. - /// If no implementation is found, return None. - fn try_batch_evaluate>>( - &self, - batch_query: &[&[P]], - evals: &mut [P], - ) -> Option> { - if let Some(entry) = self.entries.get(&TypeId::of::

()) { - let entry = entry - .downcast_ref::>>() - .expect("cast must succeed"); - Some(entry.batch_evaluate(batch_query, evals)) - } else { - None - } - } -} - -impl Debug for PackedFieldCache { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PackedFieldCache") - .field("cached_implementations", &self.entries.len()) - .finish() - } -} - -#[cfg(test)] -mod tests { - use std::iter::zip; - - use binius_field::{BinaryField8b, ExtensionField, PackedBinaryField16x8b, PackedField}; - use binius_math::{ArithExpr, CompositionPolyOS}; - - use super::*; - use crate::polynomial::{cached::CachedPoly, ArithCircuitPoly}; - - fn ensure_equal_batch_eval_results( - circuit_1: &impl CompositionPolyOS

, - circuit_2: &impl CompositionPolyOS

, - batch_query: &[&[P]], - ) { - for row in 0..batch_query[0].len() { - let query = batch_query.iter().map(|q| q[row]).collect::>(); - - assert_eq!(circuit_1.evaluate(&query).unwrap(), circuit_2.evaluate(&query).unwrap()); - } - - let result_1 = { - let mut uncached_evals = vec![P::zero(); batch_query[0].len()]; - circuit_1 - .batch_evaluate(batch_query, &mut uncached_evals) - .unwrap(); - uncached_evals - }; - - let result_2 = { - let mut cached_evals = vec![P::zero(); batch_query[0].len()]; - circuit_2 - .batch_evaluate(batch_query, &mut cached_evals) - .unwrap(); - cached_evals - }; - - assert_eq!(result_1, result_2); - } - - #[derive(Debug, Copy, Clone)] - struct AddComposition; - - impl>> CompositionPolyOS

- for AddComposition - { - fn binary_tower_level(&self) -> usize { - 0 - } - - fn n_vars(&self) -> usize { - 1 - } - - fn degree(&self) -> usize { - 1 - } - - fn expression(&self) -> ArithExpr { - ArithExpr::Const(BinaryField8b::new(123).into()) + ArithExpr::Var(0) - } - - fn evaluate(&self, query: &[P]) -> Result { - Ok(query[0] + P::broadcast(BinaryField8b::new(123).into())) - } - - fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> { - for (input, output) in zip(batch_query[0], evals) { - *output = *input + P::broadcast(BinaryField8b::new(123).into()); - } - - Ok(()) - } - } - - #[test] - fn test_cached_impl() { - let expr = ArithExpr::Const(BinaryField8b::new(123)) + ArithExpr::Var(0); - let circuit = ArithCircuitPoly::::new(expr); - - let composition = AddComposition; - - let mut cached_circuit = CachedPoly::new(circuit.clone()); - cached_circuit.register::(composition); - - let batch_query = [(0..255).map(BinaryField8b::new).collect::>()]; - let batch_query = batch_query.iter().map(|q| q.as_slice()).collect::>(); - ensure_equal_batch_eval_results(&circuit, &cached_circuit, &batch_query); - } - - #[test] - fn test_uncached_impl() { - let expr = ArithExpr::Const(BinaryField8b::new(123)) + ArithExpr::Var(0); - let circuit = ArithCircuitPoly::::new(expr); - - let composition = AddComposition; - - let mut cached_circuit = CachedPoly::new(circuit.clone()); - cached_circuit.register::(composition); - - let batch_query = [(0..255).map(BinaryField8b::new).collect::>()]; - let batch_query = batch_query.iter().map(|q| q.as_slice()).collect::>(); - ensure_equal_batch_eval_results(&circuit, &cached_circuit, &batch_query); - } -} diff --git a/crates/core/src/polynomial/mod.rs b/crates/core/src/polynomial/mod.rs index 1c21083ad..7f8e4d762 100644 --- a/crates/core/src/polynomial/mod.rs +++ b/crates/core/src/polynomial/mod.rs @@ -1,7 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. mod arith_circuit; -mod cached; mod error; mod multivariate; #[allow(dead_code)] @@ -9,6 +8,5 @@ mod multivariate; pub mod test_utils; pub use arith_circuit::*; -pub use cached::*; pub use error::*; pub use multivariate::*; diff --git a/crates/core/src/polynomial/multivariate.rs b/crates/core/src/polynomial/multivariate.rs index a8e6fdcf8..e670f7aca 100644 --- a/crates/core/src/polynomial/multivariate.rs +++ b/crates/core/src/polynomial/multivariate.rs @@ -4,9 +4,10 @@ use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sy use binius_field::{Field, PackedField}; use binius_math::{ - ArithExpr, CompositionPolyOS, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef, + ArithExpr, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef, }; -use binius_utils::bail; +use binius_utils::{bail, SerializationError, SerializationMode}; +use bytes::BufMut; use itertools::Itertools; use rand::{rngs::StdRng, SeedableRng}; @@ -14,8 +15,8 @@ use super::error::Error; /// A multivariate polynomial over a binary tower field. /// -/// The definition `MultivariatePoly` is nearly identical to that of [`CompositionPolyOS`], except that -/// `MultivariatePoly` is _object safe_, whereas `CompositionPolyOS` is not. +/// The definition `MultivariatePoly` is nearly identical to that of [`CompositionPoly`], except that +/// `MultivariatePoly` is _object safe_, whereas `CompositionPoly` is not. pub trait MultivariatePoly

: Debug + Send + Sync { /// The number of variables. fn n_vars(&self) -> usize; @@ -28,13 +29,24 @@ pub trait MultivariatePoly

: Debug + Send + Sync { /// Returns the maximum binary tower level of all constants in the arithmetic expression. fn binary_tower_level(&self) -> usize; + + /// Serialize a type erased MultivariatePoly. + /// Since not every MultivariatePoly implements serialization, this defaults to returning an error. + fn erased_serialize( + &self, + write_buf: &mut dyn BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + let _ = (write_buf, mode); + Err(SerializationError::SerializationNotImplemented) + } } /// Identity composition function $g(X) = X$. #[derive(Clone, Debug)] pub struct IdentityCompositionPoly; -impl CompositionPolyOS

for IdentityCompositionPoly { +impl CompositionPoly

for IdentityCompositionPoly { fn n_vars(&self) -> usize { 1 } @@ -59,7 +71,7 @@ impl CompositionPolyOS

for IdentityCompositionPoly { } } -/// An adapter that constructs a [`CompositionPolyOS`] for a field from a [`CompositionPolyOS`] for a +/// An adapter that constructs a [`CompositionPoly`] for a field from a [`CompositionPoly`] for a /// packing of that field. /// /// This is not intended for use in performance-critical code sections. @@ -72,7 +84,7 @@ pub struct CompositionScalarAdapter { impl CompositionScalarAdapter where P: PackedField, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { pub const fn new(composition: Composition) -> Self { Self { @@ -82,11 +94,11 @@ where } } -impl CompositionPolyOS for CompositionScalarAdapter +impl CompositionPoly for CompositionScalarAdapter where F: Field, P: PackedField, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn n_vars(&self) -> usize { self.composition.n_vars() @@ -141,7 +153,7 @@ where impl MultilinearComposite where P: PackedField, - C: CompositionPolyOS

, + C: CompositionPoly

, M: MultilinearPoly

, { pub fn new(n_vars: usize, composition: C, multilinears: Vec) -> Result { @@ -207,12 +219,10 @@ where impl MultilinearComposite where P: PackedField, - C: CompositionPolyOS

+ 'static, + C: CompositionPoly

+ 'static, M: MultilinearPoly

, { - pub fn to_arc_dyn_composition( - self, - ) -> MultilinearComposite>, M> { + pub fn to_arc_dyn_composition(self) -> MultilinearComposite>, M> { MultilinearComposite { n_vars: self.n_vars, composition: Arc::new(self.composition), @@ -267,7 +277,7 @@ where /// for two distinct multivariate polynomials f and g. /// /// NOTE: THIS IS NOT ADVERSARIALLY COLLISION RESISTANT, COLLISIONS CAN BE MANUFACTURED EASILY -pub fn composition_hash>(composition: &C) -> P { +pub fn composition_hash>(composition: &C) -> P { let mut rng = StdRng::from_seed([0; 32]); let random_point = repeat_with(|| P::random(&mut rng)) @@ -281,7 +291,7 @@ pub fn composition_hash>(composition: &C #[cfg(test)] mod tests { - use binius_math::{ArithExpr, CompositionPolyOS}; + use binius_math::{ArithExpr, CompositionPoly}; #[test] fn test_fingerprint_same_32b() { @@ -291,7 +301,7 @@ mod tests { let expr = (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -308,7 +318,7 @@ mod tests { let expr = ArithExpr::Var(0) + ArithExpr::Var(1); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -326,7 +336,7 @@ mod tests { let expr = (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -342,7 +352,7 @@ mod tests { let expr = ArithExpr::Var(0) + ArithExpr::Var(1); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -360,7 +370,7 @@ mod tests { let expr = (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -376,7 +386,7 @@ mod tests { let expr = ArithExpr::Var(0) + ArithExpr::Var(1); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; diff --git a/crates/core/src/protocols/evalcheck/error.rs b/crates/core/src/protocols/evalcheck/error.rs index 54426fe8f..d5bf9447b 100644 --- a/crates/core/src/protocols/evalcheck/error.rs +++ b/crates/core/src/protocols/evalcheck/error.rs @@ -45,7 +45,7 @@ pub enum VerificationError { impl VerificationError { pub fn incorrect_composite_poly_evaluation( - oracle: CompositePolyOracle, + oracle: &CompositePolyOracle, ) -> Self { let names = oracle .inner_polys() diff --git a/crates/core/src/protocols/fri/common.rs b/crates/core/src/protocols/fri/common.rs index e8e7318fd..9d74660b6 100644 --- a/crates/core/src/protocols/fri/common.rs +++ b/crates/core/src/protocols/fri/common.rs @@ -343,13 +343,13 @@ mod tests { #[test] fn test_calculate_n_test_queries() { let security_bits = 96; - let rs_code = ReedSolomonCode::new(28, 1, NTTOptions::default()).unwrap(); + let rs_code = ReedSolomonCode::new(28, 1, &NTTOptions::default()).unwrap(); let n_test_queries = calculate_n_test_queries::(security_bits, &rs_code) .unwrap(); assert_eq!(n_test_queries, 232); - let rs_code = ReedSolomonCode::new(28, 2, NTTOptions::default()).unwrap(); + let rs_code = ReedSolomonCode::new(28, 2, &NTTOptions::default()).unwrap(); let n_test_queries = calculate_n_test_queries::(security_bits, &rs_code) .unwrap(); @@ -359,7 +359,7 @@ mod tests { #[test] fn test_calculate_n_test_queries_unsatisfiable() { let security_bits = 128; - let rs_code = ReedSolomonCode::new(28, 1, NTTOptions::default()).unwrap(); + let rs_code = ReedSolomonCode::new(28, 1, &NTTOptions::default()).unwrap(); assert_matches!( calculate_n_test_queries::(security_bits, &rs_code), Err(Error::ParameterError) diff --git a/crates/core/src/protocols/fri/prove.rs b/crates/core/src/protocols/fri/prove.rs index ba07ebb72..2a3b9f49f 100644 --- a/crates/core/src/protocols/fri/prove.rs +++ b/crates/core/src/protocols/fri/prove.rs @@ -3,7 +3,7 @@ use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField, TowerField}; use binius_hal::{make_portable_backend, ComputationBackend}; use binius_maybe_rayon::prelude::*; -use binius_utils::{bail, serialization::SerializeBytes}; +use binius_utils::{bail, SerializeBytes}; use bytemuck::zeroed_vec; use bytes::BufMut; use itertools::izip; @@ -174,7 +174,7 @@ pub fn commit_interleaved( message: &[P], ) -> Result, Error> where - F: BinaryField + ExtensionField, + F: BinaryField, FA: BinaryField, P: PackedField + PackedExtension, PA: PackedField, @@ -209,7 +209,7 @@ pub fn commit_interleaved_with( message_writer: impl FnOnce(&mut [P]), ) -> Result, Error> where - F: BinaryField + ExtensionField, + F: BinaryField, FA: BinaryField, P: PackedField + PackedExtension, PA: PackedField, diff --git a/crates/core/src/protocols/fri/tests.rs b/crates/core/src/protocols/fri/tests.rs index df4ecf22a..1798b3c14 100644 --- a/crates/core/src/protocols/fri/tests.rs +++ b/crates/core/src/protocols/fri/tests.rs @@ -46,14 +46,14 @@ fn test_commit_prove_verify_success( let committed_rs_code_packed = ReedSolomonCode::>::new( log_dimension, log_inv_rate, - NTTOptions::default(), + &NTTOptions::default(), ) .unwrap(); let merkle_prover = BinaryMerkleTreeProver::<_, Groestl256, _>::new(Groestl256ByteCompression); let committed_rs_code = - ReedSolomonCode::::new(log_dimension, log_inv_rate, NTTOptions::default()).unwrap(); + ReedSolomonCode::::new(log_dimension, log_inv_rate, &NTTOptions::default()).unwrap(); let n_test_queries = 3; let params = diff --git a/crates/core/src/protocols/fri/verify.rs b/crates/core/src/protocols/fri/verify.rs index 0abc46044..85e548144 100644 --- a/crates/core/src/protocols/fri/verify.rs +++ b/crates/core/src/protocols/fri/verify.rs @@ -4,7 +4,7 @@ use std::iter; use binius_field::{BinaryField, ExtensionField, TowerField}; use binius_hal::{make_portable_backend, ComputationBackend}; -use binius_utils::{bail, serialization::DeserializeBytes}; +use binius_utils::{bail, DeserializeBytes}; use bytes::Buf; use itertools::izip; use tracing::instrument; diff --git a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs index 800ff7f7b..ca11ea7ff 100644 --- a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs +++ b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs @@ -2,13 +2,9 @@ use std::ops::Range; -use binius_field::{ - util::eq, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, -}; +use binius_field::{util::eq, Field, PackedExtension, PackedField, PackedFieldIndexable}; use binius_hal::{ComputationBackend, SumcheckEvaluator}; -use binius_math::{ - CompositionPolyOS, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly, -}; +use binius_math::{CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly}; use binius_maybe_rayon::prelude::*; use binius_utils::bail; use itertools::izip; @@ -48,12 +44,10 @@ where impl<'a, F, FDomain, P, Composition, M, Backend> GPAProver<'a, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, - P: PackedFieldIndexable - + PackedExtension - + PackedExtension, - Composition: CompositionPolyOS

, + P: PackedFieldIndexable + PackedExtension, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -104,8 +98,8 @@ where let evaluation_points = domains .iter() - .max_by_key(|domain| domain.points().len()) - .map_or_else(|| Vec::new(), |domain| domain.points().to_vec()); + .max_by_key(|domain| domain.size()) + .map_or_else(|| Vec::new(), |domain| domain.finite_points().to_vec()); let state = ProverState::new( multilinears, @@ -193,12 +187,12 @@ where impl SumcheckProver for GPAProver<'_, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension + PackedExtension, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -229,7 +223,7 @@ where }) .collect::>(); - let evals = self.state.calculate_later_round_evals(&evaluators)?; + let evals = self.state.calculate_round_evals(&evaluators)?; let coeffs = self.state .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)?; @@ -287,13 +281,13 @@ where gpa_round_challenge: P::Scalar, } -impl SumcheckEvaluator +impl SumcheckEvaluator for GPAEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension + PackedExtension, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn eval_point_indices(&self) -> Range { // By definition of grand product GKR circuit, the composition evaluation is a multilinear @@ -344,10 +338,10 @@ where impl SumcheckInterpolator for GPAEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { #[instrument( skip_all, diff --git a/crates/core/src/protocols/gkr_gpa/prove.rs b/crates/core/src/protocols/gkr_gpa/prove.rs index c1ab27bcb..e48d0c272 100644 --- a/crates/core/src/protocols/gkr_gpa/prove.rs +++ b/crates/core/src/protocols/gkr_gpa/prove.rs @@ -1,8 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{ - ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField, -}; +use binius_field::{Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField}; use binius_hal::ComputationBackend; use binius_math::{ extrapolate_line_scalar, EvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, @@ -46,7 +44,6 @@ where + PackedExtension + PackedExtension, FDomain: Field, - P::Scalar: Field + ExtensionField, Challenger_: Challenger, Backend: ComputationBackend, { @@ -266,7 +263,6 @@ where where FDomain: Field, P: PackedExtension, - F: ExtensionField, { // test same layer let Some(first_prover) = provers.first() else { diff --git a/crates/core/src/protocols/gkr_gpa/tests.rs b/crates/core/src/protocols/gkr_gpa/tests.rs index 87fcefa27..263544606 100644 --- a/crates/core/src/protocols/gkr_gpa/tests.rs +++ b/crates/core/src/protocols/gkr_gpa/tests.rs @@ -7,8 +7,8 @@ use binius_field::{ as_packed_field::{PackScalar, PackedType}, packed::set_packed_slice, underlier::{UnderlierType, WithUnderlier}, - BinaryField128b, BinaryField32b, ExtensionField, Field, PackedExtension, PackedField, - PackedFieldIndexable, RepackedExtension, TowerField, + BinaryField128b, BinaryField32b, Field, PackedExtension, PackedField, PackedFieldIndexable, + RepackedExtension, TowerField, }; use binius_math::{IsomorphicEvaluationDomainFactory, MultilinearExtension}; use bytemuck::zeroed_vec; @@ -24,15 +24,11 @@ use crate::{ witness::MultilinearExtensionIndex, }; -fn generate_poly_helper( +fn generate_poly_helper, F: Field>( rng: &mut StdRng, n_vars: usize, n_multilinears: usize, -) -> Vec<(MultilinearExtension

, F)> -where - P: PackedField>, - F: Field, -{ +) -> Vec<(MultilinearExtension

, F)> { repeat_with(|| { let values = repeat_with(|| F::random(&mut *rng)) .take(1 << n_vars) @@ -119,7 +115,7 @@ fn run_prove_verify_batch_test() where U: UnderlierType + PackScalar, P: PackedExtension + RepackedExtension

+ PackedFieldIndexable, - F: TowerField + ExtensionField, + F: TowerField, FS: TowerField, { let rng = StdRng::seed_from_u64(0); diff --git a/crates/core/src/protocols/gkr_gpa/verify.rs b/crates/core/src/protocols/gkr_gpa/verify.rs index 056f4cb76..a9d96366d 100644 --- a/crates/core/src/protocols/gkr_gpa/verify.rs +++ b/crates/core/src/protocols/gkr_gpa/verify.rs @@ -54,7 +54,7 @@ where &mut reverse_sorted_evalcheck_claims, ); - layer_claims = reduce_layer_claim_batch(layer_claims, transcript)?; + layer_claims = reduce_layer_claim_batch(&layer_claims, transcript)?; } process_finished_claims( n_claims, @@ -102,7 +102,7 @@ fn process_finished_claims( /// * `proof` - The batch layer proof that reduces the kth layer claims of the product circuits to the (k+1)th /// * `transcript` - The verifier transcript fn reduce_layer_claim_batch( - claims: Vec>, + claims: &[LayerClaim], transcript: &mut VerifierTranscript, ) -> Result>, Error> where diff --git a/crates/core/src/protocols/gkr_int_mul/error.rs b/crates/core/src/protocols/gkr_int_mul/error.rs index caa824f00..6da9cf33d 100644 --- a/crates/core/src/protocols/gkr_int_mul/error.rs +++ b/crates/core/src/protocols/gkr_int_mul/error.rs @@ -13,4 +13,12 @@ pub enum Error { SumcheckError(#[from] SumcheckError), #[error("polynomial error: {0}")] Polynomial(#[from] PolynomialError), + #[error("verification failure: {0}")] + Verification(#[from] VerificationError), +} + +#[derive(Debug, thiserror::Error)] +pub enum VerificationError { + #[error("the proof contains an incorrect evaluation of the eq indicator")] + IncorrectEqIndEvaluation, } diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/compositions.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/compositions.rs index a15251587..35cbebf91 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/compositions.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/compositions.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{Field, PackedField}; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::bail; #[derive(Debug)] @@ -12,7 +12,7 @@ where pub generator_power_constant: F, } -impl CompositionPolyOS

for MultiplyOrDont { +impl CompositionPoly

for MultiplyOrDont { fn n_vars(&self) -> usize { 2 } diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs index 070d0637e..f6c0ae388 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs @@ -3,7 +3,7 @@ use std::array; use binius_field::{ - BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, + BinaryField1b, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField, }; use binius_hal::ComputationBackend; @@ -40,18 +40,17 @@ pub fn prove< backend: &Backend, ) -> Result, Error> where + F: ExtensionField + ExtensionField + TowerField, FDomain: Field, - PBits: PackedField, - PGenerator: PackedExtension - + PackedFieldIndexable + FGenerator: TowerField + ExtensionField + ExtensionField, + PBits: PackedField, + PGenerator: PackedField + + PackedExtension + PackedExtension, - PGenerator::Scalar: ExtensionField + ExtensionField, - PChallenge: PackedField - + PackedFieldIndexable + PChallenge: PackedFieldIndexable + + PackedExtension + PackedExtension + PackedExtension, - F: ExtensionField + ExtensionField + BinaryField + TowerField, - FGenerator: Field + TowerField, Backend: ComputationBackend, Challenger_: Challenger, { @@ -60,10 +59,11 @@ where let mut eval_point = claim.eval_point.clone(); let mut eval = claim.eval; + for exponent_bit_number in (1..EXPONENT_BIT_WIDTH).rev() { let this_round_exponent_bit = witness.exponent[exponent_bit_number].clone(); let this_round_generator_power_constant = - F::from(FGenerator::MULTIPLICATIVE_GENERATOR.pow([1 << exponent_bit_number])); + F::from(FGenerator::MULTIPLICATIVE_GENERATOR.pow(1 << exponent_bit_number)); let this_round_input_data = witness.single_bit_output_layers_data[exponent_bit_number - 1].clone(); diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs index 97508b59f..421b1efd7 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs @@ -124,7 +124,7 @@ fn witness_gen_happens_correctly() { for (row_idx, this_row_exponent) in exponent.into_iter().enumerate() { assert_eq!( ::Scalar::MULTIPLICATIVE_GENERATOR - .pow([this_row_exponent as u64]), + .pow(this_row_exponent as u64), get_packed_slice(results, row_idx) ); } diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs index 1024481ab..66df9910e 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs @@ -3,7 +3,6 @@ use std::array; use binius_field::{ExtensionField, TowerField}; -use binius_utils::bail; use super::{ super::error::Error, common::GeneratorExponentReductionOutput, utils::first_layer_inverse, @@ -13,7 +12,7 @@ use crate::{ polynomial::MultivariatePoly, protocols::{ gkr_gpa::LayerClaim, - gkr_int_mul::generator_exponent::compositions::MultiplyOrDont, + gkr_int_mul::{error::VerificationError, generator_exponent::compositions::MultiplyOrDont}, sumcheck::{self, zerocheck::ExtraProduct, CompositeSumClaim, SumcheckClaim}, }, transcript::VerifierTranscript, @@ -78,7 +77,7 @@ where EqIndPartialEval::new(log_size, sumcheck_query_point.clone())?.evaluate(&eval_point)?; if sumcheck_verification_output.multilinear_evals[0][2] != eq_eval { - bail!(Error::EqEvalDoesntVerify) + return Err(VerificationError::IncorrectEqIndEvaluation.into()); } eval_claims_on_bit_columns[exponent_bit_number] = LayerClaim { diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/witness.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/witness.rs index 885c47f9e..569c8f491 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/witness.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/witness.rs @@ -3,8 +3,7 @@ use std::{array, cmp::min, slice}; use binius_field::{ - ext_base_op_par, BinaryField, BinaryField1b, ExtensionField, Field, PackedExtension, - PackedField, PackedFieldIndexable, + ext_base_op_par, BinaryField, BinaryField1b, ExtensionField, PackedExtension, PackedField, }; use binius_maybe_rayon::{ prelude::{IndexedParallelIterator, ParallelIterator}, @@ -28,7 +27,7 @@ pub struct GeneratorExponentWitness< fn copy_witness_into_vec(poly: &MultilinearWitness) -> Vec

where P: PackedField, - PE: PackedField + PackedExtension, + PE: PackedExtension, PE::Scalar: ExtensionField, { let mut input_layer: Vec

= zeroed_vec(1 << poly.n_vars().saturating_sub(P::LOG_WIDTH)); @@ -68,9 +67,8 @@ fn evaluate_single_bit_output_packed( ) -> Vec where PBits: PackedField, - PGenerator: - PackedField + PackedFieldIndexable + PackedExtension, - PGenerator::Scalar: ExtensionField + BinaryField, + PGenerator: PackedExtension, + PGenerator::Scalar: BinaryField, { debug_assert_eq!( PBits::WIDTH * exponent_bit.len(), @@ -98,9 +96,8 @@ fn evaluate_first_layer_output_packed( ) -> Vec where PBits: PackedField, - PGenerator: - PackedField + PackedFieldIndexable + PackedExtension, - PGenerator::Scalar: ExtensionField, + PGenerator: PackedExtension, + PGenerator::Scalar: BinaryField, { let mut result = vec![PGenerator::zero(); exponent_bit.len() * PGenerator::Scalar::DEGREE]; @@ -119,11 +116,10 @@ impl<'a, PBits, PGenerator, PChallenge, const EXPONENT_BIT_WIDTH: usize> GeneratorExponentWitness<'a, PBits, PGenerator, PChallenge, EXPONENT_BIT_WIDTH> where PBits: PackedField, - PGenerator: - PackedField + PackedFieldIndexable + PackedExtension, - PGenerator::Scalar: ExtensionField + BinaryField, - PChallenge: PackedField + PackedExtension, - PChallenge::Scalar: ExtensionField, + PGenerator: PackedExtension, + PGenerator::Scalar: BinaryField, + PChallenge: PackedExtension, + PChallenge::Scalar: BinaryField, { pub fn new( exponent: [MultilinearWitness<'a, PChallenge>; EXPONENT_BIT_WIDTH], @@ -140,12 +136,15 @@ where PGenerator::Scalar::MULTIPLICATIVE_GENERATOR, ); + let mut generator_power_constant = PGenerator::Scalar::MULTIPLICATIVE_GENERATOR.square(); + for layer_idx_from_left in 1..EXPONENT_BIT_WIDTH { single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed( &exponent_data[layer_idx_from_left], - PGenerator::Scalar::MULTIPLICATIVE_GENERATOR.pow([1 << layer_idx_from_left]), + generator_power_constant, &single_bit_output_layers_data[layer_idx_from_left - 1], - ) + ); + generator_power_constant = generator_power_constant.square(); } Ok(Self { diff --git a/crates/core/src/protocols/gkr_int_mul/mod.rs b/crates/core/src/protocols/gkr_int_mul/mod.rs index 1e8861bff..02f0dbf2b 100644 --- a/crates/core/src/protocols/gkr_int_mul/mod.rs +++ b/crates/core/src/protocols/gkr_int_mul/mod.rs @@ -1,4 +1,4 @@ // Copyright 2024-2025 Irreducible Inc. mod error; -//pub mod generator_exponent; +pub mod generator_exponent; diff --git a/crates/core/src/protocols/sumcheck/common.rs b/crates/core/src/protocols/sumcheck/common.rs index c0d1e2cc7..8a9bb6f2d 100644 --- a/crates/core/src/protocols/sumcheck/common.rs +++ b/crates/core/src/protocols/sumcheck/common.rs @@ -6,7 +6,7 @@ use binius_field::{ util::{inner_product_unchecked, powers}, ExtensionField, Field, PackedField, }; -use binius_math::{CompositionPolyOS, MultilinearPoly}; +use binius_math::{CompositionPoly, MultilinearPoly}; use binius_utils::bail; use getset::{CopyGetters, Getters}; use tracing::instrument; @@ -45,7 +45,7 @@ pub struct SumcheckClaim { impl SumcheckClaim where - Composition: CompositionPolyOS, + Composition: CompositionPoly, { /// Constructs a new sumcheck claim. /// diff --git a/crates/core/src/protocols/sumcheck/error.rs b/crates/core/src/protocols/sumcheck/error.rs index a0195e68b..848123a2a 100644 --- a/crates/core/src/protocols/sumcheck/error.rs +++ b/crates/core/src/protocols/sumcheck/error.rs @@ -45,7 +45,7 @@ pub enum Error { oracle: String, hypercube_index: usize, }, - #[error("constraint set containts multilinears of different heights")] + #[error("constraint set contains multilinears of different heights")] ConstraintSetNumberOfVariablesMismatch, #[error("batching sumchecks and zerochecks is not supported yet")] MixedBatchingNotSupported, diff --git a/crates/core/src/protocols/sumcheck/front_loaded.rs b/crates/core/src/protocols/sumcheck/front_loaded.rs index bdba716ad..3b1c4c937 100644 --- a/crates/core/src/protocols/sumcheck/front_loaded.rs +++ b/crates/core/src/protocols/sumcheck/front_loaded.rs @@ -3,7 +3,7 @@ use std::{cmp, cmp::Ordering, collections::VecDeque, iter}; use binius_field::{Field, TowerField}; -use binius_math::{evaluate_univariate, CompositionPolyOS}; +use binius_math::{evaluate_univariate, CompositionPoly}; use binius_utils::sorting::is_sorted_ascending; use bytes::Buf; @@ -60,7 +60,7 @@ pub struct BatchVerifier { impl BatchVerifier where F: TowerField, - C: CompositionPolyOS + Clone, + C: CompositionPoly + Clone, { /// Constructs a new verifier for the front-loaded batched sumcheck. /// diff --git a/crates/core/src/protocols/sumcheck/prove/concrete_prover.rs b/crates/core/src/protocols/sumcheck/prove/concrete_prover.rs deleted file mode 100644 index b67e99c2f..000000000 --- a/crates/core/src/protocols/sumcheck/prove/concrete_prover.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2024-2025 Irreducible Inc. - -use binius_field::{ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable}; -use binius_hal::ComputationBackend; -use binius_math::{CompositionPolyOS, MultilinearPoly}; - -use super::{batch_prove::SumcheckProver, RegularSumcheckProver, ZerocheckProver}; -use crate::protocols::sumcheck::{common::RoundCoeffs, error::Error}; - -/// A sum type that is used to put both regular sumchecks and zerochecks into the same `batch_prove` call. -pub enum ConcreteProver<'a, FDomain, PBase, P, CompositionBase, Composition, M, Backend> -where - FDomain: Field, - PBase: PackedField, - P: PackedField, - M: MultilinearPoly

+ Send + Sync, - Backend: ComputationBackend, -{ - Sumcheck(RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>), - Zerocheck(ZerocheckProver<'a, FDomain, PBase, P, CompositionBase, Composition, M, Backend>), -} - -impl SumcheckProver - for ConcreteProver<'_, FDomain, FBase, P, CompositionBase, Composition, M, Backend> -where - F: Field + ExtensionField + ExtensionField, - FDomain: Field, - FBase: ExtensionField, - P: PackedFieldIndexable - + PackedExtension - + PackedExtension - + PackedExtension, - CompositionBase: CompositionPolyOS<

>::PackedSubfield>, - Composition: CompositionPolyOS

, - M: MultilinearPoly

+ Send + Sync, - Backend: ComputationBackend, -{ - fn n_vars(&self) -> usize { - match self { - ConcreteProver::Sumcheck(prover) => prover.n_vars(), - ConcreteProver::Zerocheck(prover) => prover.n_vars(), - } - } - - fn execute(&mut self, batch_coeff: F) -> Result, Error> { - match self { - ConcreteProver::Sumcheck(prover) => prover.execute(batch_coeff), - ConcreteProver::Zerocheck(prover) => prover.execute(batch_coeff), - } - } - - fn fold(&mut self, challenge: F) -> Result<(), Error> { - match self { - ConcreteProver::Sumcheck(prover) => prover.fold(challenge), - ConcreteProver::Zerocheck(prover) => prover.fold(challenge), - } - } - - fn finish(self: Box) -> Result, Error> { - match *self { - ConcreteProver::Sumcheck(prover) => Box::new(prover).finish(), - ConcreteProver::Zerocheck(prover) => Box::new(prover).finish(), - } - } -} diff --git a/crates/core/src/protocols/sumcheck/prove/mod.rs b/crates/core/src/protocols/sumcheck/prove/mod.rs index 94029d716..44c1f587c 100644 --- a/crates/core/src/protocols/sumcheck/prove/mod.rs +++ b/crates/core/src/protocols/sumcheck/prove/mod.rs @@ -3,7 +3,6 @@ mod batch_prove; mod batch_prove_univariate_zerocheck; pub(crate) mod common; -mod concrete_prover; pub mod front_loaded; pub mod oracles; pub mod prover_state; @@ -15,7 +14,6 @@ pub use batch_prove::{batch_prove, batch_prove_with_start, SumcheckProver}; pub use batch_prove_univariate_zerocheck::{ batch_prove_zerocheck_univariate_round, UnivariateZerocheckProver, }; -pub use concrete_prover::ConcreteProver; pub use oracles::{ constraint_set_sumcheck_prover, constraint_set_zerocheck_prover, split_constraint_set, }; diff --git a/crates/core/src/protocols/sumcheck/prove/oracles.rs b/crates/core/src/protocols/sumcheck/prove/oracles.rs index 8b34f4e57..c6f4b7980 100644 --- a/crates/core/src/protocols/sumcheck/prove/oracles.rs +++ b/crates/core/src/protocols/sumcheck/prove/oracles.rs @@ -54,7 +54,7 @@ where + PackedExtension + PackedExtension + PackedExtension, - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, FBase: TowerField + ExtensionField + TryFrom, FDomain: Field, Backend: ComputationBackend, diff --git a/crates/core/src/protocols/sumcheck/prove/prover_state.rs b/crates/core/src/protocols/sumcheck/prove/prover_state.rs index d05e72523..d88272f11 100644 --- a/crates/core/src/protocols/sumcheck/prove/prover_state.rs +++ b/crates/core/src/protocols/sumcheck/prove/prover_state.rs @@ -5,10 +5,10 @@ use std::{ sync::atomic::{AtomicBool, Ordering}, }; -use binius_field::{util::powers, ExtensionField, Field, PackedExtension, PackedField}; +use binius_field::{util::powers, Field, PackedExtension, PackedField}; use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear}; use binius_math::{ - evaluate_univariate, CompositionPolyOS, MLEDirectAdapter, MultilinearPoly, MultilinearQuery, + evaluate_univariate, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQuery, }; use binius_maybe_rayon::prelude::*; use binius_utils::bail; @@ -70,8 +70,8 @@ where impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, + F: Field, + P: PackedField + PackedExtension, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -191,8 +191,11 @@ where ref mut large_field_folded_multilinear, } => { // Post-switchover, simply plug in challenge for the zeroth variable. + let single_variable_query = MultilinearQuery::expand(&[challenge]); *large_field_folded_multilinear = MLEDirectAdapter::from( - large_field_folded_multilinear.evaluate_zeroth_variable(challenge)?, + large_field_folded_multilinear + .as_ref() + .evaluate_partial_low(single_variable_query.to_ref())?, ); } }; @@ -242,41 +245,17 @@ where .collect() } - /// Calculate the accumulated evaluations for the first sumcheck round. - #[instrument(skip_all, level = "debug")] - pub fn calculate_first_round_evals( - &self, - evaluators: &[Evaluator], - ) -> Result>, Error> - where - FBase: ExtensionField, - F: ExtensionField, - P: PackedExtension, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, - { - Ok(self.backend.sumcheck_compute_first_round_evals( - self.n_vars, - &self.multilinears, - evaluators, - &self.evaluation_points, - )?) - } - /// Calculate the accumulated evaluations for an arbitrary sumcheck round. - /// - /// See [`Self::calculate_first_round_evals`] for an optimized version of this method that - /// operates over small fields in the first round. #[instrument(skip_all, level = "debug")] - pub fn calculate_later_round_evals( + pub fn calculate_round_evals( &self, evaluators: &[Evaluator], ) -> Result>, Error> where - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, + Evaluator: SumcheckEvaluator + Sync, + Composition: CompositionPoly

, { - Ok(self.backend.sumcheck_compute_later_round_evals( + Ok(self.backend.sumcheck_compute_round_evals( self.n_vars, self.tensor_query.as_ref().map(Into::into), &self.multilinears, diff --git a/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs b/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs index a3695046a..2a6ccb3f7 100644 --- a/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs @@ -2,11 +2,9 @@ use std::{marker::PhantomData, ops::Range}; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField}; +use binius_field::{Field, PackedExtension, PackedField}; use binius_hal::{ComputationBackend, SumcheckEvaluator}; -use binius_math::{ - CompositionPolyOS, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly, -}; +use binius_math::{CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly}; use binius_maybe_rayon::prelude::*; use binius_utils::bail; use itertools::izip; @@ -31,7 +29,7 @@ where F: Field, P: PackedField, M: MultilinearPoly

+ Send + Sync, - Composition: CompositionPolyOS

+ 'a, + Composition: CompositionPoly

+ 'a, { let n_vars = multilinears .first() @@ -82,10 +80,10 @@ where impl<'a, F, FDomain, P, Composition, M, Backend> RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedField + PackedExtension + PackedExtension, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -142,8 +140,8 @@ where let evaluation_points = domains .iter() - .max_by_key(|domain| domain.points().len()) - .map_or_else(|| Vec::new(), |domain| domain.points().to_vec()); + .max_by_key(|domain| domain.size()) + .map_or_else(|| Vec::new(), |domain| domain.finite_points().to_vec()); let state = ProverState::new( multilinears, @@ -166,10 +164,10 @@ where impl SumcheckProver for RegularSumcheckProver<'_, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedField + PackedExtension + PackedExtension, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -193,7 +191,7 @@ where }) .collect::>(); - let evals = self.state.calculate_later_round_evals(&evaluators)?; + let evals = self.state.calculate_round_evals(&evaluators)?; self.state .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals) } @@ -213,13 +211,13 @@ where _marker: PhantomData

, } -impl SumcheckEvaluator +impl SumcheckEvaluator for RegularSumcheckEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension + PackedExtension, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn eval_point_indices(&self) -> Range { // NB: We skip evaluation of $r(X)$ at $X = 0$ as it is derivable from the @@ -256,7 +254,7 @@ where impl SumcheckInterpolator for RegularSumcheckEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension, FDomain: Field, { diff --git a/crates/core/src/protocols/sumcheck/prove/univariate.rs b/crates/core/src/protocols/sumcheck/prove/univariate.rs index 4d77666b9..fbcf3efe3 100644 --- a/crates/core/src/protocols/sumcheck/prove/univariate.rs +++ b/crates/core/src/protocols/sumcheck/prove/univariate.rs @@ -9,7 +9,7 @@ use binius_field::{ }; use binius_hal::{ComputationBackend, ComputationBackendExt}; use binius_math::{ - CompositionPolyOS, Error as MathError, EvaluationDomainFactory, + CompositionPoly, Error as MathError, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearPoly, }; use binius_maybe_rayon::prelude::*; @@ -95,7 +95,7 @@ pub fn univariatizing_reduction_prover<'a, F, FDomain, P, Backend>( backend: &'a Backend, ) -> Result, Error> where - F: TowerField + ExtensionField, + F: TowerField, FDomain: TowerField, P: PackedFieldIndexable + PackedExtension @@ -132,12 +132,7 @@ where } #[derive(Debug)] -struct ParFoldStates -where - FBase: Field + PackedField, - P: PackedField + PackedExtension, - P::Scalar: ExtensionField, -{ +struct ParFoldStates> { /// Evaluations of a multilinear subcube, embedded into P (see MultilinearPoly::subcube_evals). Scratch space. evals: Vec

, /// `evals` cast to base field and transposed to 2^skip_rounds * 2^log_batch row-major form. Scratch space. @@ -150,12 +145,7 @@ where round_evals: Vec>, } -impl ParFoldStates -where - FBase: Field, - P: PackedField + PackedExtension, - P::Scalar: ExtensionField, -{ +impl> ParFoldStates { fn new( n_multilinears: usize, skip_rounds: usize, @@ -335,11 +325,11 @@ pub fn zerocheck_univariate_evals where FDomain: TowerField, FBase: ExtensionField, - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, P: PackedFieldIndexable + PackedExtension + PackedExtension, - Composition: CompositionPolyOS>, + Composition: CompositionPoly>, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -388,7 +378,7 @@ where // univariatized subcube. // NB: expansion of the first `skip_rounds` variables is applied to the round evals sum let partial_eq_ind_evals = backend.tensor_product_full_query(zerocheck_challenges)?; - let partial_eq_ind_evals_scalars = P::unpack_scalars(&partial_eq_ind_evals[..]); + let partial_eq_ind_evals_scalars = P::unpack_scalars(&partial_eq_ind_evals); // Evaluate each composition on a minimal packed prefix corresponding to the degree let pbase_prefix_lens = composition_degrees @@ -754,7 +744,7 @@ mod tests { }; use binius_hal::make_portable_backend; use binius_math::{ - CompositionPolyOS, DefaultEvaluationDomainFactory, EvaluationDomainFactory, MultilinearPoly, + CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, MultilinearPoly, }; use binius_ntt::SingleThreadedNTT; use rand::{prelude::StdRng, SeedableRng}; @@ -823,7 +813,7 @@ mod tests { .map(|i| interleaved_scalars[(i << log_batch) + batch_idx]) .collect::>(); - for (i, &point) in max_domain.points()[1 << skip_rounds..] + for (i, &point) in max_domain.finite_points()[1 << skip_rounds..] [..extrapolated_scalars_cnt] .iter() .enumerate() @@ -889,11 +879,11 @@ mod tests { let compositions = [ Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap()) - as Arc>>, + as Arc>>, Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap()) - as Arc>>, + as Arc>>, Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap()) - as Arc>>, + as Arc>>, ]; let backend = make_portable_backend(); diff --git a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs index c3325604d..bb9867c61 100644 --- a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs @@ -9,7 +9,7 @@ use binius_field::{ }; use binius_hal::{ComputationBackend, SumcheckEvaluator}; use binius_math::{ - CompositionPolyOS, EvaluationDomainFactory, InterpolationDomain, MLEDirectAdapter, + CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MLEDirectAdapter, MultilinearPoly, MultilinearQuery, }; use binius_maybe_rayon::prelude::*; @@ -23,7 +23,7 @@ use tracing::instrument; use crate::{ polynomial::{Error as PolynomialError, MultilinearComposite}, protocols::sumcheck::{ - common::{determine_switchovers, equal_n_vars_check, small_field_embedding_degree_check}, + common::{determine_switchovers, equal_n_vars_check}, prove::{ common::fold_partial_eq_ind, univariate::{ @@ -41,13 +41,13 @@ use crate::{ pub fn validate_witness<'a, F, P, M, Composition>( multilinears: &[M], - zero_claims: impl IntoIterator, Composition)>, + zero_claims: impl IntoIterator, ) -> Result<(), Error> where F: Field, P: PackedField, M: MultilinearPoly

+ Send + Sync, - Composition: CompositionPolyOS

+ 'a, + Composition: CompositionPoly

+ 'a, { let n_vars = multilinears .first() @@ -99,7 +99,7 @@ where #[getset(get = "pub")] multilinears: Vec, switchover_rounds: Vec, - compositions: Vec<(Arc, CompositionBase, Composition)>, + compositions: Vec<(String, CompositionBase, Composition)>, zerocheck_challenges: Vec, domains: Vec>, backend: &'a Backend, @@ -111,21 +111,21 @@ where impl<'a, 'm, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend> UnivariateZerocheck<'a, 'm, FDomain, FBase, P, CompositionBase, Composition, M, Backend> where - F: Field + ExtensionField + ExtensionField, + F: Field, FDomain: Field, FBase: ExtensionField, P: PackedFieldIndexable + PackedExtension + PackedExtension + PackedExtension, - CompositionBase: CompositionPolyOS<

>::PackedSubfield>, - Composition: CompositionPolyOS

, + CompositionBase: CompositionPoly<

>::PackedSubfield>, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync + 'm, Backend: ComputationBackend, { pub fn new( multilinears: Vec, - zero_claims: impl IntoIterator, CompositionBase, Composition)>, + zero_claims: impl IntoIterator, zerocheck_challenges: &[F], evaluation_domain_factory: impl EvaluationDomainFactory, switchover_fn: impl Fn(usize) -> usize, @@ -154,8 +154,6 @@ where validate_witness(&multilinears, &compositions)?; } - small_field_embedding_degree_check::<_, FBase, P, _>(&multilinears)?; - let switchover_rounds = determine_switchovers(&multilinears, switchover_fn); let zerocheck_challenges = zerocheck_challenges.to_vec(); @@ -188,16 +186,7 @@ where pub fn into_regular_zerocheck( self, ) -> Result< - ZerocheckProver< - 'a, - FDomain, - FBase, - P, - CompositionBase, - Composition, - MultilinearWitness<'m, P>, - Backend, - >, + ZerocheckProver<'a, FDomain, P, Composition, MultilinearWitness<'m, P>, Backend>, Error, > { if self.univariate_evals_output.is_some() { @@ -224,26 +213,29 @@ where validate_witness(&multilinears, &compositions)?; } + let compositions = self + .compositions + .into_iter() + .map(|(_, _, composition)| composition) + .collect::>(); + // Evaluate zerocheck partial indicator in variables 1..n_vars let start = self.n_vars.min(1); let partial_eq_ind_evals = self .backend .tensor_product_full_query(&self.zerocheck_challenges[start..])?; - let claimed_sums = vec![F::ZERO; self.compositions.len()]; + let claimed_sums = vec![F::ZERO; compositions.len()]; // This is a regular multilinear zerocheck constructor, split over two creation stages. ZerocheckProver::new( multilinears, - self.switchover_rounds, - self.compositions - .into_iter() - .map(|(_, a, b)| (a, b)) - .collect(), + &self.switchover_rounds, + compositions, partial_eq_ind_evals, self.zerocheck_challenges, claimed_sums, self.domains, - RegularFirstRound::BaseField, + RegularFirstRound::SkipCube, self.backend, ) } @@ -253,15 +245,15 @@ impl<'a, 'm, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend> UnivariateZerocheckProver<'a, F> for UnivariateZerocheck<'a, 'm, FDomain, FBase, P, CompositionBase, Composition, M, Backend> where - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, FDomain: TowerField, FBase: ExtensionField, P: PackedFieldIndexable + PackedExtension + PackedExtension + PackedExtension, - CompositionBase: CompositionPolyOS> + 'static, - Composition: CompositionPolyOS

+ 'static, + CompositionBase: CompositionPoly> + 'static, + Composition: CompositionPoly

+ 'static, M: MultilinearPoly

+ Send + Sync + 'm, Backend: ComputationBackend, { @@ -383,27 +375,30 @@ where .switchover_rounds .into_iter() .map(|switchover_round| switchover_round.saturating_sub(skip_rounds)) - .collect(); + .collect::>(); let zerocheck_challenges = self.zerocheck_challenges.clone(); + let compositions = self + .compositions + .into_iter() + .map(|(_, _, composition)| composition) + .collect(); + // This is also regular multilinear zerocheck constructor, but "jump started" in round // `skip_rounds` while using witness with a projected univariate round. - // NB: first round evaluator has to be overriden due to issues proving + // NB: first round evaluator has to be overridden due to issues proving // `P: RepackedExtension

` relation in the generic context, as well as the need // to use later round evaluator (as this _is_ a "later" round, albeit numbered at zero) - let regular_prover = ZerocheckProver::<_, FBase, _, _, _, _, _>::new( + let regular_prover = ZerocheckProver::new( partial_low_multilinears, - switchover_rounds, - self.compositions - .into_iter() - .map(|(_, a, b)| (a, b)) - .collect(), + &switchover_rounds, + compositions, partial_eq_ind_evals, zerocheck_challenges, claimed_prime_sums, self.domains, - RegularFirstRound::LargeField, + RegularFirstRound::LaterRound, self.backend, )?; @@ -413,8 +408,8 @@ where #[derive(Debug, Clone, Copy)] enum RegularFirstRound { - BaseField, - LargeField, + SkipCube, + LaterRound, } /// A "regular" multilinear zerocheck prover. @@ -432,10 +427,9 @@ enum RegularFirstRound { /// /// [Gruen24]: #[derive(Debug)] -pub struct ZerocheckProver<'a, FDomain, FBase, P, CompositionBase, Composition, M, Backend> +pub struct ZerocheckProver<'a, FDomain, P, Composition, M, Backend> where FDomain: Field, - FBase: PackedField, P: PackedField, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, @@ -445,32 +439,26 @@ where eq_ind_eval: P::Scalar, partial_eq_ind_evals: Backend::Vec

, zerocheck_challenges: Vec, - compositions: Vec<(CompositionBase, Composition)>, + compositions: Vec, domains: Vec>, first_round: RegularFirstRound, - _f_base_marker: PhantomData, } -impl<'a, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend> - ZerocheckProver<'a, FDomain, FBase, P, CompositionBase, Composition, M, Backend> +impl<'a, F, FDomain, P, Composition, M, Backend> + ZerocheckProver<'a, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField + ExtensionField, + F: Field, FDomain: Field, - FBase: ExtensionField, - P: PackedFieldIndexable - + PackedExtension - + PackedExtension - + PackedExtension, - CompositionBase: CompositionPolyOS>, - Composition: CompositionPolyOS

, + P: PackedFieldIndexable + PackedExtension, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { #[allow(clippy::too_many_arguments)] fn new( multilinears: Vec, - switchover_rounds: Vec, - compositions: Vec<(CompositionBase, Composition)>, + switchover_rounds: &[usize], + compositions: Vec, partial_eq_ind_evals: Backend::Vec

, zerocheck_challenges: Vec, claimed_prime_sums: Vec, @@ -480,8 +468,8 @@ where ) -> Result { let evaluation_points = domains .iter() - .max_by_key(|domain| domain.points().len()) - .map_or_else(|| Vec::new(), |domain| domain.points().to_vec()); + .max_by_key(|domain| domain.size()) + .map_or_else(|| Vec::new(), |domain| domain.finite_points().to_vec()); if claimed_prime_sums.len() != compositions.len() { bail!(Error::IncorrectClaimedPrimeSumsLength); @@ -489,7 +477,7 @@ where let state = ProverState::new_with_switchover_rounds( multilinears, - &switchover_rounds, + switchover_rounds, claimed_prime_sums, evaluation_points, backend, @@ -517,7 +505,6 @@ where compositions, domains, first_round, - _f_base_marker: PhantomData, }) } @@ -547,18 +534,13 @@ where } } -impl SumcheckProver - for ZerocheckProver<'_, FDomain, FBase, P, CompositionBase, Composition, M, Backend> +impl SumcheckProver + for ZerocheckProver<'_, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField + ExtensionField, + F: Field, FDomain: Field, - FBase: ExtensionField, - P: PackedFieldIndexable - + PackedExtension - + PackedExtension - + PackedExtension, - CompositionBase: CompositionPolyOS<

>::PackedSubfield>, - Composition: CompositionPolyOS

, + P: PackedFieldIndexable + PackedExtension, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -580,33 +562,29 @@ where #[instrument(skip_all, name = "ZerocheckProver::execute", level = "debug")] fn execute(&mut self, batch_coeff: F) -> Result, Error> { let round = self.round(); - let base_field_first_round = - round == 0 && matches!(self.first_round, RegularFirstRound::BaseField); - let coeffs = if base_field_first_round { + let skip_cube_first_round = + round == 0 && matches!(self.first_round, RegularFirstRound::SkipCube); + let coeffs = if skip_cube_first_round { let evaluators = izip!(&self.compositions, &self.domains) - .map(|((composition_base, composition), interpolation_domain)| { - ZerocheckFirstRoundEvaluator { - composition_base, - composition, - interpolation_domain, - partial_eq_ind_evals: &self.partial_eq_ind_evals, - _f_base_marker: PhantomData::, - } + .map(|(composition, interpolation_domain)| ZerocheckFirstRoundEvaluator { + composition, + interpolation_domain, + partial_eq_ind_evals: &self.partial_eq_ind_evals, }) .collect::>(); - let evals = self.state.calculate_first_round_evals(&evaluators)?; + let evals = self.state.calculate_round_evals(&evaluators)?; self.state .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)? } else { let evaluators = izip!(&self.compositions, &self.domains) - .map(|((_, composition), interpolation_domain)| ZerocheckLaterRoundEvaluator { + .map(|(composition, interpolation_domain)| ZerocheckLaterRoundEvaluator { composition, interpolation_domain, partial_eq_ind_evals: &self.partial_eq_ind_evals, round_zerocheck_challenge: self.zerocheck_challenges[round], }) .collect::>(); - let evals = self.state.calculate_later_round_evals(&evaluators)?; + let evals = self.state.calculate_round_evals(&evaluators)?; self.state .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)? }; @@ -636,28 +614,22 @@ where } } -struct ZerocheckFirstRoundEvaluator<'a, P, FBase, FDomain, CompositionBase, Composition> +struct ZerocheckFirstRoundEvaluator<'a, P, FDomain, Composition> where P: PackedField, - FBase: Field, FDomain: Field, { - composition_base: &'a CompositionBase, composition: &'a Composition, interpolation_domain: &'a InterpolationDomain, partial_eq_ind_evals: &'a [P], - _f_base_marker: PhantomData, } -impl SumcheckEvaluator - for ZerocheckFirstRoundEvaluator<'_, P, FBase, FDomain, CompositionBase, Composition> +impl SumcheckEvaluator + for ZerocheckFirstRoundEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField + ExtensionField, - FBase: Field, - P: PackedField + PackedExtension, + P: PackedField>, FDomain: Field, - CompositionBase: CompositionPolyOS>, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn eval_point_indices(&self) -> Range { // In the first round of zerocheck we can uniquely determine the degree d @@ -670,7 +642,7 @@ where &self, subcube_vars: usize, subcube_index: usize, - batch_query: &[&[PackedSubfield]], + batch_query: &[&[P]], ) -> P { // If the composition is a linear polynomial, then the composite multivariate polynomial // is multilinear. If the prover is honest, then this multilinear is identically zero, @@ -681,7 +653,7 @@ where let row_len = batch_query.first().map_or(0, |row| row.len()); stackalloc_with_default(row_len, |evals| { - self.composition_base + self.composition .batch_evaluate(batch_query, evals) .expect("correct by query construction invariant"); @@ -705,13 +677,12 @@ where } } -impl SumcheckInterpolator - for ZerocheckFirstRoundEvaluator<'_, P, FBase, FDomain, CompositionBase, Composition> +impl SumcheckInterpolator + for ZerocheckFirstRoundEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField + ExtensionField, - FBase: Field, - FDomain: Field, + F: Field + ExtensionField, P: PackedField, + FDomain: Field, { fn round_evals_to_coeffs( &self, @@ -741,13 +712,12 @@ where round_zerocheck_challenge: P::Scalar, } -impl SumcheckEvaluator +impl SumcheckEvaluator for ZerocheckLaterRoundEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, + P: PackedField>, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn eval_point_indices(&self) -> Range { // We can uniquely derive the degree d univariate round polynomial r from evaluations at @@ -796,7 +766,7 @@ where impl SumcheckInterpolator for ZerocheckLaterRoundEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension, FDomain: Field, { diff --git a/crates/core/src/protocols/sumcheck/tests.rs b/crates/core/src/protocols/sumcheck/tests.rs index 667f63162..a77209641 100644 --- a/crates/core/src/protocols/sumcheck/tests.rs +++ b/crates/core/src/protocols/sumcheck/tests.rs @@ -16,7 +16,7 @@ use binius_field::{ }; use binius_hal::{make_portable_backend, ComputationBackend, ComputationBackendExt}; use binius_math::{ - ArithExpr, CompositionPolyOS, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, + ArithExpr, CompositionPoly, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly, MultilinearQuery, }; use binius_maybe_rayon::{current_num_threads, prelude::*}; @@ -50,7 +50,7 @@ struct PowerComposition { exponent: usize, } -impl CompositionPolyOS

for PowerComposition { +impl CompositionPoly

for PowerComposition { fn n_vars(&self) -> usize { 1 } @@ -103,7 +103,7 @@ fn compute_composite_sum( where P: PackedField, M: MultilinearPoly

+ Send + Sync, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { let n_vars = multilinears .first() @@ -263,11 +263,11 @@ fn make_test_sumcheck<'a, F, FDomain, P, PExt, Backend>( backend: &'a Backend, ) -> ( Vec>, - SumcheckClaim + Clone + 'static>, + SumcheckClaim + Clone + 'static>, impl SumcheckProver + 'a, ) where - F: Field + ExtensionField + ExtensionField, + F: Field, FDomain: Field, P: PackedField, PExt: PackedField @@ -287,10 +287,9 @@ where .map(MLEEmbeddingAdapter::<_, PExt, _>::from) .collect::>(); - let mut claim_composite_sums = - Vec::>>>::new(); + let mut claim_composite_sums = Vec::>>>::new(); let mut prover_composite_sums = - Vec::>>>::new(); + Vec::>>>::new(); if max_degree >= 1 { let identity_composition = diff --git a/crates/core/src/protocols/sumcheck/univariate.rs b/crates/core/src/protocols/sumcheck/univariate.rs index d433b5d68..df833eef4 100644 --- a/crates/core/src/protocols/sumcheck/univariate.rs +++ b/crates/core/src/protocols/sumcheck/univariate.rs @@ -230,7 +230,7 @@ mod tests { }; use binius_hal::ComputationBackend; use binius_math::{ - CompositionPolyOS, DefaultEvaluationDomainFactory, EvaluationDomainFactory, + CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MultilinearPoly, }; use groestl_crypto::Groestl256; @@ -437,31 +437,31 @@ mod tests { let prover_compositions = [ ( "pair".into(), - pair.clone() as Arc>>, - pair.clone() as Arc>>, + pair.clone() as Arc>>, + pair.clone() as Arc>>, ), ( "triple".into(), - triple.clone() as Arc>>, - triple.clone() as Arc>>, + triple.clone() as Arc>>, + triple.clone() as Arc>>, ), ( "quad".into(), - quad.clone() as Arc>>, - quad.clone() as Arc>>, + quad.clone() as Arc>>, + quad.clone() as Arc>>, ), ]; let prover_adapter_compositions = [ - CompositionScalarAdapter::new(pair.clone() as Arc>), - CompositionScalarAdapter::new(triple.clone() as Arc>), - CompositionScalarAdapter::new(quad.clone() as Arc>), + CompositionScalarAdapter::new(pair.clone() as Arc>), + CompositionScalarAdapter::new(triple.clone() as Arc>), + CompositionScalarAdapter::new(quad.clone() as Arc>), ]; let verifier_compositions = [ - pair as Arc>, - triple as Arc>, - quad as Arc>, + pair as Arc>, + triple as Arc>, + quad as Arc>, ]; for skip_rounds in 0..=max_n_vars { diff --git a/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs b/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs index eb14a3bba..75756851a 100644 --- a/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{util::inner_product_unchecked, Field, TowerField}; -use binius_math::{CompositionPolyOS, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory}; +use binius_math::{CompositionPoly, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory}; use binius_utils::{bail, sorting::is_sorted_ascending}; use tracing::instrument; @@ -50,7 +50,7 @@ pub fn batch_verify_zerocheck_univariate_round( ) -> Result, Error> where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, Challenger_: Challenger, { // Check that the claims are in descending order by n_vars diff --git a/crates/core/src/protocols/sumcheck/verify.rs b/crates/core/src/protocols/sumcheck/verify.rs index b06de94ac..fe778c1a2 100644 --- a/crates/core/src/protocols/sumcheck/verify.rs +++ b/crates/core/src/protocols/sumcheck/verify.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{Field, TowerField}; -use binius_math::{evaluate_univariate, CompositionPolyOS}; +use binius_math::{evaluate_univariate, CompositionPoly}; use binius_utils::{bail, sorting::is_sorted_ascending}; use itertools::izip; use tracing::instrument; @@ -34,7 +34,7 @@ pub fn batch_verify( ) -> Result, Error> where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, Challenger_: Challenger, { let start = BatchVerifyStart { @@ -69,7 +69,7 @@ pub fn batch_verify_with_start( ) -> Result, Error> where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, Challenger_: Challenger, { let BatchVerifyStart { @@ -177,7 +177,7 @@ pub fn compute_expected_batch_composite_evaluation_single_claim Result where - Composition: CompositionPolyOS, + Composition: CompositionPoly, { let composite_evals = claim .composite_sums() @@ -193,7 +193,7 @@ fn compute_expected_batch_composite_evaluation_multi_claim], ) -> Result where - Composition: CompositionPolyOS, + Composition: CompositionPoly, { izip!(batch_coeffs, claims, multilinear_evals.iter()) .map(|(batch_coeff, claim, multilinear_evals)| { diff --git a/crates/core/src/protocols/sumcheck/zerocheck.rs b/crates/core/src/protocols/sumcheck/zerocheck.rs index 4bce691c9..5d47e28bd 100644 --- a/crates/core/src/protocols/sumcheck/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/zerocheck.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use binius_field::{util::eq, Field, PackedField}; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::{bail, sorting::is_sorted_ascending}; use getset::CopyGetters; @@ -22,7 +22,7 @@ pub struct ZerocheckClaim { impl ZerocheckClaim where - Composition: CompositionPolyOS, + Composition: CompositionPoly, { pub fn new( n_vars: usize, @@ -60,7 +60,7 @@ where } /// Requirement: zerocheck challenges have been sampled before this is called -pub fn reduce_to_sumchecks>( +pub fn reduce_to_sumchecks>( claims: &[ZerocheckClaim], ) -> Result>>, Error> { // Check that the claims are in descending order by n_vars @@ -100,7 +100,7 @@ pub fn reduce_to_sumchecks>( /// /// Note that due to univariatization of some rounds the number of challenges may be less than /// the maximum number of variables among claims. -pub fn verify_sumcheck_outputs>( +pub fn verify_sumcheck_outputs>( claims: &[ZerocheckClaim], zerocheck_challenges: &[F], sumcheck_output: BatchSumcheckOutput, @@ -158,10 +158,10 @@ pub struct ExtraProduct { pub inner: Composition, } -impl CompositionPolyOS

for ExtraProduct +impl CompositionPoly

for ExtraProduct where P: PackedField, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn n_vars(&self) -> usize { self.inner.n_vars() + 1 @@ -195,8 +195,8 @@ mod tests { use std::{iter, sync::Arc}; use binius_field::{ - BinaryField128b, BinaryField32b, BinaryField8b, ExtensionField, PackedBinaryField1x128b, - PackedExtension, PackedFieldIndexable, PackedSubfield, RepackedExtension, + BinaryField128b, BinaryField32b, BinaryField8b, PackedBinaryField1x128b, PackedExtension, + PackedFieldIndexable, PackedSubfield, RepackedExtension, }; use binius_hal::{make_portable_backend, ComputationBackend, ComputationBackendExt}; use binius_math::{ @@ -236,10 +236,10 @@ mod tests { Backend, > where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension + RepackedExtension

, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync + 'static, Backend: ComputationBackend, { diff --git a/crates/core/src/protocols/test_utils.rs b/crates/core/src/protocols/test_utils.rs index 1ad370ad1..bfa4a0b47 100644 --- a/crates/core/src/protocols/test_utils.rs +++ b/crates/core/src/protocols/test_utils.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use binius_field::{ExtensionField, Field, PackedField}; -use binius_math::{ArithExpr, CompositionPolyOS, MLEEmbeddingAdapter, MultilinearExtension}; +use binius_math::{ArithExpr, CompositionPoly, MLEEmbeddingAdapter, MultilinearExtension}; use rand::Rng; use crate::polynomial::Error as PolynomialError; @@ -19,10 +19,10 @@ impl AddOneComposition { } } -impl CompositionPolyOS

for AddOneComposition +impl CompositionPoly

for AddOneComposition where P: PackedField, - Inner: CompositionPolyOS

, + Inner: CompositionPoly

, { fn n_vars(&self) -> usize { self.inner.n_vars() @@ -56,7 +56,7 @@ impl TestProductComposition { } } -impl

CompositionPolyOS

for TestProductComposition +impl

CompositionPoly

for TestProductComposition where P: PackedField, { @@ -122,7 +122,7 @@ where } pub fn transform_poly( - multilin: MultilinearExtension, + multilin: &MultilinearExtension, ) -> Result, PolynomialError> where F: Field, diff --git a/crates/core/src/reed_solomon/reed_solomon.rs b/crates/core/src/reed_solomon/reed_solomon.rs index 771a306f1..ade7bace2 100644 --- a/crates/core/src/reed_solomon/reed_solomon.rs +++ b/crates/core/src/reed_solomon/reed_solomon.rs @@ -15,7 +15,7 @@ use std::marker::PhantomData; use binius_field::{BinaryField, ExtensionField, PackedField, RepackedExtension}; use binius_maybe_rayon::prelude::*; use binius_ntt::{AdditiveNTT, DynamicDispatchNTT, Error, NTTOptions, ThreadingSettings}; -use binius_utils::{bail, checked_arithmetics::checked_log_2}; +use binius_utils::bail; use getset::CopyGetters; use tracing::instrument; @@ -40,7 +40,7 @@ where pub fn new( log_dimension: usize, log_inv_rate: usize, - ntt_options: NTTOptions, + ntt_options: &NTTOptions, ) -> Result { // Since we split work between log_inv_rate threads, we need to decrease the number of threads per each NTT transformation. let ntt_log_threads = ntt_options @@ -49,11 +49,11 @@ where .saturating_sub(log_inv_rate); let ntt = DynamicDispatchNTT::new( log_dimension + log_inv_rate, - NTTOptions { + &NTTOptions { thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: ntt_log_threads, }, - ..ntt_options + precompute_twiddles: ntt_options.precompute_twiddles, }, )?; @@ -160,16 +160,11 @@ where /// /// * If the `code` buffer does not have capacity for `len() << log_batch_size` field elements. #[instrument(skip_all, level = "debug")] - pub fn encode_ext_batch_inplace( + pub fn encode_ext_batch_inplace>( &self, code: &mut [PE], log_batch_size: usize, - ) -> Result<(), Error> - where - PE: RepackedExtension

, - PE::Scalar: ExtensionField<

::Scalar>, - { - let log_degree = checked_log_2(PE::Scalar::DEGREE); - self.encode_batch_inplace(PE::cast_bases_mut(code), log_batch_size + log_degree) + ) -> Result<(), Error> { + self.encode_batch_inplace(PE::cast_bases_mut(code), log_batch_size + PE::Scalar::LOG_DEGREE) } } diff --git a/crates/core/src/ring_switch/common.rs b/crates/core/src/ring_switch/common.rs index e86876e49..fb76fe80d 100644 --- a/crates/core/src/ring_switch/common.rs +++ b/crates/core/src/ring_switch/common.rs @@ -72,7 +72,7 @@ impl<'a, F: TowerField> EvalClaimSystem<'a, F> { pub fn new( oracles: &MultilinearOracleSet, commit_meta: &'a CommitMeta, - oracle_to_commit_index: SparseIndex, + oracle_to_commit_index: &SparseIndex, eval_claims: &'a [EvalcheckMultilinearClaim], ) -> Result { // Sort evaluation claims in ascending order by number of packed variables. This must diff --git a/crates/core/src/ring_switch/eq_ind.rs b/crates/core/src/ring_switch/eq_ind.rs index daffe0ba8..3505969fd 100644 --- a/crates/core/src/ring_switch/eq_ind.rs +++ b/crates/core/src/ring_switch/eq_ind.rs @@ -1,10 +1,14 @@ // Copyright 2024-2025 Irreducible Inc. -use std::{iter, marker::PhantomData, sync::Arc}; +use std::{any::TypeId, iter, marker::PhantomData, sync::Arc}; use binius_field::{ - util::inner_product_unchecked, ExtensionField, Field, PackedExtension, PackedField, - PackedFieldIndexable, TowerField, + byte_iteration::{ + can_iterate_bytes, create_partial_sums_lookup_tables, iterate_bytes, ByteIteratorCallback, + }, + util::inner_product_unchecked, + BinaryField1b, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, + TowerField, }; use binius_math::{tensor_prod_eq_ind, MultilinearExtension}; use binius_maybe_rayon::prelude::*; @@ -17,6 +21,34 @@ use crate::{ tensor_algebra::TensorAlgebra, }; +/// Information about the row-batching coefficients. +#[derive(Debug)] +pub struct RowBatchCoeffs { + coeffs: Vec, + /// This is a lookup table for the partial sums of the coefficients + /// that is used to efficiently fold with 1-bit coefficients. + partial_sums_lookup_table: Vec, +} + +impl RowBatchCoeffs { + pub fn new(coeffs: Vec) -> Self { + let partial_sums_lookup_table = if coeffs.len() >= 8 { + create_partial_sums_lookup_tables(coeffs.as_slice()) + } else { + Vec::new() + }; + + Self { + coeffs, + partial_sums_lookup_table, + } + } + + pub fn coeffs(&self) -> &[F] { + &self.coeffs + } +} + /// The multilinear function $A$ from [DP24] Section 5. /// /// The function $A$ is $\ell':= \ell - \kappa$-variate and depends on the last $\ell'$ coordinates @@ -27,7 +59,7 @@ use crate::{ pub struct RingSwitchEqInd { /// $z_{\kappa}, \ldots, z_{\ell-1}$ z_vals: Arc<[F]>, - row_batch_coeffs: Arc<[F]>, + row_batch_coeffs: Arc>, mixing_coeff: F, _marker: PhantomData, } @@ -39,10 +71,10 @@ where { pub fn new( z_vals: Arc<[F]>, - row_batch_coeffs: Arc<[F]>, + row_batch_coeffs: Arc>, mixing_coeff: F, ) -> Result { - if row_batch_coeffs.len() < F::DEGREE { + if row_batch_coeffs.coeffs.len() < F::DEGREE { bail!(Error::InvalidArgs( "RingSwitchEqInd::new expects row_batch_coeffs length greater than or equal to \ the extension degree" @@ -67,20 +99,53 @@ where P::unpack_scalars_mut(&mut evals) .par_iter_mut() .for_each(|val| { - let vert = *val; - *val = inner_product_unchecked( - self.row_batch_coeffs.iter().copied(), - ExtensionField::::iter_bases(&vert), - ); + *val = inner_product_subfield(*val, &self.row_batch_coeffs); }); Ok(MultilinearExtension::from_values(evals)?) } } +#[inline(always)] +fn inner_product_subfield(value: F, row_batch_coeffs: &RowBatchCoeffs) -> F +where + FSub: Field, + F: ExtensionField, +{ + if TypeId::of::() == TypeId::of::() && can_iterate_bytes::() { + // Special case when we are folding with 1-bit coefficients. + // Use partial sums lookup table to speed up the computation. + + struct Callback<'a, F> { + partial_sums_lookup: &'a [F], + result: F, + } + + impl ByteIteratorCallback for Callback<'_, F> { + #[inline(always)] + fn call(&mut self, iter: impl Iterator) { + for (byte_index, byte) in iter.enumerate() { + self.result += self.partial_sums_lookup[(byte_index << 8) + byte as usize]; + } + } + } + + let mut callback = Callback { + partial_sums_lookup: &row_batch_coeffs.partial_sums_lookup_table, + result: F::ZERO, + }; + iterate_bytes(std::slice::from_ref(&value), &mut callback); + + callback.result + } else { + // fall back to the general case + inner_product_unchecked(row_batch_coeffs.coeffs.iter().copied(), F::iter_bases(&value)) + } +} + impl MultivariatePoly for RingSwitchEqInd where FSub: TowerField, - F: TowerField + PackedField + ExtensionField + PackedExtension, + F: TowerField + PackedField + PackedExtension, { fn n_vars(&self) -> usize { self.z_vals.len() @@ -108,7 +173,7 @@ where }, ); - let folded_eval = tensor_eval.fold_vertical(&self.row_batch_coeffs); + let folded_eval = tensor_eval.fold_vertical(&self.row_batch_coeffs.coeffs); Ok(folded_eval) } @@ -141,7 +206,8 @@ mod tests { let row_batch_coeffs = repeat_with(|| ::random(&mut rng)) .take(1 << kappa) - .collect::>(); + .collect::>(); + let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(row_batch_coeffs)); let eval_point = repeat_with(|| ::random(&mut rng)) .take(n_vars) diff --git a/crates/core/src/ring_switch/prove.rs b/crates/core/src/ring_switch/prove.rs index ab2d6aee8..5a7423aad 100644 --- a/crates/core/src/ring_switch/prove.rs +++ b/crates/core/src/ring_switch/prove.rs @@ -11,6 +11,7 @@ use tracing::instrument; use super::{ common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc}, + eq_ind::RowBatchCoeffs, error::Error, tower_tensor_algebra::TowerTensorAlgebra, }; @@ -75,11 +76,12 @@ where // Sample the row-batching randomness. let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa()); - let row_batch_coeffs = - Arc::from(MultilinearQuery::::expand(&row_batch_challenges).into_expansion()); + let row_batch_coeffs = Arc::new(RowBatchCoeffs::new( + MultilinearQuery::::expand(&row_batch_challenges).into_expansion(), + )); let row_batched_evals = - compute_row_batched_sumcheck_evals(scaled_tensor_elems, &row_batch_coeffs); + compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs()); transcript.message().write_scalar_slice(&row_batched_evals); // Create the reduced PIOP sumcheck witnesses. @@ -217,7 +219,7 @@ where fn make_ring_switch_eq_inds( sumcheck_claim_descs: &[PIOPSumcheckClaimDesc], suffix_descs: &[EvalClaimSuffixDesc], - row_batch_coeffs: Arc<[F]>, + row_batch_coeffs: Arc>, mixing_coeffs: &[F], ) -> Result>, Error> where @@ -238,7 +240,7 @@ where fn make_ring_switch_eq_ind( suffix_desc: &EvalClaimSuffixDesc>, - row_batch_coeffs: Arc<[FExt]>, + row_batch_coeffs: Arc>>, mixing_coeff: FExt, ) -> Result, Error> where diff --git a/crates/core/src/ring_switch/tests.rs b/crates/core/src/ring_switch/tests.rs index 5546e786e..2cb0ce77d 100644 --- a/crates/core/src/ring_switch/tests.rs +++ b/crates/core/src/ring_switch/tests.rs @@ -14,7 +14,7 @@ use binius_math::{ DefaultEvaluationDomainFactory, MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly, MultilinearQuery, }; -use binius_utils::serialization::{DeserializeBytes, SerializeBytes}; +use binius_utils::{DeserializeBytes, SerializeBytes}; use groestl_crypto::Groestl256; use rand::prelude::*; @@ -208,7 +208,7 @@ fn with_test_instance_from_oracles( // Finish setting up the test case let system = - EvalClaimSystem::new(oracles, &commit_meta, oracle_to_commit_index, &eval_claims).unwrap(); + EvalClaimSystem::new(oracles, &commit_meta, &oracle_to_commit_index, &eval_claims).unwrap(); check_eval_point_consistency(&system); func(rng, system, witnesses) @@ -304,7 +304,7 @@ fn commit_prove_verify_piop( // Finish setting up the test case let system = - EvalClaimSystem::new(oracles, &commit_meta, oracle_to_commit_index, &eval_claims).unwrap(); + EvalClaimSystem::new(oracles, &commit_meta, &oracle_to_commit_index, &eval_claims).unwrap(); check_eval_point_consistency(&system); let mut proof = ProverTranscript::>::new(); diff --git a/crates/core/src/ring_switch/verify.rs b/crates/core/src/ring_switch/verify.rs index 330b0ffce..8a841b2aa 100644 --- a/crates/core/src/ring_switch/verify.rs +++ b/crates/core/src/ring_switch/verify.rs @@ -8,6 +8,7 @@ use binius_utils::checked_arithmetics::log2_ceil_usize; use bytes::Buf; use itertools::izip; +use super::eq_ind::RowBatchCoeffs; use crate::{ fiat_shamir::{CanSample, Challenger}, piop::PIOPSumcheckClaim, @@ -50,8 +51,9 @@ where // Sample the row-batching randomness. let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa()); - let row_batch_coeffs = - Arc::from(MultilinearQuery::::expand(&row_batch_challenges).into_expansion()); + let row_batch_coeffs = Arc::new(RowBatchCoeffs::new( + MultilinearQuery::::expand(&row_batch_challenges).into_expansion(), + )); // For each original evaluation claim, receive the row-batched evaluation claim. let row_batched_evals = transcript @@ -66,7 +68,7 @@ where &system.eval_claim_to_prefix_desc_index, ); for (expected, tensor_elem) in iter::zip(mixed_row_batched_evals, tensor_elems) { - if tensor_elem.fold_vertical(&row_batch_coeffs) != expected { + if tensor_elem.fold_vertical(row_batch_coeffs.coeffs()) != expected { return Err(VerificationError::IncorrectRowBatchedSum.into()); } } @@ -75,7 +77,7 @@ where let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, Tower>( &system.sumcheck_claim_descs, &system.suffix_descs, - row_batch_coeffs, + &row_batch_coeffs, &mixing_coeffs, )?; let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals) @@ -173,7 +175,7 @@ fn accumulate_evaluations_by_prefixes( fn make_ring_switch_eq_inds( sumcheck_claim_descs: &[PIOPSumcheckClaimDesc], suffix_descs: &[EvalClaimSuffixDesc], - row_batch_coeffs: Arc<[F]>, + row_batch_coeffs: &Arc>, mixing_coeffs: &[F], ) -> Result>>, Error> where @@ -190,7 +192,7 @@ where fn make_ring_switch_eq_ind( suffix_desc: &EvalClaimSuffixDesc>, - row_batch_coeffs: Arc<[FExt]>, + row_batch_coeffs: Arc>>, mixing_coeff: FExt, ) -> Result>>, Error> where diff --git a/crates/core/src/tensor_algebra.rs b/crates/core/src/tensor_algebra.rs index a5d7ea5a7..cb071f259 100644 --- a/crates/core/src/tensor_algebra.rs +++ b/crates/core/src/tensor_algebra.rs @@ -10,7 +10,6 @@ use std::{ use binius_field::{ square_transpose, util::inner_product_unchecked, ExtensionField, Field, PackedExtension, }; -use binius_utils::checked_arithmetics::checked_log_2; /// An element of the tensor algebra defined as the tensor product of `FE` and `FE` as fields. /// @@ -64,7 +63,7 @@ where /// Returns $\kappa$, the base-2 logarithm of the extension degree. pub const fn kappa() -> usize { - checked_log_2(FE::DEGREE) + FE::LOG_DEGREE } /// Returns the byte size of an element. @@ -124,12 +123,7 @@ where } } -impl TensorAlgebra -where - F: Field, - FE: ExtensionField + PackedExtension, - FE::Scalar: ExtensionField, -{ +impl + PackedExtension> TensorAlgebra { /// Multiply by an element from the vertical subring. /// /// Internally, this performs a transpose, vertical scaling, then transpose sequence. If diff --git a/crates/core/src/transcript/error.rs b/crates/core/src/transcript/error.rs index 68a5eefd7..ee57c376f 100644 --- a/crates/core/src/transcript/error.rs +++ b/crates/core/src/transcript/error.rs @@ -1,7 +1,5 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_utils::serialization::Error as SerializationError; - #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Transcript is not empty, {remaining} bytes")] @@ -9,5 +7,5 @@ pub enum Error { #[error("Not enough bytes in the buffer")] NotEnoughBytes, #[error("Serialization error: {0}")] - Serialization(#[from] SerializationError), + Serialization(#[from] binius_utils::SerializationError), } diff --git a/crates/core/src/transcript/mod.rs b/crates/core/src/transcript/mod.rs index e6081b6c5..d014ea9a0 100644 --- a/crates/core/src/transcript/mod.rs +++ b/crates/core/src/transcript/mod.rs @@ -16,8 +16,8 @@ mod error; use std::{iter::repeat_with, slice}; -use binius_field::{deserialize_canonical, serialize_canonical, PackedField, TowerField}; -use binius_utils::serialization::{DeserializeBytes, SerializeBytes}; +use binius_field::{PackedField, TowerField}; +use binius_utils::{DeserializeBytes, SerializationMode, SerializeBytes}; use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut}; pub use error::Error; use tracing::warn; @@ -259,12 +259,14 @@ impl TranscriptReader<'_, B> { } pub fn read(&mut self) -> Result { - T::deserialize(self.buffer()).map_err(Into::into) + let mode = SerializationMode::CanonicalTower; + T::deserialize(self.buffer(), mode).map_err(Into::into) } pub fn read_vec(&mut self, n: usize) -> Result, Error> { + let mode = SerializationMode::CanonicalTower; let mut buffer = self.buffer(); - repeat_with(move || T::deserialize(&mut buffer).map_err(Into::into)) + repeat_with(move || T::deserialize(&mut buffer, mode).map_err(Into::into)) .take(n) .collect() } @@ -287,7 +289,8 @@ impl TranscriptReader<'_, B> { pub fn read_scalar_slice_into(&mut self, buf: &mut [F]) -> Result<(), Error> { let mut buffer = self.buffer(); for elem in buf { - *elem = deserialize_canonical(&mut buffer)?; + let mode = SerializationMode::CanonicalTower; + *elem = DeserializeBytes::deserialize(&mut buffer, mode)?; } Ok(()) } @@ -334,20 +337,27 @@ impl TranscriptWriter<'_, B> { } pub fn write(&mut self, value: &T) { - value - .serialize(self.buffer()) - .expect("TODO: propagate error") + self.proof_size_event_wrapper(|buffer| { + value + .serialize(buffer, SerializationMode::CanonicalTower) + .expect("TODO: propagate error"); + }); } pub fn write_slice(&mut self, values: &[T]) { - let mut buffer = self.buffer(); - for value in values { - value.serialize(&mut buffer).expect("TODO: propagate error") - } + self.proof_size_event_wrapper(|buffer| { + for value in values { + value + .serialize(&mut *buffer, SerializationMode::CanonicalTower) + .expect("TODO: propagate error"); + } + }); } pub fn write_bytes(&mut self, data: &[u8]) { - self.buffer().put_slice(data); + self.proof_size_event_wrapper(|buffer| { + buffer.put_slice(data); + }); } pub fn write_scalar(&mut self, f: F) { @@ -355,10 +365,12 @@ impl TranscriptWriter<'_, B> { } pub fn write_scalar_slice(&mut self, elems: &[F]) { - let mut buffer = self.buffer(); - for elem in elems { - serialize_canonical(*elem, &mut buffer).expect("TODO: propagate error"); - } + self.proof_size_event_wrapper(|buffer| { + for elem in elems { + SerializeBytes::serialize(elem, &mut *buffer, SerializationMode::CanonicalTower) + .expect("TODO: propagate error"); + } + }); } pub fn write_packed>(&mut self, packed: P) { @@ -378,6 +390,14 @@ impl TranscriptWriter<'_, B> { self.write_bytes(msg.as_bytes()) } } + + fn proof_size_event_wrapper(&mut self, f: F) { + let buffer = self.buffer(); + let start_bytes = buffer.remaining_mut(); + f(buffer); + let end_bytes = buffer.remaining_mut(); + tracing::event!(name: "proof_size", tracing::Level::INFO, counter=true, incremental=true, value=start_bytes - end_bytes); + } } impl CanSample for VerifierTranscript @@ -386,7 +406,8 @@ where Challenger_: Challenger, { fn sample(&mut self) -> F { - deserialize_canonical(self.combined.challenger.sampler()) + let mode = SerializationMode::CanonicalTower; + DeserializeBytes::deserialize(self.combined.challenger.sampler(), mode) .expect("challenger has infinite buffer") } } @@ -397,7 +418,8 @@ where Challenger_: Challenger, { fn sample(&mut self) -> F { - deserialize_canonical(self.combined.challenger.sampler()) + let mode = SerializationMode::CanonicalTower; + DeserializeBytes::deserialize(self.combined.challenger.sampler(), mode) .expect("challenger has infinite buffer") } } diff --git a/crates/core/src/transparent/constant.rs b/crates/core/src/transparent/constant.rs index 1c0108739..860ced321 100644 --- a/crates/core/src/transparent/constant.rs +++ b/crates/core/src/transparent/constant.rs @@ -1,18 +1,26 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{ExtensionField, TowerField}; -use binius_utils::bail; +use binius_field::{BinaryField128b, ExtensionField, TowerField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; /// A constant polynomial. -#[derive(Debug, Copy, Clone)] -pub struct Constant { +#[derive(Debug, Copy, Clone, SerializeBytes, DeserializeBytes)] +pub struct Constant { n_vars: usize, value: F, tower_level: usize, } +inventory::submit! { + >::register_deserializer( + "Constant", + |buf, mode| Ok(Box::new(Constant::::deserialize(&mut *buf, mode)?)) + ) +} + impl Constant { pub fn new(n_vars: usize, value: FS) -> Self where @@ -26,6 +34,7 @@ impl Constant { } } +#[erased_serialize_bytes] impl MultivariatePoly for Constant { fn n_vars(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/mod.rs b/crates/core/src/transparent/mod.rs index e95864fe8..a1e79368c 100644 --- a/crates/core/src/transparent/mod.rs +++ b/crates/core/src/transparent/mod.rs @@ -6,6 +6,7 @@ pub mod eq_ind; pub mod multilinear_extension; pub mod powers; pub mod select_row; +pub mod serialization; pub mod shift_ind; pub mod step_down; pub mod step_up; diff --git a/crates/core/src/transparent/multilinear_extension.rs b/crates/core/src/transparent/multilinear_extension.rs index 7df54751a..ef55e4d76 100644 --- a/crates/core/src/transparent/multilinear_extension.rs +++ b/crates/core/src/transparent/multilinear_extension.rs @@ -2,9 +2,15 @@ use std::{fmt::Debug, ops::Deref}; -use binius_field::{ExtensionField, PackedField, RepackedExtension, TowerField}; +use binius_field::{ + arch::OptimalUnderlier, as_packed_field::PackedType, packed::pack_slice, BinaryField128b, + BinaryField16b, BinaryField1b, BinaryField2b, BinaryField32b, BinaryField4b, BinaryField64b, + BinaryField8b, ExtensionField, PackedField, RepackedExtension, TowerField, +}; use binius_hal::{make_portable_backend, ComputationBackendExt}; +use binius_macros::erased_serialize_bytes; use binius_math::{MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly}; +use binius_utils::{DeserializeBytes, SerializationError, SerializationMode, SerializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -26,6 +32,72 @@ where data: MLEEmbeddingAdapter, } +impl SerializeBytes for MultilinearExtensionTransparent +where + P: PackedField, + PE: RepackedExtension

, + PE::Scalar: TowerField + ExtensionField, + Data: Deref + Debug + Send + Sync, +{ + fn serialize( + &self, + write_buf: impl bytes::BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + let elems = PE::iter_slice( + self.data + .packed_evals() + .expect("Evals should always be available here"), + ) + .collect::>(); + SerializeBytes::serialize(&elems, write_buf, mode) + } +} + +inventory::submit! { + >::register_deserializer( + "MultilinearExtensionTransparent", + |buf, mode| { + type U = OptimalUnderlier; + type F = BinaryField128b; + type P = PackedType; + let hypercube_evals = Vec::::deserialize(&mut *buf, mode)?; + let result: Box> = if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else { + Box::new(MultilinearExtensionTransparent::::from_values(pack_slice(&hypercube_evals)).unwrap()) + }; + Ok(result) + } + ) +} + +fn try_pack_slice(xs: &[F]) -> Option> +where + PS: PackedField, + F: ExtensionField, +{ + Some(pack_slice( + &xs.iter() + .copied() + .map(TryInto::try_into) + .collect::, _>>() + .ok()?, + )) +} + impl MultilinearExtensionTransparent where P: PackedField, @@ -49,6 +121,7 @@ where } } +#[erased_serialize_bytes] impl MultivariatePoly for MultilinearExtensionTransparent where F: TowerField + ExtensionField, diff --git a/crates/core/src/transparent/powers.rs b/crates/core/src/transparent/powers.rs index c0c912d9f..01a925a10 100644 --- a/crates/core/src/transparent/powers.rs +++ b/crates/core/src/transparent/powers.rs @@ -2,10 +2,11 @@ use std::iter::successors; -use binius_field::{Field, PackedField, TowerField}; +use binius_field::{BinaryField128b, PackedField, TowerField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; use binius_maybe_rayon::prelude::*; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use bytemuck::zeroed_vec; use itertools::{izip, Itertools}; @@ -13,13 +14,20 @@ use crate::polynomial::{Error, MultivariatePoly}; /// A transparent multilinear polynomial whose evaluation at index $i$ is $g^i$ for /// some field element $g$. -#[derive(Debug)] -pub struct Powers { +#[derive(Debug, SerializeBytes, DeserializeBytes)] +pub struct Powers { n_vars: usize, base: F, } -impl Powers { +inventory::submit! { + >::register_deserializer( + "Powers", + |buf, mode| Ok(Box::new(Powers::::deserialize(&mut *buf, mode)?)) + ) +} + +impl Powers { pub const fn new(n_vars: usize, base: F) -> Self { Self { n_vars, base } } @@ -49,6 +57,7 @@ impl Powers { } } +#[erased_serialize_bytes] impl> MultivariatePoly

for Powers { fn n_vars(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/select_row.rs b/crates/core/src/transparent/select_row.rs index fdcd32c5b..9ef3bf0e6 100644 --- a/crates/core/src/transparent/select_row.rs +++ b/crates/core/src/transparent/select_row.rs @@ -1,8 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{packed::set_packed_slice, BinaryField1b, Field, PackedField}; +use binius_field::{packed::set_packed_slice, BinaryField128b, BinaryField1b, Field, PackedField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -18,12 +19,19 @@ use crate::polynomial::{Error, MultivariatePoly}; /// ``` /// /// This is useful for defining boundary constraints -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct SelectRow { n_vars: usize, index: usize, } +inventory::submit! { + >::register_deserializer( + "SelectRow", + |buf, mode| Ok(Box::new(SelectRow::deserialize(&mut *buf, mode)?)) + ) +} + impl SelectRow { pub fn new(n_vars: usize, index: usize) -> Result { if index >= (1 << n_vars) { @@ -50,6 +58,7 @@ impl SelectRow { } } +#[erased_serialize_bytes] impl MultivariatePoly for SelectRow { fn degree(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/serialization.rs b/crates/core/src/transparent/serialization.rs new file mode 100644 index 000000000..2d5ef14f5 --- /dev/null +++ b/crates/core/src/transparent/serialization.rs @@ -0,0 +1,82 @@ +// Copyright 2025 Irreducible Inc. + +//! The purpose of this module is to enable serialization/deserialization of generic MultivariatePoly implementations +//! +//! The simplest way to do this would be to create an enum with all the possible structs that implement MultivariatePoly +//! +//! This has a few problems, though: +//! - Third party code is not able to define custom transparent polynomials +//! - The enum would inherit, or be forced to enumerate possible type parameters of every struct variant + +use std::{collections::HashMap, sync::LazyLock}; + +use binius_field::{BinaryField128b, TowerField}; +use binius_utils::{DeserializeBytes, SerializationError, SerializationMode, SerializeBytes}; + +use crate::polynomial::MultivariatePoly; + +impl SerializeBytes for Box> { + fn serialize( + &self, + mut write_buf: impl bytes::BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + self.erased_serialize(&mut write_buf, mode) + } +} + +impl DeserializeBytes for Box> { + fn deserialize( + mut read_buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + let name = String::deserialize(&mut read_buf, mode)?; + match REGISTRY.get(name.as_str()) { + Some(Some(erased_deserialize)) => erased_deserialize(&mut read_buf, mode), + Some(None) => Err(SerializationError::DeserializerNameConflict { name }), + None => Err(SerializationError::DeserializerNotImplented), + } + } +} + +// Using the inventory crate we can collect all deserializers before the main function runs +// This allows third party code to submit their own deserializers as well +inventory::collect!(DeserializerEntry); + +static REGISTRY: LazyLock>>> = + LazyLock::new(|| { + let mut registry = HashMap::new(); + inventory::iter::> + .into_iter() + .for_each(|&DeserializerEntry { name, deserializer }| match registry.entry(name) { + std::collections::hash_map::Entry::Vacant(entry) => { + entry.insert(Some(deserializer)); + } + std::collections::hash_map::Entry::Occupied(mut entry) => { + entry.insert(None); + } + }); + registry + }); + +impl dyn MultivariatePoly { + pub const fn register_deserializer( + name: &'static str, + deserializer: ErasedDeserializeBytes, + ) -> DeserializerEntry { + DeserializerEntry { name, deserializer } + } +} + +pub struct DeserializerEntry { + name: &'static str, + deserializer: ErasedDeserializeBytes, +} + +type ErasedDeserializeBytes = fn( + &mut dyn bytes::Buf, + mode: SerializationMode, +) -> Result>, SerializationError>; diff --git a/crates/core/src/transparent/step_down.rs b/crates/core/src/transparent/step_down.rs index 8c588f738..00dec6771 100644 --- a/crates/core/src/transparent/step_down.rs +++ b/crates/core/src/transparent/step_down.rs @@ -1,8 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{Field, PackedField}; +use binius_field::{BinaryField128b, Field, PackedField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -20,12 +21,19 @@ use crate::polynomial::{Error, MultivariatePoly}; /// ``` /// /// This is useful for making constraints that are not enforced at the last rows of the trace -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct StepDown { n_vars: usize, index: usize, } +inventory::submit! { + >::register_deserializer( + "StepDown", + |buf, mode| Ok(Box::new(StepDown::deserialize(&mut *buf, mode)?)) + ) +} + impl StepDown { pub fn new(n_vars: usize, index: usize) -> Result { if index > 1 << n_vars { @@ -68,6 +76,7 @@ impl StepDown { } } +#[erased_serialize_bytes] impl MultivariatePoly for StepDown { fn degree(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/step_up.rs b/crates/core/src/transparent/step_up.rs index 3a24b9f53..ad022df9a 100644 --- a/crates/core/src/transparent/step_up.rs +++ b/crates/core/src/transparent/step_up.rs @@ -1,8 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{Field, PackedField}; +use binius_field::{BinaryField128b, Field, PackedField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -20,12 +21,19 @@ use crate::polynomial::{Error, MultivariatePoly}; /// ``` /// /// This is useful for making constraints that are not enforced at the first rows of the trace -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct StepUp { n_vars: usize, index: usize, } +inventory::submit! { + >::register_deserializer( + "StepUp", + |buf, mode| Ok(Box::new(StepUp::deserialize(&mut *buf, mode)?)) + ) +} + impl StepUp { pub fn new(n_vars: usize, index: usize) -> Result { if index > 1 << n_vars { @@ -64,6 +72,7 @@ impl StepUp { } } +#[erased_serialize_bytes] impl MultivariatePoly for StepUp { fn degree(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/tower_basis.rs b/crates/core/src/transparent/tower_basis.rs index 20992c467..8b8d32bf0 100644 --- a/crates/core/src/transparent/tower_basis.rs +++ b/crates/core/src/transparent/tower_basis.rs @@ -2,9 +2,10 @@ use std::marker::PhantomData; -use binius_field::{Field, PackedField, TowerField}; +use binius_field::{BinaryField128b, Field, PackedField, TowerField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -20,13 +21,20 @@ use crate::polynomial::{Error, MultivariatePoly}; /// /// Thus, $\mathcal{T}_{\iota+k}$ has a $\mathcal{T}_{\iota}$-basis of size $2^k$: /// * $1, X_{\iota}, X_{\iota+1}, X_{\iota}X_{\iota+1}, X_{\iota+2}, \ldots, X_{\iota} X_{\iota+1} \ldots X_{\iota+k-1}$ -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, SerializeBytes, DeserializeBytes)] pub struct TowerBasis { k: usize, iota: usize, _marker: PhantomData, } +inventory::submit! { + >::register_deserializer( + "TowerBasis", + |buf, mode| Ok(Box::new(TowerBasis::::deserialize(&mut *buf, mode)?)) + ) +} + impl TowerBasis { pub fn new(k: usize, iota: usize) -> Result { if iota + k > F::TOWER_LEVEL { @@ -62,6 +70,7 @@ impl TowerBasis { } } +#[erased_serialize_bytes] impl MultivariatePoly for TowerBasis where F: TowerField, diff --git a/crates/field/Cargo.toml b/crates/field/Cargo.toml index 36de13de8..11903020b 100644 --- a/crates/field/Cargo.toml +++ b/crates/field/Cargo.toml @@ -11,7 +11,6 @@ workspace = true binius_maybe_rayon = { path = "../maybe_rayon", default-features = false } binius_utils = { path = "../utils", default-features = false } bytemuck.workspace = true -bytes.workspace = true cfg-if.workspace = true derive_more.workspace = true rand.workspace = true diff --git a/crates/field/benches/packed_extension_mul.rs b/crates/field/benches/packed_extension_mul.rs index a119dc9a7..836b8474f 100644 --- a/crates/field/benches/packed_extension_mul.rs +++ b/crates/field/benches/packed_extension_mul.rs @@ -16,7 +16,6 @@ fn benchmark_packed_extension_mul( label: &str, ) where F: Field, - BinaryField128b: ExtensionField, PackedBinaryField2x128b: PackedExtension, { let mut rng = thread_rng(); diff --git a/crates/field/benches/packed_field_element_access.rs b/crates/field/benches/packed_field_element_access.rs index d8f4d0ee4..834f280ad 100644 --- a/crates/field/benches/packed_field_element_access.rs +++ b/crates/field/benches/packed_field_element_access.rs @@ -3,11 +3,15 @@ use std::array; use binius_field::{ - PackedBinaryField128x1b, PackedBinaryField16x32b, PackedBinaryField16x8b, - PackedBinaryField1x128b, PackedBinaryField256x1b, PackedBinaryField2x128b, - PackedBinaryField2x64b, PackedBinaryField32x8b, PackedBinaryField4x128b, - PackedBinaryField4x32b, PackedBinaryField4x64b, PackedBinaryField512x1b, - PackedBinaryField64x8b, PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField, + ByteSlicedAES16x128b, ByteSlicedAES16x16b, ByteSlicedAES16x32b, ByteSlicedAES16x64b, + ByteSlicedAES16x8b, ByteSlicedAES32x128b, ByteSlicedAES32x16b, ByteSlicedAES32x32b, + ByteSlicedAES32x64b, ByteSlicedAES32x8b, ByteSlicedAES64x128b, ByteSlicedAES64x16b, + ByteSlicedAES64x32b, ByteSlicedAES64x64b, ByteSlicedAES64x8b, PackedBinaryField128x1b, + PackedBinaryField16x32b, PackedBinaryField16x8b, PackedBinaryField1x128b, + PackedBinaryField256x1b, PackedBinaryField2x128b, PackedBinaryField2x64b, + PackedBinaryField32x8b, PackedBinaryField4x128b, PackedBinaryField4x32b, + PackedBinaryField4x64b, PackedBinaryField512x1b, PackedBinaryField64x8b, + PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField, }; use criterion::{ criterion_group, criterion_main, measurement::WallTime, BenchmarkGroup, Criterion, Throughput, @@ -86,5 +90,43 @@ fn packed_512(c: &mut Criterion) { benchmark_get_set!(PackedBinaryField4x128b, group); } -criterion_group!(get_set, packed_128, packed_256, packed_512); +fn byte_sliced_128(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_128"); + + benchmark_get_set!(ByteSlicedAES16x8b, group); + benchmark_get_set!(ByteSlicedAES16x16b, group); + benchmark_get_set!(ByteSlicedAES16x32b, group); + benchmark_get_set!(ByteSlicedAES16x64b, group); + benchmark_get_set!(ByteSlicedAES16x128b, group); +} + +fn byte_sliced_256(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_256"); + + benchmark_get_set!(ByteSlicedAES32x8b, group); + benchmark_get_set!(ByteSlicedAES32x16b, group); + benchmark_get_set!(ByteSlicedAES32x32b, group); + benchmark_get_set!(ByteSlicedAES32x64b, group); + benchmark_get_set!(ByteSlicedAES32x128b, group); +} + +fn byte_sliced_512(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_512"); + + benchmark_get_set!(ByteSlicedAES64x8b, group); + benchmark_get_set!(ByteSlicedAES64x16b, group); + benchmark_get_set!(ByteSlicedAES64x32b, group); + benchmark_get_set!(ByteSlicedAES64x64b, group); + benchmark_get_set!(ByteSlicedAES64x128b, group); +} + +criterion_group!( + get_set, + packed_128, + packed_256, + packed_512, + byte_sliced_128, + byte_sliced_256, + byte_sliced_512 +); criterion_main!(get_set); diff --git a/crates/field/benches/packed_field_init.rs b/crates/field/benches/packed_field_init.rs index b2a3feb54..43c644e93 100644 --- a/crates/field/benches/packed_field_init.rs +++ b/crates/field/benches/packed_field_init.rs @@ -3,11 +3,15 @@ use std::array; use binius_field::{ - PackedBinaryField128x1b, PackedBinaryField16x32b, PackedBinaryField16x8b, - PackedBinaryField1x128b, PackedBinaryField256x1b, PackedBinaryField2x128b, - PackedBinaryField2x64b, PackedBinaryField32x8b, PackedBinaryField4x128b, - PackedBinaryField4x32b, PackedBinaryField4x64b, PackedBinaryField512x1b, - PackedBinaryField64x8b, PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField, + ByteSlicedAES16x128b, ByteSlicedAES16x16b, ByteSlicedAES16x32b, ByteSlicedAES16x64b, + ByteSlicedAES16x8b, ByteSlicedAES32x128b, ByteSlicedAES32x16b, ByteSlicedAES32x32b, + ByteSlicedAES32x64b, ByteSlicedAES32x8b, ByteSlicedAES64x128b, ByteSlicedAES64x16b, + ByteSlicedAES64x32b, ByteSlicedAES64x64b, ByteSlicedAES64x8b, PackedBinaryField128x1b, + PackedBinaryField16x32b, PackedBinaryField16x8b, PackedBinaryField1x128b, + PackedBinaryField256x1b, PackedBinaryField2x128b, PackedBinaryField2x64b, + PackedBinaryField32x8b, PackedBinaryField4x128b, PackedBinaryField4x32b, + PackedBinaryField4x64b, PackedBinaryField512x1b, PackedBinaryField64x8b, + PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField, }; use criterion::{ criterion_group, criterion_main, measurement::WallTime, BenchmarkGroup, Criterion, Throughput, @@ -71,5 +75,43 @@ fn packed_512(c: &mut Criterion) { benchmark_from_fn!(PackedBinaryField4x128b, group); } -criterion_group!(initialization, packed_128, packed_256, packed_512); +fn byte_sliced_128(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_128"); + + benchmark_from_fn!(ByteSlicedAES16x8b, group); + benchmark_from_fn!(ByteSlicedAES16x16b, group); + benchmark_from_fn!(ByteSlicedAES16x32b, group); + benchmark_from_fn!(ByteSlicedAES16x64b, group); + benchmark_from_fn!(ByteSlicedAES16x128b, group); +} + +fn byte_sliced_256(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_256"); + + benchmark_from_fn!(ByteSlicedAES32x8b, group); + benchmark_from_fn!(ByteSlicedAES32x16b, group); + benchmark_from_fn!(ByteSlicedAES32x32b, group); + benchmark_from_fn!(ByteSlicedAES32x64b, group); + benchmark_from_fn!(ByteSlicedAES32x128b, group); +} + +fn byte_sliced_512(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_512"); + + benchmark_from_fn!(ByteSlicedAES64x8b, group); + benchmark_from_fn!(ByteSlicedAES64x16b, group); + benchmark_from_fn!(ByteSlicedAES64x32b, group); + benchmark_from_fn!(ByteSlicedAES64x64b, group); + benchmark_from_fn!(ByteSlicedAES64x128b, group); +} + +criterion_group!( + initialization, + packed_128, + packed_256, + packed_512, + byte_sliced_128, + byte_sliced_256, + byte_sliced_512 +); criterion_main!(initialization); diff --git a/crates/field/benches/packed_field_subfield_ops.rs b/crates/field/benches/packed_field_subfield_ops.rs index cfd2dff39..bc2a36539 100644 --- a/crates/field/benches/packed_field_subfield_ops.rs +++ b/crates/field/benches/packed_field_subfield_ops.rs @@ -5,8 +5,8 @@ use std::array; use binius_field::{ packed::mul_by_subfield_scalar, underlier::{UnderlierType, WithUnderlier}, - BinaryField1b, BinaryField32b, BinaryField4b, BinaryField64b, BinaryField8b, ExtensionField, - Field, PackedBinaryField16x8b, PackedBinaryField1x128b, PackedBinaryField2x128b, + BinaryField1b, BinaryField32b, BinaryField4b, BinaryField64b, BinaryField8b, Field, + PackedBinaryField16x8b, PackedBinaryField1x128b, PackedBinaryField2x128b, PackedBinaryField32x8b, PackedBinaryField4x128b, PackedBinaryField4x32b, PackedBinaryField64x8b, PackedBinaryField8x32b, PackedBinaryField8x64b, PackedExtension, }; @@ -17,11 +17,7 @@ use rand::thread_rng; const BATCH_SIZE: usize = 32; -fn bench_mul_subfield(group: &mut BenchmarkGroup<'_, WallTime>) -where - PE: PackedExtension>, - F: Field, -{ +fn bench_mul_subfield, F: Field>(group: &mut BenchmarkGroup<'_, WallTime>) { let mut rng = thread_rng(); let packed: [PE; BATCH_SIZE] = array::from_fn(|_| PE::random(&mut rng)); let scalars: [F; BATCH_SIZE] = array::from_fn(|_| F::random(&mut rng)); diff --git a/crates/field/benches/packed_field_utils.rs b/crates/field/benches/packed_field_utils.rs index 516af5e6e..622a49afa 100644 --- a/crates/field/benches/packed_field_utils.rs +++ b/crates/field/benches/packed_field_utils.rs @@ -274,11 +274,23 @@ macro_rules! benchmark_packed_operation { PackedBinaryPolyval4x128b // Byte sliced AES fields + ByteSlicedAES16x8b + ByteSlicedAES16x16b + ByteSlicedAES16x32b + ByteSlicedAES16x64b + ByteSlicedAES16x128b + ByteSlicedAES32x8b ByteSlicedAES32x16b ByteSlicedAES32x32b ByteSlicedAES32x64b ByteSlicedAES32x128b + + ByteSlicedAES64x8b + ByteSlicedAES64x16b + ByteSlicedAES64x32b + ByteSlicedAES64x64b + ByteSlicedAES64x128b ]); }; } diff --git a/crates/field/src/aes_field.rs b/crates/field/src/aes_field.rs index c69b665ab..74f7e76df 100644 --- a/crates/field/src/aes_field.rs +++ b/crates/field/src/aes_field.rs @@ -2,13 +2,16 @@ use std::{ any::TypeId, - array, fmt::{Debug, Display, Formatter}, iter::{Product, Sum}, marker::PhantomData, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +use binius_utils::{ + bytes::{Buf, BufMut}, + DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, +}; use bytemuck::{Pod, Zeroable}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; @@ -79,6 +82,13 @@ impl_arithmetic_using_packed!(AESTowerField128b); impl TowerField for AESTowerField8b { type Canonical = BinaryField8b; + fn min_tower_level(self) -> usize { + match self { + Self::ZERO | Self::ONE => 0, + _ => 3, + } + } + fn mul_primitive(self, iota: usize) -> Result { match iota { 0..=1 => Ok(self * ISOMORPHIC_ALPHAS[iota]), @@ -158,8 +168,8 @@ impl Transformation for SubfieldTransformer>, - OEP: PackedExtension>, + IEP: PackedExtension, + OEP: PackedExtension, T: Transformation, PackedSubfield>, { fn transform(&self, input: &IEP) -> OEP { @@ -172,9 +182,7 @@ where pub fn make_aes_to_binary_packed_transformer() -> impl Transformation where IP: PackedExtension, - IP::Scalar: ExtensionField, OP: PackedExtension, - OP::Scalar: ExtensionField, PackedSubfield: PackedTransformationFactory>, { @@ -191,9 +199,7 @@ where pub fn make_binary_to_aes_packed_transformer() -> impl Transformation where IP: PackedExtension, - IP::Scalar: ExtensionField, OP: PackedExtension, - OP::Scalar: ExtensionField, PackedSubfield: PackedTransformationFactory>, { @@ -281,18 +287,61 @@ impl_tower_field_conversion!(AESTowerField32b, BinaryField32b); impl_tower_field_conversion!(AESTowerField64b, BinaryField64b); impl_tower_field_conversion!(AESTowerField128b, BinaryField128b); +macro_rules! serialize_deserialize_non_canonical { + ($field:ident, canonical=$canonical:ident) => { + impl SerializeBytes for $field { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + match mode { + SerializationMode::Native => self.0.serialize(write_buf, mode), + SerializationMode::CanonicalTower => { + $canonical::from(*self).serialize(write_buf, mode) + } + } + } + } + + impl DeserializeBytes for $field { + fn deserialize( + read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + match mode { + SerializationMode::Native => { + Ok(Self(DeserializeBytes::deserialize(read_buf, mode)?)) + } + SerializationMode::CanonicalTower => { + Ok(Self::from($canonical::deserialize(read_buf, mode)?)) + } + } + } + } + }; +} + +serialize_deserialize_non_canonical!(AESTowerField8b, canonical = BinaryField8b); +serialize_deserialize_non_canonical!(AESTowerField16b, canonical = BinaryField16b); +serialize_deserialize_non_canonical!(AESTowerField32b, canonical = BinaryField32b); +serialize_deserialize_non_canonical!(AESTowerField64b, canonical = BinaryField64b); +serialize_deserialize_non_canonical!(AESTowerField128b, canonical = BinaryField128b); + #[cfg(test)] mod tests { - use bytes::BytesMut; + use binius_utils::{bytes::BytesMut, SerializationMode, SerializeBytes}; use proptest::{arbitrary::any, proptest}; use rand::thread_rng; use super::*; use crate::{ - binary_field::tests::is_binary_field_valid_generator, deserialize_canonical, - serialize_canonical, underlier::WithUnderlier, PackedAESBinaryField16x32b, - PackedAESBinaryField4x32b, PackedAESBinaryField8x32b, PackedBinaryField16x32b, - PackedBinaryField4x32b, PackedBinaryField8x32b, + binary_field::tests::is_binary_field_valid_generator, underlier::WithUnderlier, + PackedAESBinaryField16x32b, PackedAESBinaryField4x32b, PackedAESBinaryField8x32b, + PackedBinaryField16x32b, PackedBinaryField4x32b, PackedBinaryField8x32b, }; fn check_square(f: impl Field) { @@ -591,28 +640,24 @@ mod tests { let aes64 = ::random(&mut rng); let aes128 = ::random(&mut rng); - serialize_canonical(aes8, &mut buffer).unwrap(); - serialize_canonical(aes16, &mut buffer).unwrap(); - serialize_canonical(aes32, &mut buffer).unwrap(); - serialize_canonical(aes64, &mut buffer).unwrap(); - serialize_canonical(aes128, &mut buffer).unwrap(); + let mode = SerializationMode::CanonicalTower; + + SerializeBytes::serialize(&aes8, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&aes16, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&aes32, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&aes64, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&aes128, &mut buffer, mode).unwrap(); - serialize_canonical(aes128, &mut buffer).unwrap(); + SerializeBytes::serialize(&aes128, &mut buffer, mode).unwrap(); let mut read_buffer = buffer.freeze(); - assert_eq!(deserialize_canonical::(&mut read_buffer).unwrap(), aes8); - assert_eq!(deserialize_canonical::(&mut read_buffer).unwrap(), aes16); - assert_eq!(deserialize_canonical::(&mut read_buffer).unwrap(), aes32); - assert_eq!(deserialize_canonical::(&mut read_buffer).unwrap(), aes64); - assert_eq!( - deserialize_canonical::(&mut read_buffer).unwrap(), - aes128 - ); - - assert_eq!( - deserialize_canonical::(&mut read_buffer).unwrap(), - aes128.into() - ) + assert_eq!(AESTowerField8b::deserialize(&mut read_buffer, mode).unwrap(), aes8); + assert_eq!(AESTowerField16b::deserialize(&mut read_buffer, mode).unwrap(), aes16); + assert_eq!(AESTowerField32b::deserialize(&mut read_buffer, mode).unwrap(), aes32); + assert_eq!(AESTowerField64b::deserialize(&mut read_buffer, mode).unwrap(), aes64); + assert_eq!(AESTowerField128b::deserialize(&mut read_buffer, mode).unwrap(), aes128); + + assert_eq!(BinaryField128b::deserialize(&mut read_buffer, mode).unwrap(), aes128.into()) } } diff --git a/crates/field/src/arch/aarch64/m128.rs b/crates/field/src/arch/aarch64/m128.rs index a32a82b7d..70c496606 100644 --- a/crates/field/src/arch/aarch64/m128.rs +++ b/crates/field/src/arch/aarch64/m128.rs @@ -19,8 +19,8 @@ use crate::{ arch::binary_utils::{as_array_mut, as_array_ref}, arithmetic_traits::Broadcast, underlier::{ - impl_divisible, impl_iteration, NumCast, Random, SmallU, UnderlierType, - UnderlierWithBitOps, WithUnderlier, U1, U2, U4, + impl_divisible, impl_iteration, unpack_lo_128b_fallback, NumCast, Random, SmallU, + UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4, }, BinaryField, }; @@ -337,6 +337,40 @@ impl UnderlierWithBitOps for M128 { _ => panic!("unsupported bit count"), } } + + #[inline(always)] + fn shl_128b_lanes(self, rhs: usize) -> Self { + Self(self.0 << rhs) + } + + #[inline(always)] + fn shr_128b_lanes(self, rhs: usize) -> Self { + Self(self.0 >> rhs) + } + + #[inline(always)] + fn unpack_lo_128b_lanes(self, rhs: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, rhs, log_block_len), + 3 => unsafe { vzip1q_u8(self.into(), rhs.into()).into() }, + 4 => unsafe { vzip1q_u16(self.into(), rhs.into()).into() }, + 5 => unsafe { vzip1q_u32(self.into(), rhs.into()).into() }, + 6 => unsafe { vzip1q_u64(self.into(), rhs.into()).into() }, + _ => panic!("Unsupported block length"), + } + } + + #[inline(always)] + fn unpack_hi_128b_lanes(self, rhs: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, rhs, log_block_len), + 3 => unsafe { vzip2q_u8(self.into(), rhs.into()).into() }, + 4 => unsafe { vzip2q_u16(self.into(), rhs.into()).into() }, + 5 => unsafe { vzip2q_u32(self.into(), rhs.into()).into() }, + 6 => unsafe { vzip2q_u64(self.into(), rhs.into()).into() }, + _ => panic!("Unsupported block length"), + } + } } impl UnderlierWithBitConstants for M128 { @@ -401,6 +435,37 @@ impl UnderlierWithBitConstants for M128 { } } } + + #[inline] + fn transpose(self, other: Self, log_block_len: usize) -> (Self, Self) { + unsafe { + match log_block_len { + 0..=3 => { + let (a, b) = (self.into(), other.into()); + let (mut a, mut b) = (Self::from(vuzp1q_u8(a, b)), Self::from(vuzp2q_u8(a, b))); + + for log_block_len in (log_block_len..3).rev() { + (a, b) = a.interleave(b, log_block_len); + } + + (a, b) + } + 4 => { + let (a, b) = (self.into(), other.into()); + (vuzp1q_u16(a, b).into(), vuzp2q_u16(a, b).into()) + } + 5 => { + let (a, b) = (self.into(), other.into()); + (vuzp1q_u32(a, b).into(), vuzp2q_u32(a, b).into()) + } + 6 => { + let (a, b) = (self.into(), other.into()); + (vuzp1q_u64(a, b).into(), vuzp2q_u64(a, b).into()) + } + _ => panic!("Unsupported block length"), + } + } + } } impl From for PackedPrimitiveType { diff --git a/crates/field/src/arch/portable/byte_sliced/invert.rs b/crates/field/src/arch/portable/byte_sliced/invert.rs index 3544cbbdf..8581e5669 100644 --- a/crates/field/src/arch/portable/byte_sliced/invert.rs +++ b/crates/field/src/arch/portable/byte_sliced/invert.rs @@ -6,25 +6,24 @@ use super::{ use crate::{ tower_levels::{TowerLevel, TowerLevelWithArithOps}, underlier::WithUnderlier, - AESTowerField8b, PackedAESBinaryField32x8b, PackedField, + AESTowerField8b, PackedField, }; #[inline(always)] -pub fn invert_or_zero>( - field_element: &Level::Data, - destination: &mut Level::Data, +pub fn invert_or_zero, Level: TowerLevel>( + field_element: &Level::Data

, + destination: &mut Level::Data

, ) { - let base_alpha = - PackedAESBinaryField32x8b::from_scalars([AESTowerField8b::from_underlier(0xd3); 32]); + let base_alpha = P::broadcast(AESTowerField8b::from_underlier(0xd3)); - inv_main::(field_element, destination, base_alpha); + inv_main::(field_element, destination, base_alpha); } #[inline(always)] -fn inv_main>( - field_element: &Level::Data, - destination: &mut Level::Data, - base_alpha: PackedAESBinaryField32x8b, +fn inv_main, Level: TowerLevel>( + field_element: &Level::Data

, + destination: &mut Level::Data

, + base_alpha: P, ) { if Level::WIDTH == 1 { destination.as_mut()[0] = field_element.as_ref()[0].invert_or_zero(); @@ -35,36 +34,30 @@ fn inv_main>( let (result0, result1) = Level::split_mut(destination); - let mut intermediate = <>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut intermediate = <::Base as TowerLevel>::default(); // intermediate = subfield_alpha*a1 - mul_alpha::(a1, &mut intermediate, base_alpha); + mul_alpha::(a1, &mut intermediate, base_alpha); // intermediate = a0 + subfield_alpha*a1 Level::Base::add_into(a0, &mut intermediate); - let mut delta = <>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut delta = <::Base as TowerLevel>::default(); // delta = intermediate * a0 - mul_main::(&intermediate, a0, &mut delta, base_alpha); + mul_main::(&intermediate, a0, &mut delta, base_alpha); // delta = intermediate * a0 + a1^2 - square_main::(a1, &mut delta, base_alpha); + square_main::(a1, &mut delta, base_alpha); - let mut delta_inv = <>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut delta_inv = <::Base as TowerLevel>::default(); // delta_inv = 1/delta - inv_main::(&delta, &mut delta_inv, base_alpha); + inv_main::(&delta, &mut delta_inv, base_alpha); // result0 = delta_inv*intermediate - mul_main::(&delta_inv, &intermediate, result0, base_alpha); + mul_main::(&delta_inv, &intermediate, result0, base_alpha); // result1 = delta_inv*intermediate - mul_main::(&delta_inv, a1, result1, base_alpha); + mul_main::(&delta_inv, a1, result1, base_alpha); } diff --git a/crates/field/src/arch/portable/byte_sliced/mod.rs b/crates/field/src/arch/portable/byte_sliced/mod.rs index 0c42539b9..0f29290a6 100644 --- a/crates/field/src/arch/portable/byte_sliced/mod.rs +++ b/crates/field/src/arch/portable/byte_sliced/mod.rs @@ -15,8 +15,8 @@ pub mod tests { use proptest::prelude::*; use crate::{$scalar_type, underlier::WithUnderlier, packed::PackedField, arch::byte_sliced::$name}; - fn scalar_array_strategy() -> impl Strategy { - any::<[<$scalar_type as WithUnderlier>::Underlier; 32]>().prop_map(|arr| arr.map(<$scalar_type>::from_underlier)) + fn scalar_array_strategy() -> impl Strategy::WIDTH]> { + any::<[<$scalar_type as WithUnderlier>::Underlier; <$name>::WIDTH]>().prop_map(|arr| arr.map(<$scalar_type>::from_underlier)) } proptest! { @@ -27,7 +27,7 @@ pub mod tests { let bytesliced_result = bytesliced_a + bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_result.get(i)); } } @@ -39,7 +39,7 @@ pub mod tests { bytesliced_a += bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_a.get(i)); } } @@ -51,7 +51,7 @@ pub mod tests { let bytesliced_result = bytesliced_a - bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_result.get(i)); } } @@ -63,7 +63,7 @@ pub mod tests { bytesliced_a -= bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_a.get(i)); } } @@ -75,7 +75,7 @@ pub mod tests { let bytesliced_result = bytesliced_a * bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_result.get(i)); } } @@ -87,7 +87,7 @@ pub mod tests { bytesliced_a *= bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_a.get(i)); } } @@ -118,9 +118,21 @@ pub mod tests { }; } + define_byte_sliced_test!(tests_16x128, ByteSlicedAES16x128b, AESTowerField128b); + define_byte_sliced_test!(tests_16x64, ByteSlicedAES16x64b, AESTowerField64b); + define_byte_sliced_test!(tests_16x32, ByteSlicedAES16x32b, AESTowerField32b); + define_byte_sliced_test!(tests_16x16, ByteSlicedAES16x16b, AESTowerField16b); + define_byte_sliced_test!(tests_16x8, ByteSlicedAES16x8b, AESTowerField8b); + define_byte_sliced_test!(tests_32x128, ByteSlicedAES32x128b, AESTowerField128b); define_byte_sliced_test!(tests_32x64, ByteSlicedAES32x64b, AESTowerField64b); define_byte_sliced_test!(tests_32x32, ByteSlicedAES32x32b, AESTowerField32b); define_byte_sliced_test!(tests_32x16, ByteSlicedAES32x16b, AESTowerField16b); define_byte_sliced_test!(tests_32x8, ByteSlicedAES32x8b, AESTowerField8b); + + define_byte_sliced_test!(tests_64x128, ByteSlicedAES64x128b, AESTowerField128b); + define_byte_sliced_test!(tests_64x64, ByteSlicedAES64x64b, AESTowerField64b); + define_byte_sliced_test!(tests_64x32, ByteSlicedAES64x32b, AESTowerField32b); + define_byte_sliced_test!(tests_64x16, ByteSlicedAES64x16b, AESTowerField16b); + define_byte_sliced_test!(tests_64x8, ByteSlicedAES64x8b, AESTowerField8b); } diff --git a/crates/field/src/arch/portable/byte_sliced/multiply.rs b/crates/field/src/arch/portable/byte_sliced/multiply.rs index cfaf7339b..c2037703a 100644 --- a/crates/field/src/arch/portable/byte_sliced/multiply.rs +++ b/crates/field/src/arch/portable/byte_sliced/multiply.rs @@ -2,25 +2,28 @@ use crate::{ tower_levels::{TowerLevel, TowerLevelWithArithOps}, underlier::WithUnderlier, - AESTowerField8b, PackedAESBinaryField32x8b, PackedField, + AESTowerField8b, PackedField, }; #[inline(always)] -pub fn mul>( - field_element_a: &Level::Data, - field_element_b: &Level::Data, - destination: &mut Level::Data, +pub fn mul, Level: TowerLevel>( + field_element_a: &Level::Data

, + field_element_b: &Level::Data

, + destination: &mut Level::Data

, ) { - let base_alpha = - PackedAESBinaryField32x8b::from_scalars([AESTowerField8b::from_underlier(0xd3); 32]); - mul_main::(field_element_a, field_element_b, destination, base_alpha); + let base_alpha = P::broadcast(AESTowerField8b::from_underlier(0xd3)); + mul_main::(field_element_a, field_element_b, destination, base_alpha); } #[inline(always)] -pub fn mul_alpha>( - field_element: &Level::Data, - destination: &mut Level::Data, - base_alpha: PackedAESBinaryField32x8b, +pub fn mul_alpha< + const WRITING_TO_ZEROS: bool, + P: PackedField, + Level: TowerLevel, +>( + field_element: &Level::Data

, + destination: &mut Level::Data

, + base_alpha: P, ) { if Level::WIDTH == 1 { if WRITING_TO_ZEROS { @@ -49,15 +52,19 @@ pub fn mul_alpha(a1, result1, base_alpha); + mul_alpha::(a1, result1, base_alpha); } #[inline(always)] -pub fn mul_main>( - field_element_a: &Level::Data, - field_element_b: &Level::Data, - destination: &mut Level::Data, - base_alpha: PackedAESBinaryField32x8b, +pub fn mul_main< + const WRITING_TO_ZEROS: bool, + P: PackedField, + Level: TowerLevel, +>( + field_element_a: &Level::Data

, + field_element_b: &Level::Data

, + destination: &mut Level::Data

, + base_alpha: P, ) { if Level::WIDTH == 1 { if WRITING_TO_ZEROS { @@ -78,21 +85,19 @@ pub fn mul_main>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut z2_z0 = <::Base as TowerLevel>::default(); // z2_z0 = z2 - mul_main::(a1, b1, &mut z2_z0, base_alpha); + mul_main::(a1, b1, &mut z2_z0, base_alpha); // result1 = z2 * alpha - mul_alpha::(&z2_z0, result1, base_alpha); + mul_alpha::(&z2_z0, result1, base_alpha); // z2_z0 = z2 + z0 - mul_main::(a0, b0, &mut z2_z0, base_alpha); + mul_main::(a0, b0, &mut z2_z0, base_alpha); // result1 = z1 + z2 * alpha - mul_main::(&xored_halves_a, &xored_halves_b, result1, base_alpha); + mul_main::(&xored_halves_a, &xored_halves_b, result1, base_alpha); // result1 = z2+ z0+ z1 + z2 * alpha Level::Base::add_into(&z2_z0, result1); diff --git a/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs b/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs index 7746795d1..40e9648aa 100644 --- a/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs +++ b/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs @@ -7,7 +7,7 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; -use bytemuck::Zeroable; +use bytemuck::{Pod, Zeroable}; use super::{invert::invert_or_zero, multiply::mul, square::square}; use crate::{ @@ -15,7 +15,7 @@ use crate::{ tower_levels::*, underlier::{UnderlierWithBitOps, WithUnderlier}, AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField8b, - PackedField, + PackedAESBinaryField16x8b, PackedAESBinaryField64x8b, PackedField, }; /// Represents 32 AES Tower Field elements in byte-sliced form backed by Packed 32x8b AES fields. @@ -24,16 +24,15 @@ use crate::{ /// multiplication circuit on GFNI machines, since multiplication of two 32x8b field elements is /// handled in one instruction. macro_rules! define_byte_sliced { - ($name:ident, $scalar_type:ty, $tower_level: ty) => { - #[derive(Default, Clone, Debug, Copy, PartialEq, Eq, Zeroable)] + ($name:ident, $scalar_type:ty, $packed_storage:ty, $tower_level: ty) => { + #[derive(Default, Clone, Debug, Copy, PartialEq, Eq, Pod, Zeroable)] + #[repr(transparent)] pub struct $name { - pub(super) data: [PackedAESBinaryField32x8b; - <$tower_level as TowerLevel>::WIDTH], + pub(super) data: [$packed_storage; <$tower_level as TowerLevel>::WIDTH], } impl $name { - pub const BYTES: usize = PackedAESBinaryField32x8b::WIDTH - * <$tower_level as TowerLevel>::WIDTH; + pub const BYTES: usize = <$packed_storage>::WIDTH * <$tower_level as TowerLevel>::WIDTH; /// Get the byte at the given index. /// @@ -41,11 +40,8 @@ macro_rules! define_byte_sliced { /// The caller must ensure that `byte_index` is less than `BYTES`. #[allow(clippy::modulo_one)] pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 { - self.data - [byte_index % <$tower_level as TowerLevel>::WIDTH] - .get( - byte_index / <$tower_level as TowerLevel>::WIDTH, - ) + self.data[byte_index % <$tower_level as TowerLevel>::WIDTH] + .get(byte_index / <$tower_level as TowerLevel>::WIDTH) .to_underlier() } } @@ -53,28 +49,26 @@ macro_rules! define_byte_sliced { impl PackedField for $name { type Scalar = $scalar_type; - const LOG_WIDTH: usize = 5; + const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH; + #[inline(always)] unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar { - let mut result_underlier = 0; - for (byte_index, val) in self.data.iter().enumerate() { - // Safety: - // - `byte_index` is less than 16 - // - `i` must be less than 32 due to safety conditions of this method - unsafe { - result_underlier - .set_subvalue(byte_index, val.get_unchecked(i).to_underlier()) - } - } + let result_underlier = + ::Underlier::from_fn(|byte_index| unsafe { + self.data + .get_unchecked(byte_index) + .get_unchecked(i) + .to_underlier() + }); Self::Scalar::from_underlier(result_underlier) } + #[inline(always)] unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) { let underlier = scalar.to_underlier(); - for byte_index in 0..<$tower_level as TowerLevel>::WIDTH - { + for byte_index in 0..<$tower_level as TowerLevel>::WIDTH { self.data[byte_index].set_unchecked( i, AESTowerField8b::from_underlier(underlier.get_subvalue(byte_index)), @@ -86,16 +80,18 @@ macro_rules! define_byte_sliced { Self::from_scalars([Self::Scalar::random(rng); 32]) } + #[inline] fn broadcast(scalar: Self::Scalar) -> Self { Self { data: array::from_fn(|byte_index| { - PackedAESBinaryField32x8b::broadcast(AESTowerField8b::from_underlier( - unsafe { scalar.to_underlier().get_subvalue(byte_index) }, - )) + <$packed_storage>::broadcast(AESTowerField8b::from_underlier(unsafe { + scalar.to_underlier().get_subvalue(byte_index) + })) }), } } + #[inline] fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self { let mut result = Self::default(); @@ -107,30 +103,43 @@ macro_rules! define_byte_sliced { result } + #[inline] fn square(self) -> Self { let mut result = Self::default(); - square::<$tower_level>(&self.data, &mut result.data); + square::<$packed_storage, $tower_level>(&self.data, &mut result.data); result } + #[inline] fn invert_or_zero(self) -> Self { let mut result = Self::default(); - invert_or_zero::<$tower_level>(&self.data, &mut result.data); + invert_or_zero::<$packed_storage, $tower_level>(&self.data, &mut result.data); result } + #[inline] fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) { let mut result1 = Self::default(); let mut result2 = Self::default(); - for byte_num in 0..<$tower_level as TowerLevel>::WIDTH { - let (this_byte_result1, this_byte_result2) = + for byte_num in 0..<$tower_level as TowerLevel>::WIDTH { + (result1.data[byte_num], result2.data[byte_num]) = self.data[byte_num].interleave(other.data[byte_num], log_block_len); + } + + (result1, result2) + } - result1.data[byte_num] = this_byte_result1; - result2.data[byte_num] = this_byte_result2; + #[inline] + fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) { + let mut result1 = Self::default(); + let mut result2 = Self::default(); + + for byte_num in 0..<$tower_level as TowerLevel>::WIDTH { + (result1.data[byte_num], result2.data[byte_num]) = + self.data[byte_num].unzip(other.data[byte_num], log_block_len); } (result1, result2) @@ -203,12 +212,9 @@ macro_rules! define_byte_sliced { type Output = Self; fn mul(self, rhs: Self) -> Self { - let mut result = $name { - data: [PackedAESBinaryField32x8b::default(); - <$tower_level as TowerLevel>::WIDTH], - }; + let mut result = Self::default(); - mul::<$tower_level>(&self.data, &rhs.data, &mut result.data); + mul::<$packed_storage, $tower_level>(&self.data, &rhs.data, &mut result.data); result } @@ -267,8 +273,38 @@ macro_rules! define_byte_sliced { }; } -define_byte_sliced!(ByteSlicedAES32x128b, AESTowerField128b, TowerLevel16); -define_byte_sliced!(ByteSlicedAES32x64b, AESTowerField64b, TowerLevel8); -define_byte_sliced!(ByteSlicedAES32x32b, AESTowerField32b, TowerLevel4); -define_byte_sliced!(ByteSlicedAES32x16b, AESTowerField16b, TowerLevel2); -define_byte_sliced!(ByteSlicedAES32x8b, AESTowerField8b, TowerLevel1); +// 128 bit +define_byte_sliced!( + ByteSlicedAES16x128b, + AESTowerField128b, + PackedAESBinaryField16x8b, + TowerLevel16 +); +define_byte_sliced!(ByteSlicedAES16x64b, AESTowerField64b, PackedAESBinaryField16x8b, TowerLevel8); +define_byte_sliced!(ByteSlicedAES16x32b, AESTowerField32b, PackedAESBinaryField16x8b, TowerLevel4); +define_byte_sliced!(ByteSlicedAES16x16b, AESTowerField16b, PackedAESBinaryField16x8b, TowerLevel2); +define_byte_sliced!(ByteSlicedAES16x8b, AESTowerField8b, PackedAESBinaryField16x8b, TowerLevel1); + +// 256 bit +define_byte_sliced!( + ByteSlicedAES32x128b, + AESTowerField128b, + PackedAESBinaryField32x8b, + TowerLevel16 +); +define_byte_sliced!(ByteSlicedAES32x64b, AESTowerField64b, PackedAESBinaryField32x8b, TowerLevel8); +define_byte_sliced!(ByteSlicedAES32x32b, AESTowerField32b, PackedAESBinaryField32x8b, TowerLevel4); +define_byte_sliced!(ByteSlicedAES32x16b, AESTowerField16b, PackedAESBinaryField32x8b, TowerLevel2); +define_byte_sliced!(ByteSlicedAES32x8b, AESTowerField8b, PackedAESBinaryField32x8b, TowerLevel1); + +// 512 bit +define_byte_sliced!( + ByteSlicedAES64x128b, + AESTowerField128b, + PackedAESBinaryField64x8b, + TowerLevel16 +); +define_byte_sliced!(ByteSlicedAES64x64b, AESTowerField64b, PackedAESBinaryField64x8b, TowerLevel8); +define_byte_sliced!(ByteSlicedAES64x32b, AESTowerField32b, PackedAESBinaryField64x8b, TowerLevel4); +define_byte_sliced!(ByteSlicedAES64x16b, AESTowerField16b, PackedAESBinaryField64x8b, TowerLevel2); +define_byte_sliced!(ByteSlicedAES64x8b, AESTowerField8b, PackedAESBinaryField64x8b, TowerLevel1); diff --git a/crates/field/src/arch/portable/byte_sliced/square.rs b/crates/field/src/arch/portable/byte_sliced/square.rs index bcd9514aa..0c0b6ab70 100644 --- a/crates/field/src/arch/portable/byte_sliced/square.rs +++ b/crates/field/src/arch/portable/byte_sliced/square.rs @@ -3,24 +3,27 @@ use super::multiply::mul_alpha; use crate::{ tower_levels::{TowerLevel, TowerLevelWithArithOps}, underlier::WithUnderlier, - AESTowerField8b, PackedAESBinaryField32x8b, PackedField, + AESTowerField8b, PackedField, }; #[inline(always)] -pub fn square>( - field_element: &Level::Data, - destination: &mut Level::Data, +pub fn square, Level: TowerLevel>( + field_element: &Level::Data

, + destination: &mut Level::Data

, ) { - let base_alpha = - PackedAESBinaryField32x8b::from_scalars([AESTowerField8b::from_underlier(0xd3); 32]); - square_main::(field_element, destination, base_alpha); + let base_alpha = P::broadcast(AESTowerField8b::from_underlier(0xd3)); + square_main::(field_element, destination, base_alpha); } #[inline(always)] -pub fn square_main>( - field_element: &Level::Data, - destination: &mut Level::Data, - base_alpha: PackedAESBinaryField32x8b, +pub fn square_main< + const WRITING_TO_ZEROS: bool, + P: PackedField, + Level: TowerLevel, +>( + field_element: &Level::Data

, + destination: &mut Level::Data

, + base_alpha: P, ) { if Level::WIDTH == 1 { if WRITING_TO_ZEROS { @@ -34,15 +37,13 @@ pub fn square_main>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut a1_squared = <::Base as TowerLevel>::default(); - square_main::(a1, &mut a1_squared, base_alpha); + square_main::(a1, &mut a1_squared, base_alpha); - mul_alpha::(&a1_squared, result1, base_alpha); + mul_alpha::(&a1_squared, result1, base_alpha); - square_main::(a0, result0, base_alpha); + square_main::(a0, result0, base_alpha); Level::Base::add_into(&a1_squared, result0); } diff --git a/crates/field/src/arch/portable/packed.rs b/crates/field/src/arch/portable/packed.rs index ad372dd9e..66bbb9650 100644 --- a/crates/field/src/arch/portable/packed.rs +++ b/crates/field/src/arch/portable/packed.rs @@ -336,6 +336,14 @@ where (c.into(), d.into()) } + #[inline] + fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) { + assert!(log_block_len < Self::LOG_WIDTH); + let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize; + let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len); + (c.into(), d.into()) + } + #[inline] unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self { debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH); diff --git a/crates/field/src/arch/portable/packed_arithmetic.rs b/crates/field/src/arch/portable/packed_arithmetic.rs index c15f65fab..98d0f1304 100644 --- a/crates/field/src/arch/portable/packed_arithmetic.rs +++ b/crates/field/src/arch/portable/packed_arithmetic.rs @@ -37,6 +37,18 @@ where (c, d) } + + /// Transpose with the given bit size + fn transpose(mut self, mut other: Self, log_block_len: usize) -> (Self, Self) { + // There are 2^7 = 128 bits in a u128 + assert!(log_block_len < Self::INTERLEAVE_EVEN_MASK.len()); + + for log_block_len in (log_block_len..Self::LOG_BITS).rev() { + (self, other) = self.interleave(other, log_block_len); + } + + (self, other) + } } /// Abstraction for a packed tower field of height greater than 0. @@ -331,7 +343,7 @@ where OP: PackedBinaryField, { pub fn new + Sync>( - transformation: FieldLinearTransformation, + transformation: &FieldLinearTransformation, ) -> Self { Self { bases: transformation @@ -387,7 +399,7 @@ where fn make_packed_transformation + Sync>( transformation: FieldLinearTransformation, ) -> Self::PackedTransformation { - PackedTransformation::new(transformation) + PackedTransformation::new(&transformation) } } diff --git a/crates/field/src/arch/portable/packed_scaled.rs b/crates/field/src/arch/portable/packed_scaled.rs index a4b55a2cf..8417939e5 100644 --- a/crates/field/src/arch/portable/packed_scaled.rs +++ b/crates/field/src/arch/portable/packed_scaled.rs @@ -216,14 +216,17 @@ where Self(array::from_fn(|_| PT::random(&mut rng))) } + #[inline] fn broadcast(scalar: Self::Scalar) -> Self { Self(array::from_fn(|_| PT::broadcast(scalar))) } + #[inline] fn square(self) -> Self { Self(self.0.map(|v| v.square())) } + #[inline] fn invert_or_zero(self) -> Self { Self(self.0.map(|v| v.invert_or_zero())) } @@ -253,6 +256,40 @@ where (Self(first), Self(second)) } + fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) { + let mut first = [Default::default(); N]; + let mut second = [Default::default(); N]; + + if log_block_len >= PT::LOG_WIDTH { + let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH); + for i in (0..N / 2).step_by(block_in_pts) { + first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]); + + second[i..i + block_in_pts] + .copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]); + } + + for i in (0..N / 2).step_by(block_in_pts) { + first[i + N / 2..i + N / 2 + block_in_pts] + .copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]); + + second[i + N / 2..i + N / 2 + block_in_pts] + .copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]); + } + } else { + for i in 0..N / 2 { + (first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len); + } + + for i in 0..N / 2 { + (first[i + N / 2], second[i + N / 2]) = + other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len); + } + } + + (Self(first), Self(second)) + } + #[inline] unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self { let log_n = checked_log_2(N); @@ -360,20 +397,23 @@ macro_rules! packed_scaled_field { impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name { type Output = Self; - fn add(self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { - let mut result = Self::default(); - for i in 0..Self::WIDTH_IN_PT { - result.0[i] = self.0[i] + rhs; + #[inline] + fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v += broadcast; } - result + self } } impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name { + #[inline] fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) { - for i in 0..Self::WIDTH_IN_PT { - self.0[i] += rhs; + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v += broadcast; } } } @@ -381,20 +421,23 @@ macro_rules! packed_scaled_field { impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name { type Output = Self; - fn sub(self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { - let mut result = Self::default(); - for i in 0..Self::WIDTH_IN_PT { - result.0[i] = self.0[i] - rhs; + #[inline] + fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v -= broadcast; } - result + self } } impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name { + #[inline] fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) { - for i in 0..Self::WIDTH_IN_PT { - self.0[i] -= rhs; + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v -= broadcast; } } } @@ -402,20 +445,23 @@ macro_rules! packed_scaled_field { impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name { type Output = Self; - fn mul(self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { - let mut result = Self::default(); - for i in 0..Self::WIDTH_IN_PT { - result.0[i] = self.0[i] * rhs; + #[inline] + fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v *= broadcast; } - result + self } } impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name { + #[inline] fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) { - for i in 0..Self::WIDTH_IN_PT { - self.0[i] *= rhs; + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v *= broadcast; } } } diff --git a/crates/field/src/arch/x86_64/m128.rs b/crates/field/src/arch/x86_64/m128.rs index f75d152d2..d1d715f6d 100644 --- a/crates/field/src/arch/x86_64/m128.rs +++ b/crates/field/src/arch/x86_64/m128.rs @@ -13,7 +13,7 @@ use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use crate::{ arch::{ - binary_utils::{as_array_mut, make_func_to_i8}, + binary_utils::{as_array_mut, as_array_ref, make_func_to_i8}, portable::{ packed::{impl_pack_scalar, PackedPrimitiveType}, packed_arithmetic::{ @@ -23,8 +23,9 @@ use crate::{ }, arithmetic_traits::Broadcast, underlier::{ - impl_divisible, impl_iteration, spread_fallback, NumCast, Random, SmallU, SpreadToByte, - UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4, + impl_divisible, impl_iteration, spread_fallback, unpack_hi_128b_fallback, + unpack_lo_128b_fallback, NumCast, Random, SmallU, SpreadToByte, UnderlierType, + UnderlierWithBitOps, WithUnderlier, U1, U2, U4, }, BinaryField, }; @@ -181,7 +182,7 @@ impl Not for M128 { } /// `std::cmp::max` isn't const, so we need our own implementation -const fn max_i32(left: i32, right: i32) -> i32 { +pub(crate) const fn max_i32(left: i32, right: i32) -> i32 { if left > right { left } else { @@ -193,22 +194,37 @@ const fn max_i32(left: i32, right: i32) -> i32 { /// We have to use macro because parameter `count` in _mm_slli_epi64/_mm_srli_epi64 should be passed as constant /// and Rust currently doesn't allow passing expressions (`count - 64`) where variable is a generic constant parameter. /// Source: https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688 -macro_rules! bitshift_right { - ($val:expr, $count:literal) => { +macro_rules! bitshift_128b { + ($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => { unsafe { - let carry = _mm_bsrli_si128($val, 8); - if $count >= 64 { - _mm_srli_epi64(carry, max_i32($count - 64, 0)) - } else { - let carry = _mm_slli_epi64(carry, max_i32(64 - $count, 0)); - - let val = _mm_srli_epi64($val, $count); - _mm_or_si128(val, carry) - } + let carry = $byte_shift($val, 8); + seq!(N in 64..128 { + if $shift == N { + return $bit_shift_64( + carry, + crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _, + ).into(); + } + }); + seq!(N in 0..64 { + if $shift == N { + let carry = $bit_shift_64_opposite( + carry, + crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _, + ); + + let val = $bit_shift_64($val, N); + return $or(val, carry).into(); + } + }); + + return Default::default() } }; } +pub(crate) use bitshift_128b; + impl Shr for M128 { type Output = Self; @@ -216,32 +232,10 @@ impl Shr for M128 { fn shr(self, rhs: usize) -> Self::Output { // This implementation is effective when `rhs` is known at compile-time. // In our code this is always the case. - seq!(N in 0..128 { - if rhs == N { - return Self(bitshift_right!(self.0, N)); - } - }); - - Self::default() + bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128) } } -macro_rules! bitshift_left { - ($val:expr, $count:literal) => { - unsafe { - let carry = _mm_bslli_si128($val, 8); - if $count >= 64 { - _mm_slli_epi64(carry, max_i32($count - 64, 0)) - } else { - let carry = _mm_srli_epi64(carry, max_i32(64 - $count, 0)); - - let val = _mm_slli_epi64($val, $count); - _mm_or_si128(val, carry) - } - } - }; -} - impl Shl for M128 { type Output = Self; @@ -249,13 +243,7 @@ impl Shl for M128 { fn shl(self, rhs: usize) -> Self::Output { // This implementation is effective when `rhs` is known at compile-time. // In our code this is always the case. - seq!(N in 0..128 { - if rhs == N { - return Self(bitshift_left!(self.0, N)); - } - }); - - Self::default() + bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128); } } @@ -420,39 +408,41 @@ impl UnderlierWithBitOps for M128 { #[inline(always)] unsafe fn get_subvalue(&self, i: usize) -> T where - T: WithUnderlier, - T::Underlier: NumCast, + T: UnderlierType + NumCast, { - match T::Underlier::BITS { - 1 | 2 | 4 | 8 | 16 | 32 | 64 => { - let elements_in_64 = 64 / T::Underlier::BITS; - let chunk_64 = unsafe { - if i >= elements_in_64 { - _mm_extract_epi64(self.0, 1) - } else { - _mm_extract_epi64(self.0, 0) - } - }; - - let result_64 = if T::Underlier::BITS == 64 { - chunk_64 - } else { - let ones = ((1u128 << T::Underlier::BITS) - 1) as u64; - let val_64 = (chunk_64 as u64) - >> (T::Underlier::BITS - * (if i >= elements_in_64 { - i - elements_in_64 - } else { - i - })) & ones; - - val_64 as i64 - }; - T::from_underlier(T::Underlier::num_cast_from(Self(unsafe { - _mm_set_epi64x(0, result_64) - }))) + match T::BITS { + 1 | 2 | 4 => { + let elements_in_8 = 8 / T::BITS; + let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { + *arr.get_unchecked(i / elements_in_8) + }); + + let shift = (i % elements_in_8) * T::BITS; + value_u8 >>= shift; + + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 8 => { + let value_u8 = + as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 16 => { + let value_u16 = + as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u16))) + } + 32 => { + let value_u32 = + as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u32))) } - 128 => T::from_underlier(T::Underlier::num_cast_from(*self)), + 64 => { + let value_u64 = + as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u64))) + } + 128 => T::from_underlier(T::num_cast_from(*self)), _ => panic!("unsupported bit count"), } } @@ -471,23 +461,23 @@ impl UnderlierWithBitOps for M128 { let val = u8::num_cast_from(Self::from(val)) << shift; let mask = mask << shift; - as_array_mut::<_, u8, 16>(self, |array| { - let element = &mut array[i / elements_in_8]; + as_array_mut::<_, u8, 16>(self, |array| unsafe { + let element = array.get_unchecked_mut(i / elements_in_8); *element &= !mask; *element |= val; }); } - 8 => as_array_mut::<_, u8, 16>(self, |array| { - array[i] = u8::num_cast_from(Self::from(val)); + 8 => as_array_mut::<_, u8, 16>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val)); }), - 16 => as_array_mut::<_, u16, 8>(self, |array| { - array[i] = u16::num_cast_from(Self::from(val)); + 16 => as_array_mut::<_, u16, 8>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val)); }), - 32 => as_array_mut::<_, u32, 4>(self, |array| { - array[i] = u32::num_cast_from(Self::from(val)); + 32 => as_array_mut::<_, u32, 4>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val)); }), - 64 => as_array_mut::<_, u64, 2>(self, |array| { - array[i] = u64::num_cast_from(Self::from(val)); + 64 => as_array_mut::<_, u64, 2>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val)); }), 128 => { *self = Self::from(val); @@ -705,6 +695,40 @@ impl UnderlierWithBitOps for M128 { _ => panic!("unsupported bit length"), } } + + #[inline] + fn shl_128b_lanes(self, shift: usize) -> Self { + self << shift + } + + #[inline] + fn shr_128b_lanes(self, shift: usize) -> Self { + self >> shift + } + + #[inline] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } + + #[inline] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_hi_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } } unsafe impl Zeroable for M128 {} @@ -922,6 +946,10 @@ mod tests { assert_eq!(M128::from(1u128), M128::ONE); } + fn get(value: M128, log_block_len: usize, index: usize) -> M128 { + (value >> (index << log_block_len)) & single_element_mask_bits::(1 << log_block_len) + } + proptest! { #[test] fn test_conversion(a in any::()) { @@ -955,15 +983,36 @@ mod tests { let (c, d) = unsafe {interleave_bits(a.0, b.0, height)}; let (c, d) = (M128::from(c), M128::from(d)); - let block_len = 1usize << height; - let get = |v, i| { - u128::num_cast_from((v >> (i * block_len)) & single_element_mask_bits::(1 << height)) - }; - for i in (0..128/block_len).step_by(2) { - assert_eq!(get(c, i), get(a, i)); - assert_eq!(get(c, i+1), get(b, i)); - assert_eq!(get(d, i), get(a, i+1)); - assert_eq!(get(d, i+1), get(b, i+1)); + for i in (0..128>>height).step_by(2) { + assert_eq!(get(c, height, i), get(a, height, i)); + assert_eq!(get(c, height, i+1), get(b, height, i)); + assert_eq!(get(d, height, i), get(a, height, i+1)); + assert_eq!(get(d, height, i+1), get(b, height, i+1)); + } + } + + #[test] + fn test_unpack_lo(a in any::(), b in any::(), height in 1usize..7) { + let a = M128::from(a); + let b = M128::from(b); + + let result = a.unpack_lo_128b_lanes(b, height); + for i in 0..128>>(height + 1) { + assert_eq!(get(result, height, 2*i), get(a, height, i)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i)); + } + } + + #[test] + fn test_unpack_hi(a in any::(), b in any::(), height in 1usize..7) { + let a = M128::from(a); + let b = M128::from(b); + + let result = a.unpack_hi_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count)); } } } diff --git a/crates/field/src/arch/x86_64/m256.rs b/crates/field/src/arch/x86_64/m256.rs index 3c36827fc..a2fc71ab8 100644 --- a/crates/field/src/arch/x86_64/m256.rs +++ b/crates/field/src/arch/x86_64/m256.rs @@ -9,6 +9,7 @@ use std::{ use bytemuck::{must_cast, Pod, Zeroable}; use cfg_if::cfg_if; use rand::{Rng, RngCore}; +use seq_macro::seq; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use crate::{ @@ -20,11 +21,13 @@ use crate::{ interleave_mask_even, interleave_mask_odd, UnderlierWithBitConstants, }, }, + x86_64::m128::bitshift_128b, }, arithmetic_traits::Broadcast, underlier::{ get_block_values, get_spread_bytes, impl_divisible, impl_iteration, spread_fallback, - NumCast, Random, SmallU, UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4, + unpack_hi_128b_fallback, unpack_lo_128b_fallback, NumCast, Random, SmallU, UnderlierType, + UnderlierWithBitOps, WithUnderlier, U1, U2, U4, }, BinaryField, }; @@ -323,7 +326,7 @@ impl UnderlierType for M256 { impl UnderlierWithBitOps for M256 { const ZERO: Self = { Self(m256_from_u128s!(0, 0,)) }; - const ONE: Self = { Self(m256_from_u128s!(0, 1,)) }; + const ONE: Self = { Self(m256_from_u128s!(1, 0,)) }; const ONES: Self = { Self(m256_from_u128s!(u128::MAX, u128::MAX,)) }; #[inline] @@ -463,42 +466,41 @@ impl UnderlierWithBitOps for M256 { T: UnderlierType + NumCast, { match T::BITS { - 1 | 2 | 4 | 8 | 16 | 32 => { - let elements_in_64 = 64 / T::BITS; - let chunk_64 = unsafe { - match i / elements_in_64 { - 0 => _mm256_extract_epi64(self.0, 0), - 1 => _mm256_extract_epi64(self.0, 1), - 2 => _mm256_extract_epi64(self.0, 2), - _ => _mm256_extract_epi64(self.0, 3), - } - }; + 1 | 2 | 4 => { + let elements_in_8 = 8 / T::BITS; + let mut value_u8 = as_array_ref::<_, u8, 32, _>(self, |arr| unsafe { + *arr.get_unchecked(i / elements_in_8) + }); - let result_64 = if T::BITS == 64 { - chunk_64 - } else { - let ones = ((1u128 << T::BITS) - 1) as u64; - let val_64 = (chunk_64 as u64) >> (T::BITS * (i % elements_in_64)) & ones; + let shift = (i % elements_in_8) * T::BITS; + value_u8 >>= shift; - val_64 as i64 - }; - T::num_cast_from(Self(unsafe { _mm256_set_epi64x(0, 0, 0, result_64) })) + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 8 => { + let value_u8 = + as_array_ref::<_, u8, 32, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 16 => { + let value_u16 = + as_array_ref::<_, u16, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u16))) + } + 32 => { + let value_u32 = + as_array_ref::<_, u32, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u32))) } - // NOTE: benchmark show that this strategy is optimal for getting 64-bit subvalues from 256-bit register. - // However using similar code for 1..32 bits is slower than the version above. - // Also even getting `chunk_64` in the code above using this code shows worser benchmarks results. 64 => { - T::num_cast_from(as_array_ref::<_, u64, 4, _>(self, |array| Self::from(array[i]))) + let value_u64 = + as_array_ref::<_, u64, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u64))) } 128 => { - let chunk_128 = unsafe { - if i == 0 { - _mm256_extracti128_si256(self.0, 0) - } else { - _mm256_extracti128_si256(self.0, 1) - } - }; - T::num_cast_from(Self(unsafe { _mm256_set_m128i(_mm_setzero_si128(), chunk_128) })) + let value_u128 = + as_array_ref::<_, u128, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u128))) } _ => panic!("unsupported bit count"), } @@ -518,26 +520,26 @@ impl UnderlierWithBitOps for M256 { let val = u8::num_cast_from(Self::from(val)) << shift; let mask = mask << shift; - as_array_mut::<_, u8, 32>(self, |array| { - let element = &mut array[i / elements_in_8]; + as_array_mut::<_, u8, 32>(self, |array| unsafe { + let element = array.get_unchecked_mut(i / elements_in_8); *element &= !mask; *element |= val; }); } - 8 => as_array_mut::<_, u8, 32>(self, |array| { - array[i] = u8::num_cast_from(Self::from(val)); + 8 => as_array_mut::<_, u8, 32>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val)); }), - 16 => as_array_mut::<_, u16, 16>(self, |array| { - array[i] = u16::num_cast_from(Self::from(val)); + 16 => as_array_mut::<_, u16, 16>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val)); }), - 32 => as_array_mut::<_, u32, 8>(self, |array| { - array[i] = u32::num_cast_from(Self::from(val)); + 32 => as_array_mut::<_, u32, 8>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val)); }), - 64 => as_array_mut::<_, u64, 4>(self, |array| { - array[i] = u64::num_cast_from(Self::from(val)); + 64 => as_array_mut::<_, u64, 4>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val)); }), 128 => as_array_mut::<_, u128, 2>(self, |array| { - array[i] = u128::num_cast_from(Self::from(val)); + *array.get_unchecked_mut(i) = u128::num_cast_from(Self::from(val)); }), _ => panic!("unsupported bit count"), } @@ -835,6 +837,58 @@ impl UnderlierWithBitOps for M256 { _ => spread_fallback(self, log_block_len, block_idx), } } + + #[inline] + fn shr_128b_lanes(self, rhs: usize) -> Self { + // This implementation is effective when `rhs` is known at compile-time. + // In our code this is always the case. + bitshift_128b!( + self.0, + rhs, + _mm256_bsrli_epi128, + _mm256_srli_epi64, + _mm256_slli_epi64, + _mm256_or_si256 + ) + } + + #[inline] + fn shl_128b_lanes(self, rhs: usize) -> Self { + // This implementation is effective when `rhs` is known at compile-time. + // In our code this is always the case. + bitshift_128b!( + self.0, + rhs, + _mm256_bslli_epi128, + _mm256_slli_epi64, + _mm256_srli_epi64, + _mm256_or_si256 + ); + } + + #[inline] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm256_unpacklo_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm256_unpacklo_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm256_unpacklo_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm256_unpacklo_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } + + #[inline] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_hi_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm256_unpackhi_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm256_unpackhi_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm256_unpackhi_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm256_unpackhi_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } } unsafe impl Zeroable for M256 {} @@ -910,6 +964,13 @@ impl UnderlierWithBitConstants for M256 { let (a, b) = unsafe { interleave_bits(self.0, other.0, log_block_len) }; (Self(a), Self(b)) } + + fn transpose(mut self, mut other: Self, log_block_len: usize) -> (Self, Self) { + let (a, b) = unsafe { transpose_bits(self.0, other.0, log_block_len) }; + self.0 = a; + other.0 = b; + (self, other) + } } #[inline] @@ -974,6 +1035,57 @@ unsafe fn interleave_bits(a: __m256i, b: __m256i, log_block_len: usize) -> (__m2 } } +#[inline] +unsafe fn transpose_bits(a: __m256i, b: __m256i, log_block_len: usize) -> (__m256i, __m256i) { + match log_block_len { + 0..=3 => { + let shuffle = _mm256_set_epi8( + 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0, 15, 13, 11, 9, 7, 5, 3, 1, + 14, 12, 10, 8, 6, 4, 2, 0, + ); + let (mut a, mut b) = transpose_with_shuffle(a, b, shuffle); + for log_block_len in (log_block_len..3).rev() { + (a, b) = interleave_bits(a, b, log_block_len); + } + + (a, b) + } + 4 => { + let shuffle = _mm256_set_epi8( + 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, + 13, 12, 9, 8, 5, 4, 1, 0, + ); + + transpose_with_shuffle(a, b, shuffle) + } + 5 => { + let shuffle = _mm256_set_epi8( + 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0, 15, 14, 13, 12, 7, 6, 5, 4, + 11, 10, 9, 8, 3, 2, 1, 0, + ); + + transpose_with_shuffle(a, b, shuffle) + } + 6 => { + let (a, b) = (_mm256_unpacklo_epi64(a, b), _mm256_unpackhi_epi64(a, b)); + + (_mm256_permute4x64_epi64(a, 0b11011000), _mm256_permute4x64_epi64(b, 0b11011000)) + } + 7 => (_mm256_permute2x128_si256(a, b, 0x20), _mm256_permute2x128_si256(a, b, 0x31)), + _ => panic!("unsupported block length"), + } +} + +#[inline(always)] +unsafe fn transpose_with_shuffle(a: __m256i, b: __m256i, shuffle: __m256i) -> (__m256i, __m256i) { + let a = _mm256_shuffle_epi8(a, shuffle); + let b = _mm256_shuffle_epi8(b, shuffle); + + let (a, b) = (_mm256_unpacklo_epi64(a, b), _mm256_unpackhi_epi64(a, b)); + + (_mm256_permute4x64_epi64(a, 0b11011000), _mm256_permute4x64_epi64(b, 0b11011000)) +} + #[inline] unsafe fn interleave_bits_imm( a: __m256i, @@ -1075,7 +1187,7 @@ mod tests { fn test_constants() { assert_eq!(M256::default(), M256::ZERO); assert_eq!(M256::from(0u128), M256::ZERO); - assert_eq!(M256::from([0u128, 1u128]), M256::ONE); + assert_eq!(M256::from([1u128, 0u128]), M256::ONE); } #[derive(Default)] @@ -1139,6 +1251,10 @@ mod tests { } } + fn get(value: M256, log_block_len: usize, index: usize) -> M256 { + (value >> (index << log_block_len)) & single_element_mask_bits::(1 << log_block_len) + } + proptest! { #[allow(clippy::tuple_array_conversions)] // false positive #[test] @@ -1175,14 +1291,41 @@ mod tests { let (c, d) = (M256::from(c), M256::from(d)); let block_len = 1usize << height; - let get = |v, i| { - u128::num_cast_from((v >> (i * block_len)) & single_element_mask_bits::(1 << height)) - }; for i in (0..256/block_len).step_by(2) { - assert_eq!(get(c, i), get(a, i)); - assert_eq!(get(c, i+1), get(b, i)); - assert_eq!(get(d, i), get(a, i+1)); - assert_eq!(get(d, i+1), get(b, i+1)); + assert_eq!(get(c, height, i), get(a, height, i)); + assert_eq!(get(c, height, i+1), get(b, height, i)); + assert_eq!(get(d, height, i), get(a, height, i+1)); + assert_eq!(get(d, height, i+1), get(b, height, i+1)); + } + } + + #[test] + fn test_unpack_lo(a in any::<[u128; 2]>(), b in any::<[u128; 2]>(), height in 0usize..7) { + let a = M256::from(a); + let b = M256::from(b); + + let result = a.unpack_lo_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i)); + assert_eq!(get(result, height, 2*(i + half_block_count)), get(a, height, 2 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + half_block_count)+1), get(b, height, 2 * half_block_count + i)); + } + } + + #[test] + fn test_unpack_hi(a in any::<[u128; 2]>(), b in any::<[u128; 2]>(), height in 0usize..7) { + let a = M256::from(a); + let b = M256::from(b); + + let result = a.unpack_hi_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count)); + assert_eq!(get(result, height, 2*(half_block_count + i)), get(a, height, 3*half_block_count + i)); + assert_eq!(get(result, height, 2*(half_block_count + i) +1), get(b, height, 3*half_block_count + i)); } } } diff --git a/crates/field/src/arch/x86_64/m512.rs b/crates/field/src/arch/x86_64/m512.rs index caa821c62..8afc28251 100644 --- a/crates/field/src/arch/x86_64/m512.rs +++ b/crates/field/src/arch/x86_64/m512.rs @@ -8,23 +8,28 @@ use std::{ use bytemuck::{must_cast, Pod, Zeroable}; use rand::{Rng, RngCore}; +use seq_macro::seq; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use crate::{ arch::{ - binary_utils::{as_array_mut, make_func_to_i8}, + binary_utils::{as_array_mut, as_array_ref, make_func_to_i8}, portable::{ packed::{impl_pack_scalar, PackedPrimitiveType}, packed_arithmetic::{ interleave_mask_even, interleave_mask_odd, UnderlierWithBitConstants, }, }, - x86_64::{m128::M128, m256::M256}, + x86_64::{ + m128::{bitshift_128b, M128}, + m256::M256, + }, }, arithmetic_traits::Broadcast, underlier::{ get_block_values, get_spread_bytes, impl_divisible, impl_iteration, spread_fallback, - NumCast, Random, SmallU, UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4, + unpack_hi_128b_fallback, unpack_lo_128b_fallback, NumCast, Random, SmallU, UnderlierType, + UnderlierWithBitOps, WithUnderlier, U1, U2, U4, }, BinaryField, }; @@ -370,7 +375,7 @@ impl UnderlierType for M512 { impl UnderlierWithBitOps for M512 { const ZERO: Self = { Self(m512_from_u128s!(0, 0, 0, 0,)) }; - const ONE: Self = { Self(m512_from_u128s!(0, 0, 0, 1,)) }; + const ONE: Self = { Self(m512_from_u128s!(1, 0, 0, 0,)) }; const ONES: Self = { Self(m512_from_u128s!(u128::MAX, u128::MAX, u128::MAX, u128::MAX,)) }; #[inline(always)] @@ -617,31 +622,41 @@ impl UnderlierWithBitOps for M512 { T: UnderlierType + NumCast, { match T::BITS { - 1 | 2 | 4 | 8 | 16 | 32 | 64 => { - let elements_in_64 = 64 / T::BITS; - let shuffle = unsafe { _mm512_set1_epi64((i / elements_in_64) as i64) }; - let chunk_64 = - u64::num_cast_from(Self(unsafe { _mm512_permutexvar_epi64(shuffle, self.0) })); - - let result_64 = if T::BITS == 64 { - chunk_64 - } else { - let ones = ((1u128 << T::BITS) - 1) as u64; - (chunk_64 >> (T::BITS * (i % elements_in_64))) & ones - }; + 1 | 2 | 4 => { + let elements_in_8 = 8 / T::BITS; + let mut value_u8 = as_array_ref::<_, u8, 64, _>(self, |arr| unsafe { + *arr.get_unchecked(i / elements_in_8) + }); + + let shift = (i % elements_in_8) * T::BITS; + value_u8 >>= shift; - T::num_cast_from(Self::from(result_64)) + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 8 => { + let value_u8 = + as_array_ref::<_, u8, 64, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 16 => { + let value_u16 = + as_array_ref::<_, u16, 32, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u16))) + } + 32 => { + let value_u32 = + as_array_ref::<_, u32, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u32))) + } + 64 => { + let value_u64 = + as_array_ref::<_, u64, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u64))) } 128 => { - let chunk_128 = unsafe { - match i { - 0 => _mm512_extracti32x4_epi32(self.0, 0), - 1 => _mm512_extracti32x4_epi32(self.0, 1), - 2 => _mm512_extracti32x4_epi32(self.0, 2), - _ => _mm512_extracti32x4_epi32(self.0, 3), - } - }; - T::num_cast_from(Self(unsafe { _mm512_castsi128_si512(chunk_128) })) + let value_u128 = + as_array_ref::<_, u128, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u128))) } _ => panic!("unsupported bit count"), } @@ -661,26 +676,26 @@ impl UnderlierWithBitOps for M512 { let val = u8::num_cast_from(Self::from(val)) << shift; let mask = mask << shift; - as_array_mut::<_, u8, 64>(self, |array| { - let element = &mut array[i / elements_in_8]; + as_array_mut::<_, u8, 64>(self, |array| unsafe { + let element = array.get_unchecked_mut(i / elements_in_8); *element &= !mask; *element |= val; }); } - 8 => as_array_mut::<_, u8, 64>(self, |array| { - array[i] = u8::num_cast_from(Self::from(val)); + 8 => as_array_mut::<_, u8, 64>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val)); }), - 16 => as_array_mut::<_, u16, 32>(self, |array| { - array[i] = u16::num_cast_from(Self::from(val)); + 16 => as_array_mut::<_, u16, 32>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val)); }), - 32 => as_array_mut::<_, u32, 16>(self, |array| { - array[i] = u32::num_cast_from(Self::from(val)); + 32 => as_array_mut::<_, u32, 16>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val)); }), - 64 => as_array_mut::<_, u64, 8>(self, |array| { - array[i] = u64::num_cast_from(Self::from(val)); + 64 => as_array_mut::<_, u64, 8>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val)); }), 128 => as_array_mut::<_, u128, 4>(self, |array| { - array[i] = u128::num_cast_from(Self::from(val)); + *array.get_unchecked_mut(i) = u128::num_cast_from(Self::from(val)); }), _ => panic!("unsupported bit count"), } @@ -855,6 +870,58 @@ impl UnderlierWithBitOps for M512 { _ => spread_fallback(self, log_block_len, block_idx), } } + + #[inline] + fn shr_128b_lanes(self, rhs: usize) -> Self { + // This implementation is effective when `rhs` is known at compile-time. + // In our code this is always the case. + bitshift_128b!( + self.0, + rhs, + _mm512_bsrli_epi128, + _mm512_srli_epi64, + _mm512_slli_epi64, + _mm512_or_si512 + ); + } + + #[inline] + fn shl_128b_lanes(self, rhs: usize) -> Self { + // This implementation is effective when `rhs` is known at compile-time. + // In our code this is always the case. + bitshift_128b!( + self.0, + rhs, + _mm512_bslli_epi128, + _mm512_slli_epi64, + _mm512_srli_epi64, + _mm512_or_si512 + ); + } + + #[inline] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm512_unpacklo_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm512_unpacklo_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm512_unpacklo_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm512_unpacklo_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } + + #[inline] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_hi_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm512_unpackhi_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm512_unpackhi_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm512_unpackhi_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm512_unpackhi_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } } unsafe impl Zeroable for M512 {} @@ -932,6 +999,12 @@ impl UnderlierWithBitConstants for M512 { let (a, b) = unsafe { interleave_bits(self.0, other.0, log_block_len) }; (Self(a), Self(b)) } + + #[inline(always)] + fn transpose(self, other: Self, log_bit_len: usize) -> (Self, Self) { + let (a, b) = unsafe { transpose_bits(self.0, other.0, log_bit_len) }; + (Self(a), Self(b)) + } } #[inline] @@ -1103,6 +1176,95 @@ const fn precompute_spread_mask( m512_masks } +#[inline(always)] +unsafe fn transpose_bits(a: __m512i, b: __m512i, log_block_len: usize) -> (__m512i, __m512i) { + match log_block_len { + 0..=3 => { + let shuffle = _mm512_set_epi8( + 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0, 15, 13, 11, 9, 7, 5, 3, 1, + 14, 12, 10, 8, 6, 4, 2, 0, 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0, + 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0, + ); + let (mut a, mut b) = transpose_with_shuffle(a, b, shuffle); + for log_block_len in (log_block_len..3).rev() { + (a, b) = interleave_bits(a, b, log_block_len); + } + + (a, b) + } + 4 => { + let shuffle = _mm512_set_epi8( + 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, + 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, + 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, + ); + transpose_with_shuffle(a, b, shuffle) + } + 5 => { + let shuffle = _mm512_set_epi8( + 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0, 15, 14, 13, 12, 7, 6, 5, 4, + 11, 10, 9, 8, 3, 2, 1, 0, 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0, 15, + 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0, + ); + transpose_with_shuffle(a, b, shuffle) + } + 6 => ( + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1110, 0b1100, 0b1010, 0b1000, 0b0110, 0b0100, 0b0010, 0b0000), + b, + ), + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1111, 0b1101, 0b1011, 0b1001, 0b0111, 0b0101, 0b0011, 0b0001), + b, + ), + ), + 7 => ( + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1101, 0b1100, 0b1001, 0b1000, 0b0101, 0b0100, 0b0001, 0b0000), + b, + ), + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1111, 0b1110, 0b1011, 0b1010, 0b0111, 0b0110, 0b0011, 0b0010), + b, + ), + ), + 8 => ( + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1011, 0b1010, 0b1001, 0b1000, 0b0011, 0b0010, 0b0001, 0b0000), + b, + ), + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1111, 0b1110, 0b1101, 0b1100, 0b0111, 0b0110, 0b0101, 0b0100), + b, + ), + ), + _ => panic!("unsupported block length"), + } +} + +unsafe fn transpose_with_shuffle(a: __m512i, b: __m512i, shuffle: __m512i) -> (__m512i, __m512i) { + let (a, b) = (_mm512_shuffle_epi8(a, shuffle), _mm512_shuffle_epi8(b, shuffle)); + + ( + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1110, 0b1100, 0b1010, 0b1000, 0b0110, 0b0100, 0b0010, 0b0000), + b, + ), + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1111, 0b1101, 0b1011, 0b1001, 0b0111, 0b0101, 0b0011, 0b0001), + b, + ), + ) +} + impl_iteration!(M512, @strategy BitIterationStrategy, U1, @strategy FallbackStrategy, U2, U4, @@ -1128,7 +1290,7 @@ mod tests { fn test_constants() { assert_eq!(M512::default(), M512::ZERO); assert_eq!(M512::from(0u128), M512::ZERO); - assert_eq!(M512::from([0u128, 0u128, 0u128, 1u128]), M512::ONE); + assert_eq!(M512::from([1u128, 0u128, 0u128, 0u128]), M512::ONE); } #[derive(Default)] @@ -1192,6 +1354,10 @@ mod tests { } } + fn get(value: M512, log_block_len: usize, index: usize) -> M512 { + (value >> (index << log_block_len)) & single_element_mask_bits::(1 << log_block_len) + } + proptest! { #[test] fn test_conversion(a in any::<[u128; 4]>()) { @@ -1225,14 +1391,49 @@ mod tests { let (c, d) = (M512::from(c), M512::from(d)); let block_len = 1usize << height; - let get = |v, i| { - u128::num_cast_from((v >> (i * block_len)) & single_element_mask_bits::(1 << height)) - }; for i in (0..512/block_len).step_by(2) { - assert_eq!(get(c, i), get(a, i)); - assert_eq!(get(c, i+1), get(b, i)); - assert_eq!(get(d, i), get(a, i+1)); - assert_eq!(get(d, i+1), get(b, i+1)); + assert_eq!(get(c, height, i), get(a, height, i)); + assert_eq!(get(c, height, i+1), get(b, height, i)); + assert_eq!(get(d, height, i), get(a, height, i+1)); + assert_eq!(get(d, height, i+1), get(b, height, i+1)); + } + } + + #[test] + fn test_unpack_lo(a in any::<[u128; 4]>(), b in any::<[u128; 4]>(), height in 0usize..7) { + let a = M512::from(a); + let b = M512::from(b); + + let result = a.unpack_lo_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i)); + assert_eq!(get(result, height, 2*(i + half_block_count)), get(a, height, 2 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + half_block_count)+1), get(b, height, 2 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + 2*half_block_count)), get(a, height, 4 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + 2*half_block_count)+1), get(b, height, 4 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + 3*half_block_count)), get(a, height, 6 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + 3*half_block_count)+1), get(b, height, 6 * half_block_count + i)); + } + } + + #[test] + fn test_unpack_hi(a in any::<[u128; 4]>(), b in any::<[u128; 4]>(), height in 0usize..7) { + let a = M512::from(a); + let b = M512::from(b); + + let result = a.unpack_hi_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count)); + assert_eq!(get(result, height, 2*(half_block_count + i)), get(a, height, 3*half_block_count + i)); + assert_eq!(get(result, height, 2*(half_block_count + i) +1), get(b, height, 3*half_block_count + i)); + assert_eq!(get(result, height, 2*(2*half_block_count + i)), get(a, height, 5*half_block_count + i)); + assert_eq!(get(result, height, 2*(2*half_block_count + i) +1), get(b, height, 5*half_block_count + i)); + assert_eq!(get(result, height, 2*(3*half_block_count + i)), get(a, height, 7*half_block_count + i)); + assert_eq!(get(result, height, 2*(3*half_block_count + i) +1), get(b, height, 7*half_block_count + i)); } } } diff --git a/crates/field/src/binary_field.rs b/crates/field/src/binary_field.rs index be6f8ea80..31f30f32d 100644 --- a/crates/field/src/binary_field.rs +++ b/crates/field/src/binary_field.rs @@ -2,15 +2,16 @@ use std::{ any::TypeId, - array, fmt::{Debug, Display, Formatter}, iter::{Product, Sum}, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; -use binius_utils::serialization::{DeserializeBytes, Error as SerializationError, SerializeBytes}; +use binius_utils::{ + bytes::{Buf, BufMut}, + DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, +}; use bytemuck::{Pod, Zeroable}; -use bytes::{Buf, BufMut}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; @@ -18,7 +19,7 @@ use super::{ binary_field_arithmetic::TowerFieldArithmetic, error::Error, extension::ExtensionField, }; use crate::{ - underlier::{SmallU, U1, U2, U4}, + underlier::{U1, U2, U4}, Field, }; @@ -46,6 +47,13 @@ where /// Currently for every tower field, the canonical field is Fan-Paar's binary field of the same degree. type Canonical: TowerField + SerializeBytes + DeserializeBytes; + /// Returns the smallest valid `TOWER_LEVEL` in the tower that can fit the same value. + /// + /// Since which `TOWER_LEVEL` values are valid depends on the tower, + /// `F::Canonical::from(elem).min_tower_level()` can return a different result + /// from `elem.min_tower_level()`. + fn min_tower_level(self) -> usize; + fn basis(iota: usize, i: usize) -> Result { if iota > Self::TOWER_LEVEL { return Err(Error::ExtensionDegreeTooHigh); @@ -564,7 +572,6 @@ macro_rules! impl_field_extension { } impl ExtensionField<$subfield_name> for $name { - type Iterator = <[$subfield_name; 1 << $log_degree] as IntoIterator>::IntoIter; const LOG_DEGREE: usize = $log_degree; #[inline] @@ -597,15 +604,21 @@ macro_rules! impl_field_extension { } #[inline] - fn iter_bases(&self) -> Self::Iterator { - use $crate::underlier::NumCast; + fn iter_bases(&self) -> impl Iterator { + use $crate::underlier::{WithUnderlier, IterationMethods, IterationStrategy}; + use binius_utils::iter::IterExtensions; + + IterationMethods::<<$subfield_name as WithUnderlier>::Underlier, Self::Underlier>::ref_iter(&self.0) + .map_skippable($subfield_name::from) + } + + #[inline] + fn into_iter_bases(self) -> impl Iterator { + use $crate::underlier::{WithUnderlier, IterationMethods, IterationStrategy}; + use binius_utils::iter::IterExtensions; - let base_elems = array::from_fn(|i| { - <$subfield_name>::new(<$subfield_typ>::num_cast_from( - (self.0 >> (i * $subfield_name::N_BITS)), - )) - }); - base_elems.into_iter() + IterationMethods::<<$subfield_name as WithUnderlier>::Underlier, Self::Underlier>::value_iter(self.0) + .map_skippable($subfield_name::from) } } }; @@ -621,10 +634,16 @@ pub(super) trait MulPrimitive: Sized { #[macro_export] macro_rules! binary_tower { - ($subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty)) => { - binary_tower!($subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ, $name)); + (BinaryField1b($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty $(, $canonical:ident)?) $(< $extfield_name:ident($extfield_typ:ty $(, $canonical_ext:ident)?))+) => { + binary_tower!([BinaryField1b::TOWER_LEVEL]; BinaryField1b($subfield_typ $(, $canonical_subfield)?) < $name($typ $(, $canonical)?) $(< $extfield_name($extfield_typ $(, $canonical_ext)?))+); }; - ($subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty, $canonical:ident)) => { + ($subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty $(, $canonical:ident)?) $(< $extfield_name:ident($extfield_typ:ty $(, $canonical_ext:ident)?))+) => { + binary_tower!([BinaryField1b::TOWER_LEVEL, $subfield_name::TOWER_LEVEL]; $subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ $(, $canonical)?) $(< $extfield_name($extfield_typ $(, $canonical_ext)?))+); + }; + ([$($valid_tower_levels:tt)*]; $subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty)) => { + binary_tower!([$($valid_tower_levels)*]; $subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ, $name)); + }; + ([$($valid_tower_levels:tt)*]; $subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty, $canonical:ident)) => { impl From<$name> for ($subfield_name, $subfield_name) { #[inline] fn from(src: $name) -> ($subfield_name, $subfield_name) { @@ -648,6 +667,16 @@ macro_rules! binary_tower { type Canonical = $canonical; + fn min_tower_level(self) -> usize { + let zero = <$typ as $crate::underlier::UnderlierWithBitOps>::ZERO; + for level in [$($valid_tower_levels)*] { + if self.0 >> (1 << level) == zero { + return level; + } + } + Self::TOWER_LEVEL + } + fn mul_primitive(self, iota: usize) -> Result { ::mul_primitive(self, iota) } @@ -659,14 +688,13 @@ macro_rules! binary_tower { binary_tower!($subfield_name($subfield_typ) < @1 => $name($typ)); }; - ($subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty $(, $canonical:ident)?) $(< $extfield_name:ident($extfield_typ:ty $(, $canonical_ext:ident)?))+) => { - binary_tower!($subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ $(, $canonical)?)); - binary_tower!($name($typ $(, $canonical)?) $(< $extfield_name($extfield_typ $(, $canonical_ext)?))+); + ([$($valid_tower_levels:tt)*]; $subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty $(, $canonical:ident)?) $(< $extfield_name:ident($extfield_typ:ty $(, $canonical_ext:ident)?))+) => { + binary_tower!([$($valid_tower_levels)*]; $subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ $(, $canonical)?)); + binary_tower!([$($valid_tower_levels)*, $name::TOWER_LEVEL]; $name($typ $(, $canonical)?) $(< $extfield_name($extfield_typ $(, $canonical_ext)?))+); binary_tower!($subfield_name($subfield_typ) < @2 => $($extfield_name($extfield_typ))<+); }; ($subfield_name:ident($subfield_typ:ty) < @$log_degree:expr => $name:ident($typ:ty)) => { $crate::binary_field::impl_field_extension!($subfield_name($subfield_typ) < @$log_degree => $name($typ)); - $crate::binary_field::binary_tower_subfield_mul!($subfield_name, $name); }; ($subfield_name:ident($subfield_typ:ty) < @$log_degree:expr => $name:ident($typ:ty) $(< $extfield_name:ident($extfield_typ:ty))+) => { @@ -707,76 +735,36 @@ pub fn is_canonical_tower() -> bool { } macro_rules! serialize_deserialize { - ($bin_type:ty, SmallU<$U:literal>) => { - impl SerializeBytes for $bin_type { - fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> { - if write_buf.remaining_mut() < 1 { - ::binius_utils::bail!(SerializationError::WriteBufferFull); - } - let b = self.0.val(); - write_buf.put_u8(b); - Ok(()) - } - } - - impl DeserializeBytes for $bin_type { - fn deserialize(mut read_buf: impl Buf) -> Result { - if read_buf.remaining() < 1 { - ::binius_utils::bail!(SerializationError::NotEnoughBytes); - } - let b: u8 = read_buf.get_u8(); - Ok(Self(SmallU::<$U>::new(b))) - } - } - }; - ($bin_type:ty, $inner_type:ty) => { + ($bin_type:ty) => { impl SerializeBytes for $bin_type { - fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> { - if write_buf.remaining_mut() < (<$inner_type>::BITS / 8) as usize { - ::binius_utils::bail!(SerializationError::WriteBufferFull); - } - write_buf.put_slice(&self.0.to_le_bytes()); - Ok(()) + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + self.0.serialize(write_buf, mode) } } impl DeserializeBytes for $bin_type { - fn deserialize(mut read_buf: impl Buf) -> Result { - let mut inner = <$inner_type>::default().to_le_bytes(); - if read_buf.remaining() < inner.len() { - ::binius_utils::bail!(SerializationError::NotEnoughBytes); - } - read_buf.copy_to_slice(&mut inner); - Ok(Self(<$inner_type>::from_le_bytes(inner))) + fn deserialize( + read_buf: impl Buf, + mode: SerializationMode, + ) -> Result { + Ok(Self(DeserializeBytes::deserialize(read_buf, mode)?)) } } }; } -serialize_deserialize!(BinaryField1b, SmallU<1>); -serialize_deserialize!(BinaryField2b, SmallU<2>); -serialize_deserialize!(BinaryField4b, SmallU<4>); -serialize_deserialize!(BinaryField8b, u8); -serialize_deserialize!(BinaryField16b, u16); -serialize_deserialize!(BinaryField32b, u32); -serialize_deserialize!(BinaryField64b, u64); -serialize_deserialize!(BinaryField128b, u128); - -/// Serializes a [`TowerField`] element to a byte buffer with a canonical encoding. -pub fn serialize_canonical( - elem: F, - mut writer: W, -) -> Result<(), SerializationError> { - F::Canonical::from(elem).serialize(&mut writer) -} - -/// Deserializes a [`TowerField`] element from a byte buffer with a canonical encoding. -pub fn deserialize_canonical( - mut reader: R, -) -> Result { - let as_canonical = F::Canonical::deserialize(&mut reader)?; - Ok(F::from(as_canonical)) -} +serialize_deserialize!(BinaryField1b); +serialize_deserialize!(BinaryField2b); +serialize_deserialize!(BinaryField4b); +serialize_deserialize!(BinaryField8b); +serialize_deserialize!(BinaryField16b); +serialize_deserialize!(BinaryField32b); +serialize_deserialize!(BinaryField64b); +serialize_deserialize!(BinaryField128b); impl From for Choice { fn from(val: BinaryField1b) -> Self { @@ -867,7 +855,7 @@ impl From for u8 { #[cfg(test)] pub(crate) mod tests { - use bytes::BytesMut; + use binius_utils::{bytes::BytesMut, SerializationMode}; use proptest::prelude::*; use super::{ @@ -1236,6 +1224,7 @@ pub(crate) mod tests { #[test] fn test_serialization() { + let mode = SerializationMode::CanonicalTower; let mut buffer = BytesMut::new(); let b1 = BinaryField1b::from(0x1); let b8 = BinaryField8b::new(0x12); @@ -1246,25 +1235,25 @@ pub(crate) mod tests { let b64 = BinaryField64b::new(0x13579BDF02468ACE); let b128 = BinaryField128b::new(0x147AD0369CF258BE8899AABBCCDDEEFF); - b1.serialize(&mut buffer).unwrap(); - b8.serialize(&mut buffer).unwrap(); - b2.serialize(&mut buffer).unwrap(); - b16.serialize(&mut buffer).unwrap(); - b32.serialize(&mut buffer).unwrap(); - b4.serialize(&mut buffer).unwrap(); - b64.serialize(&mut buffer).unwrap(); - b128.serialize(&mut buffer).unwrap(); + b1.serialize(&mut buffer, mode).unwrap(); + b8.serialize(&mut buffer, mode).unwrap(); + b2.serialize(&mut buffer, mode).unwrap(); + b16.serialize(&mut buffer, mode).unwrap(); + b32.serialize(&mut buffer, mode).unwrap(); + b4.serialize(&mut buffer, mode).unwrap(); + b64.serialize(&mut buffer, mode).unwrap(); + b128.serialize(&mut buffer, mode).unwrap(); let mut read_buffer = buffer.freeze(); - assert_eq!(BinaryField1b::deserialize(&mut read_buffer).unwrap(), b1); - assert_eq!(BinaryField8b::deserialize(&mut read_buffer).unwrap(), b8); - assert_eq!(BinaryField2b::deserialize(&mut read_buffer).unwrap(), b2); - assert_eq!(BinaryField16b::deserialize(&mut read_buffer).unwrap(), b16); - assert_eq!(BinaryField32b::deserialize(&mut read_buffer).unwrap(), b32); - assert_eq!(BinaryField4b::deserialize(&mut read_buffer).unwrap(), b4); - assert_eq!(BinaryField64b::deserialize(&mut read_buffer).unwrap(), b64); - assert_eq!(BinaryField128b::deserialize(&mut read_buffer).unwrap(), b128); + assert_eq!(BinaryField1b::deserialize(&mut read_buffer, mode).unwrap(), b1); + assert_eq!(BinaryField8b::deserialize(&mut read_buffer, mode).unwrap(), b8); + assert_eq!(BinaryField2b::deserialize(&mut read_buffer, mode).unwrap(), b2); + assert_eq!(BinaryField16b::deserialize(&mut read_buffer, mode).unwrap(), b16); + assert_eq!(BinaryField32b::deserialize(&mut read_buffer, mode).unwrap(), b32); + assert_eq!(BinaryField4b::deserialize(&mut read_buffer, mode).unwrap(), b4); + assert_eq!(BinaryField64b::deserialize(&mut read_buffer, mode).unwrap(), b64); + assert_eq!(BinaryField128b::deserialize(&mut read_buffer, mode).unwrap(), b128); } #[test] diff --git a/crates/field/src/binary_field_arithmetic.rs b/crates/field/src/binary_field_arithmetic.rs index 0e913f506..481b0af17 100644 --- a/crates/field/src/binary_field_arithmetic.rs +++ b/crates/field/src/binary_field_arithmetic.rs @@ -61,6 +61,10 @@ pub(crate) use impl_arithmetic_using_packed; impl TowerField for BinaryField1b { type Canonical = Self; + fn min_tower_level(self) -> usize { + 0 + } + #[inline] fn mul_primitive(self, _: usize) -> Result { Err(crate::Error::ExtensionDegreeMismatch) diff --git a/crates/field/src/byte_iteration.rs b/crates/field/src/byte_iteration.rs new file mode 100644 index 000000000..e7acda990 --- /dev/null +++ b/crates/field/src/byte_iteration.rs @@ -0,0 +1,440 @@ +// Copyright 2023-2025 Irreducible Inc. + +use std::any::TypeId; + +use bytemuck::Pod; + +use crate::{ + packed::get_packed_slice, AESTowerField128b, AESTowerField16b, AESTowerField32b, + AESTowerField64b, AESTowerField8b, BinaryField128b, BinaryField128bPolyval, BinaryField16b, + BinaryField32b, BinaryField64b, BinaryField8b, ByteSlicedAES32x128b, ByteSlicedAES32x16b, + ByteSlicedAES32x32b, ByteSlicedAES32x64b, ByteSlicedAES32x8b, Field, + PackedAESBinaryField16x16b, PackedAESBinaryField16x32b, PackedAESBinaryField16x8b, + PackedAESBinaryField1x128b, PackedAESBinaryField1x16b, PackedAESBinaryField1x32b, + PackedAESBinaryField1x64b, PackedAESBinaryField1x8b, PackedAESBinaryField2x128b, + PackedAESBinaryField2x16b, PackedAESBinaryField2x32b, PackedAESBinaryField2x64b, + PackedAESBinaryField2x8b, PackedAESBinaryField32x16b, PackedAESBinaryField32x8b, + PackedAESBinaryField4x128b, PackedAESBinaryField4x16b, PackedAESBinaryField4x32b, + PackedAESBinaryField4x64b, PackedAESBinaryField4x8b, PackedAESBinaryField64x8b, + PackedAESBinaryField8x16b, PackedAESBinaryField8x64b, PackedAESBinaryField8x8b, + PackedBinaryField128x1b, PackedBinaryField128x2b, PackedBinaryField128x4b, + PackedBinaryField16x16b, PackedBinaryField16x1b, PackedBinaryField16x2b, + PackedBinaryField16x32b, PackedBinaryField16x4b, PackedBinaryField16x8b, + PackedBinaryField1x128b, PackedBinaryField1x16b, PackedBinaryField1x32b, + PackedBinaryField1x64b, PackedBinaryField1x8b, PackedBinaryField256x1b, + PackedBinaryField256x2b, PackedBinaryField2x128b, PackedBinaryField2x16b, + PackedBinaryField2x32b, PackedBinaryField2x4b, PackedBinaryField2x64b, PackedBinaryField2x8b, + PackedBinaryField32x16b, PackedBinaryField32x1b, PackedBinaryField32x2b, + PackedBinaryField32x4b, PackedBinaryField32x8b, PackedBinaryField4x128b, + PackedBinaryField4x16b, PackedBinaryField4x2b, PackedBinaryField4x32b, PackedBinaryField4x4b, + PackedBinaryField4x64b, PackedBinaryField4x8b, PackedBinaryField512x1b, PackedBinaryField64x1b, + PackedBinaryField64x2b, PackedBinaryField64x4b, PackedBinaryField64x8b, PackedBinaryField8x16b, + PackedBinaryField8x1b, PackedBinaryField8x2b, PackedBinaryField8x32b, PackedBinaryField8x4b, + PackedBinaryField8x64b, PackedBinaryField8x8b, PackedBinaryPolyval1x128b, + PackedBinaryPolyval2x128b, PackedBinaryPolyval4x128b, PackedField, +}; + +/// A marker trait that the slice of packed values can be iterated as a sequence of bytes. +/// The order of the iteration by BinaryField1b subfield elements and bits within iterated bytes must +/// be the same. +/// +/// # Safety +/// The implementor must ensure that the cast of the slice of packed values to the slice of bytes +/// is safe and preserves the order of the 1-bit elements. +#[allow(unused)] +unsafe trait SequentialBytes: Pod {} + +unsafe impl SequentialBytes for BinaryField8b {} +unsafe impl SequentialBytes for BinaryField16b {} +unsafe impl SequentialBytes for BinaryField32b {} +unsafe impl SequentialBytes for BinaryField64b {} +unsafe impl SequentialBytes for BinaryField128b {} + +unsafe impl SequentialBytes for PackedBinaryField8x1b {} +unsafe impl SequentialBytes for PackedBinaryField16x1b {} +unsafe impl SequentialBytes for PackedBinaryField32x1b {} +unsafe impl SequentialBytes for PackedBinaryField64x1b {} +unsafe impl SequentialBytes for PackedBinaryField128x1b {} +unsafe impl SequentialBytes for PackedBinaryField256x1b {} +unsafe impl SequentialBytes for PackedBinaryField512x1b {} + +unsafe impl SequentialBytes for PackedBinaryField4x2b {} +unsafe impl SequentialBytes for PackedBinaryField8x2b {} +unsafe impl SequentialBytes for PackedBinaryField16x2b {} +unsafe impl SequentialBytes for PackedBinaryField32x2b {} +unsafe impl SequentialBytes for PackedBinaryField64x2b {} +unsafe impl SequentialBytes for PackedBinaryField128x2b {} +unsafe impl SequentialBytes for PackedBinaryField256x2b {} + +unsafe impl SequentialBytes for PackedBinaryField2x4b {} +unsafe impl SequentialBytes for PackedBinaryField4x4b {} +unsafe impl SequentialBytes for PackedBinaryField8x4b {} +unsafe impl SequentialBytes for PackedBinaryField16x4b {} +unsafe impl SequentialBytes for PackedBinaryField32x4b {} +unsafe impl SequentialBytes for PackedBinaryField64x4b {} +unsafe impl SequentialBytes for PackedBinaryField128x4b {} + +unsafe impl SequentialBytes for PackedBinaryField1x8b {} +unsafe impl SequentialBytes for PackedBinaryField2x8b {} +unsafe impl SequentialBytes for PackedBinaryField4x8b {} +unsafe impl SequentialBytes for PackedBinaryField8x8b {} +unsafe impl SequentialBytes for PackedBinaryField16x8b {} +unsafe impl SequentialBytes for PackedBinaryField32x8b {} +unsafe impl SequentialBytes for PackedBinaryField64x8b {} + +unsafe impl SequentialBytes for PackedBinaryField1x16b {} +unsafe impl SequentialBytes for PackedBinaryField2x16b {} +unsafe impl SequentialBytes for PackedBinaryField4x16b {} +unsafe impl SequentialBytes for PackedBinaryField8x16b {} +unsafe impl SequentialBytes for PackedBinaryField16x16b {} +unsafe impl SequentialBytes for PackedBinaryField32x16b {} + +unsafe impl SequentialBytes for PackedBinaryField1x32b {} +unsafe impl SequentialBytes for PackedBinaryField2x32b {} +unsafe impl SequentialBytes for PackedBinaryField4x32b {} +unsafe impl SequentialBytes for PackedBinaryField8x32b {} +unsafe impl SequentialBytes for PackedBinaryField16x32b {} + +unsafe impl SequentialBytes for PackedBinaryField1x64b {} +unsafe impl SequentialBytes for PackedBinaryField2x64b {} +unsafe impl SequentialBytes for PackedBinaryField4x64b {} +unsafe impl SequentialBytes for PackedBinaryField8x64b {} + +unsafe impl SequentialBytes for PackedBinaryField1x128b {} +unsafe impl SequentialBytes for PackedBinaryField2x128b {} +unsafe impl SequentialBytes for PackedBinaryField4x128b {} + +unsafe impl SequentialBytes for AESTowerField8b {} +unsafe impl SequentialBytes for AESTowerField16b {} +unsafe impl SequentialBytes for AESTowerField32b {} +unsafe impl SequentialBytes for AESTowerField64b {} +unsafe impl SequentialBytes for AESTowerField128b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField8x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField16x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField32x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField64x8b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField8x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField16x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField32x16b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x32b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x32b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x32b {} +unsafe impl SequentialBytes for PackedAESBinaryField16x32b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x64b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x64b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x64b {} +unsafe impl SequentialBytes for PackedAESBinaryField8x64b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x128b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x128b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x128b {} + +unsafe impl SequentialBytes for BinaryField128bPolyval {} + +unsafe impl SequentialBytes for PackedBinaryPolyval1x128b {} +unsafe impl SequentialBytes for PackedBinaryPolyval2x128b {} +unsafe impl SequentialBytes for PackedBinaryPolyval4x128b {} + +/// Returns true if T implements `SequentialBytes` trait. +/// Use a hack that exploits that array copying is optimized for the `Copy` types. +/// Unfortunately there is no more proper way to perform this check this in Rust at runtime. +#[inline(always)] +#[allow(clippy::redundant_clone)] // this is intentional in this method +pub fn is_sequential_bytes() -> bool { + struct X(bool, std::marker::PhantomData); + + impl Clone for X { + fn clone(&self) -> Self { + Self(false, std::marker::PhantomData) + } + } + + impl Copy for X {} + + let value = [X::(true, std::marker::PhantomData)]; + let cloned = value.clone(); + + cloned[0].0 +} + +/// Returns if we can iterate over bytes, each representing 8 1-bit values. +pub fn can_iterate_bytes() -> bool { + // Packed fields with sequential byte order + if is_sequential_bytes::

() { + return true; + } + + // Byte-sliced fields + // Note: add more byte sliced types here as soon as they are added + match TypeId::of::

() { + x if x == TypeId::of::() => true, + x if x == TypeId::of::() => true, + x if x == TypeId::of::() => true, + x if x == TypeId::of::() => true, + x if x == TypeId::of::() => true, + _ => false, + } +} + +/// Helper macro to generate the iteration over bytes for byte-sliced types. +macro_rules! iterate_byte_sliced { + ($packed_type:ty, $data:ident, $callback:ident) => { + assert_eq!(TypeId::of::<$packed_type>(), TypeId::of::

()); + + // Safety: the cast is safe because the type is checked by arm statement + let data = unsafe { + std::slice::from_raw_parts($data.as_ptr() as *const $packed_type, $data.len()) + }; + let iter = data.iter().flat_map(|value| { + (0..<$packed_type>::BYTES).map(move |i| unsafe { value.get_byte_unchecked(i) }) + }); + + $callback.call(iter); + }; +} + +/// Callback for byte iteration. +/// We can't return different types from the `iterate_bytes` and Fn traits don't support associated types +/// that's why we use a callback with a generic function. +pub trait ByteIteratorCallback { + fn call(&mut self, iter: impl Iterator); +} + +/// Iterate over bytes of a slice of the packed values. +/// The method panics if the packed field doesn't support byte iteration, so use `can_iterate_bytes` to check it. +#[inline(always)] +pub fn iterate_bytes(data: &[P], callback: &mut impl ByteIteratorCallback) { + if is_sequential_bytes::

() { + // Safety: `P` implements `SequentialBytes` trait, so the following cast is safe + // and preserves the order. + let bytes = unsafe { + std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) + }; + callback.call(bytes.iter().copied()); + } else { + // Note: add more byte sliced types here as soon as they are added + match TypeId::of::

() { + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x128b, data, callback); + } + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x64b, data, callback); + } + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x32b, data, callback); + } + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x16b, data, callback); + } + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x8b, data, callback); + } + _ => unreachable!("packed field doesn't support byte iteration"), + } + } +} + +/// Scalars collection abstraction. +/// This trait is used to abstract over different types of collections of field elements. +pub trait ScalarsCollection { + fn len(&self) -> usize; + fn get(&self, i: usize) -> T; + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl ScalarsCollection for &[F] { + #[inline(always)] + fn len(&self) -> usize { + <[F]>::len(self) + } + + #[inline(always)] + fn get(&self, i: usize) -> F { + self[i] + } +} + +pub struct PackedSlice<'a, P: PackedField> { + slice: &'a [P], + len: usize, +} + +impl<'a, P: PackedField> PackedSlice<'a, P> { + #[inline(always)] + pub const fn new(slice: &'a [P], len: usize) -> Self { + Self { slice, len } + } +} + +impl ScalarsCollection for PackedSlice<'_, P> { + #[inline(always)] + fn len(&self) -> usize { + self.len + } + + #[inline(always)] + fn get(&self, i: usize) -> P::Scalar { + get_packed_slice(self.slice, i) + } +} + +/// Create a lookup table for partial sums of 8 consequent elements with coefficients corresponding to bits in a byte. +/// The lookup table has the following structure: +/// [ +/// partial_sum_chunk_0_7_byte_0, partial_sum_chunk_0_7_byte_1, ..., partial_sum_chunk_0_7_byte_255, +/// partial_sum_chunk_8_15_byte_0, partial_sum_chunk_8_15_byte_1, ..., partial_sum_chunk_8_15_byte_255, +/// ... +/// ] +pub fn create_partial_sums_lookup_tables( + values: impl ScalarsCollection

, +) -> Vec

{ + let len = values.len(); + assert!(len % 8 == 0); + + let mut result = Vec::with_capacity(len * 256 / 8); + for chunk_i in 0..len / 8 { + let offset = chunk_i * 8; + for i in 0..256 { + let mut sum = P::zero(); + for j in 0..8 { + if i & (1 << j) != 0 { + sum += values.get(offset + j); + } + } + result.push(sum); + } + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{PackedBinaryField1x1b, PackedBinaryField2x1b, PackedBinaryField4x1b}; + + #[test] + fn test_sequential_bits() { + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + } +} diff --git a/crates/field/src/extension.rs b/crates/field/src/extension.rs index 45ca2e368..84a660aaa 100644 --- a/crates/field/src/extension.rs +++ b/crates/field/src/extension.rs @@ -18,9 +18,6 @@ pub trait ExtensionField: + SubAssign + MulAssign { - /// Iterator returned by `iter_bases`. - type Iterator: Iterator; - /// Base-2 logarithm of the extension degree. const LOG_DEGREE: usize; @@ -47,12 +44,13 @@ pub trait ExtensionField: fn from_bases_sparse(base_elems: &[F], log_stride: usize) -> Result; /// Iterator over base field elements. - fn iter_bases(&self) -> Self::Iterator; + fn iter_bases(&self) -> impl Iterator; + + /// Convert into an iterator over base field elements. + fn into_iter_bases(self) -> impl Iterator; } impl ExtensionField for F { - type Iterator = iter::Once; - const LOG_DEGREE: usize = 0; fn basis(i: usize) -> Result { @@ -74,7 +72,11 @@ impl ExtensionField for F { } } - fn iter_bases(&self) -> Self::Iterator { + fn iter_bases(&self) -> impl Iterator { iter::once(*self) } + + fn into_iter_bases(self) -> impl Iterator { + iter::once(self) + } } diff --git a/crates/field/src/field.rs b/crates/field/src/field.rs index 10395e9c2..4a5469516 100644 --- a/crates/field/src/field.rs +++ b/crates/field/src/field.rs @@ -7,6 +7,7 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +use binius_utils::{DeserializeBytes, SerializeBytes}; use rand::RngCore; use crate::{ @@ -49,6 +50,8 @@ pub trait Field: + InvertOrZero // `Underlier: PackScalar` is an obvious property but it can't be deduced by the compiler so we are id here. + WithUnderlier> + + SerializeBytes + + DeserializeBytes { /// The zero element of the field, the additive identity. const ZERO: Self; diff --git a/crates/field/src/lib.rs b/crates/field/src/lib.rs index cae1070da..76414c038 100644 --- a/crates/field/src/lib.rs +++ b/crates/field/src/lib.rs @@ -20,6 +20,7 @@ pub mod arithmetic_traits; pub mod as_packed_field; pub mod binary_field; mod binary_field_arithmetic; +pub mod byte_iteration; pub mod error; pub mod extension; pub mod field; diff --git a/crates/field/src/packed.rs b/crates/field/src/packed.rs index 7c0b9ad8f..9139b67f0 100644 --- a/crates/field/src/packed.rs +++ b/crates/field/src/packed.rs @@ -20,8 +20,7 @@ use super::{ Error, }; use crate::{ - arithmetic_traits::InvertOrZero, underlier::WithUnderlier, BinaryField, ExtensionField, Field, - PackedExtension, + arithmetic_traits::InvertOrZero, underlier::WithUnderlier, BinaryField, Field, PackedExtension, }; /// A packed field represents a vector of underlying field elements. @@ -216,6 +215,20 @@ pub trait PackedField: /// * `log_block_len` must be strictly less than `LOG_WIDTH`. fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self); + /// Unzips interleaved blocks of this packed vector with another packed vector. + /// + /// Consider this example, where `LOG_WIDTH` is 3 and `log_block_len` is 1: + /// A = [a0, a1, b0, b1, a2, a3, b2, b3] + /// B = [a4, a5, b4, b5, a6, a7, b6, b7] + /// + /// The transposed result is + /// A' = [a0, a1, a2, a3, a4, a5, a6, a7] + /// B' = [b0, b1, b2, b3, b4, b5, b6, b7] + /// + /// ## Preconditions + /// * `log_block_len` must be strictly less than `LOG_WIDTH`. + fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self); + /// Spread takes a block of elements within a packed field and repeats them to the full packing /// width. /// @@ -357,11 +370,7 @@ pub const fn len_packed_slice(packed: &[P]) -> usize { } /// Multiply packed field element by a subfield scalar. -pub fn mul_by_subfield_scalar(val: P, multiplier: FS) -> P -where - P: PackedExtension>, - FS: Field, -{ +pub fn mul_by_subfield_scalar, FS: Field>(val: P, multiplier: FS) -> P { use crate::underlier::UnderlierType; // This is a workaround not to make the multiplication slower in certain cases. @@ -376,6 +385,14 @@ where } } +pub fn pack_slice(scalars: &[P::Scalar]) -> Vec

{ + let mut packed_slice = vec![P::default(); scalars.len() / P::WIDTH]; + for (i, scalar) in scalars.iter().enumerate() { + set_packed_slice(&mut packed_slice, i, *scalar); + } + packed_slice +} + impl Broadcast for F { fn broadcast(scalar: F) -> Self { scalar @@ -427,6 +444,10 @@ impl PackedField for F { panic!("cannot interleave when WIDTH = 1"); } + fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) { + panic!("cannot transpose when WIDTH = 1"); + } + fn broadcast(scalar: Self::Scalar) -> Self { scalar } @@ -495,7 +516,7 @@ mod tests { } /// Run the test for all the packed fields defined in this crate. - fn run_for_all_packed_fields(test: impl PackedFieldTest) { + fn run_for_all_packed_fields(test: &impl PackedFieldTest) { // canonical tower test.run::(); @@ -665,6 +686,6 @@ mod tests { #[test] fn test_iteration() { - run_for_all_packed_fields(PackedFieldIterationTest); + run_for_all_packed_fields(&PackedFieldIterationTest); } } diff --git a/crates/field/src/packed_binary_field.rs b/crates/field/src/packed_binary_field.rs index b22f48fef..152cbf6bd 100644 --- a/crates/field/src/packed_binary_field.rs +++ b/crates/field/src/packed_binary_field.rs @@ -807,6 +807,63 @@ pub mod test_utils { check_interleave::

(lhs, rhs, log_block_len); } } + + pub fn check_unzip( + lhs: P::Underlier, + rhs: P::Underlier, + log_block_len: usize, + ) { + let lhs = P::from_underlier(lhs); + let rhs = P::from_underlier(rhs); + let block_len = 1 << log_block_len; + let (a, b) = lhs.unzip(rhs, log_block_len); + for i in (0..P::WIDTH / 2).step_by(block_len) { + for j in 0..block_len { + assert_eq!( + a.get(i + j), + lhs.get(2 * i + j), + "i: {}, j: {}, log_block_len: {}, P: {:?}", + i, + j, + log_block_len, + P::zero() + ); + assert_eq!( + b.get(i + j), + lhs.get(2 * i + j + block_len), + "i: {}, j: {}, log_block_len: {}, P: {:?}", + i, + j, + log_block_len, + P::zero() + ); + } + } + + for i in (0..P::WIDTH / 2).step_by(block_len) { + for j in 0..block_len { + assert_eq!( + a.get(i + j + P::WIDTH / 2), + rhs.get(2 * i + j), + "i: {}, j: {}, log_block_len: {}, P: {:?}", + i, + j, + log_block_len, + P::zero() + ); + assert_eq!(b.get(i + j + P::WIDTH / 2), rhs.get(2 * i + j + block_len)); + } + } + } + + pub fn check_transpose_all_heights( + lhs: P::Underlier, + rhs: P::Underlier, + ) { + for log_block_len in 0..P::LOG_WIDTH { + check_unzip::

(lhs, rhs, log_block_len); + } + } } #[cfg(test)] @@ -831,6 +888,7 @@ mod tests { }, arithmetic_traits::MulAlpha, linear_transformation::PackedTransformationFactory, + test_utils::check_transpose_all_heights, underlier::{U2, U4}, Field, PackedField, PackedFieldIndexable, }; @@ -1206,5 +1264,93 @@ mod tests { check_interleave_all_heights::(a_val.into(), b_val.into()); check_interleave_all_heights::(a_val.into(), b_val.into()); } + + #[test] + fn check_transpose_2b(a_val in 0u8..3, b_val in 0u8..3) { + check_transpose_all_heights::(U2::new(a_val), U2::new(b_val)); + check_transpose_all_heights::(U2::new(a_val), U2::new(b_val)); + } + + #[test] + fn check_transpose_4b(a_val in 0u8..16, b_val in 0u8..16) { + check_transpose_all_heights::(U4::new(a_val), U4::new(b_val)); + check_transpose_all_heights::(U4::new(a_val), U4::new(b_val)); + check_transpose_all_heights::(U4::new(a_val), U4::new(b_val)); + } + + #[test] + fn check_transpose_8b(a_val in 0u8.., b_val in 0u8..) { + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + } + + #[test] + fn check_transpose_16b(a_val in 0u16.., b_val in 0u16..) { + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + } + + #[test] + fn check_transpose_32b(a_val in 0u32.., b_val in 0u32..) { + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + } + + #[test] + fn check_transpose_64b(a_val in 0u64.., b_val in 0u64..) { + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + } + + #[test] + #[allow(clippy::useless_conversion)] // this warning depends on the target platform + fn check_transpose_128b(a_val in 0u128.., b_val in 0u128..) { + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + } + + #[test] + fn check_transpose_256b(a_val in any::<[u128; 2]>(), b_val in any::<[u128; 2]>()) { + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + } + + #[test] + fn check_transpose_512b(a_val in any::<[u128; 4]>(), b_val in any::<[u128; 4]>()) { + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + } } } diff --git a/crates/field/src/packed_extension.rs b/crates/field/src/packed_extension.rs index 3a2b11a2a..8f2aaf028 100644 --- a/crates/field/src/packed_extension.rs +++ b/crates/field/src/packed_extension.rs @@ -54,12 +54,12 @@ where /// PE: PackedField>, /// F: Field, /// { -/// packed.iter().flat_map(|ext| ext.iter_bases()) +/// packed.iter().flat_map(|ext| ext.into_iter_bases()) /// } /// /// fn cast_then_iter<'a, F, PE>(packed: &'a PE) -> impl Iterator + 'a /// where -/// PE: PackedExtension>, +/// PE: PackedExtension, /// F: Field, /// { /// PE::cast_base_ref(packed).into_iter() @@ -71,10 +71,7 @@ where /// In order for the above relation to be guaranteed, the memory representation of /// `PackedExtensionField` element must be the same as a slice of the underlying `PackedField` /// element. -pub trait PackedExtension: PackedField -where - Self::Scalar: ExtensionField, -{ +pub trait PackedExtension: PackedField> { type PackedSubfield: PackedField; fn cast_bases(packed: &[Self]) -> &[Self::PackedSubfield]; @@ -187,9 +184,7 @@ where /// This trait is a shorthand for the case `PackedExtension` which is a /// quite common case in our codebase. pub trait RepackedExtension: - PackedExtension -where - Self::Scalar: ExtensionField, + PackedField> + PackedExtension { } @@ -202,10 +197,8 @@ where /// This trait adds shortcut methods for the case `PackedExtension` which is a /// quite common case in our codebase. -pub trait PackedExtensionIndexable: PackedExtension -where - Self::Scalar: ExtensionField, - Self::PackedSubfield: PackedFieldIndexable, +pub trait PackedExtensionIndexable: + PackedExtension + PackedField> { fn unpack_base_scalars(packed: &[Self]) -> &[F] { Self::PackedSubfield::unpack_scalars(Self::cast_bases(packed)) @@ -219,7 +212,7 @@ where impl PackedExtensionIndexable for PT where F: Field, - PT: PackedExtension, PackedSubfield: PackedFieldIndexable>, + PT: PackedExtension, { } diff --git a/crates/field/src/packed_extension_ops.rs b/crates/field/src/packed_extension_ops.rs index 6d2ede809..30fad681f 100644 --- a/crates/field/src/packed_extension_ops.rs +++ b/crates/field/src/packed_extension_ops.rs @@ -6,21 +6,17 @@ use binius_maybe_rayon::prelude::{ use crate::{Error, ExtensionField, Field, PackedExtension, PackedField}; -pub fn ext_base_mul(lhs: &mut [PE], rhs: &[PE::PackedSubfield]) -> Result<(), Error> -where - PE: PackedExtension, - PE::Scalar: ExtensionField, - F: Field, -{ +pub fn ext_base_mul, F: Field>( + lhs: &mut [PE], + rhs: &[PE::PackedSubfield], +) -> Result<(), Error> { ext_base_op(lhs, rhs, |_, lhs, broadcasted_rhs| PE::cast_ext(lhs.cast_base() * broadcasted_rhs)) } -pub fn ext_base_mul_par(lhs: &mut [PE], rhs: &[PE::PackedSubfield]) -> Result<(), Error> -where - PE: PackedExtension, - PE::Scalar: ExtensionField, - F: Field, -{ +pub fn ext_base_mul_par, F: Field>( + lhs: &mut [PE], + rhs: &[PE::PackedSubfield], +) -> Result<(), Error> { ext_base_op_par(lhs, rhs, |_, lhs, broadcasted_rhs| { PE::cast_ext(lhs.cast_base() * broadcasted_rhs) }) @@ -29,15 +25,10 @@ where /// # Safety /// /// Width of PackedSubfield is >= the width of the field implementing PackedExtension. -pub unsafe fn get_packed_subfields_at_pe_idx( +pub unsafe fn get_packed_subfields_at_pe_idx, F: Field>( packed_subfields: &[PE::PackedSubfield], i: usize, -) -> PE::PackedSubfield -where - PE: PackedExtension, - PE::Scalar: ExtensionField, - F: Field, -{ +) -> PE::PackedSubfield { let bottom_most_scalar_idx = i * PE::WIDTH; let bottom_most_scalar_idx_in_subfield_arr = bottom_most_scalar_idx / PE::PackedSubfield::WIDTH; let bottom_most_scalar_idx_within_packed_subfield = @@ -67,7 +58,6 @@ pub fn ext_base_op( ) -> Result<(), Error> where PE: PackedExtension, - PE::Scalar: ExtensionField, F: Field, Func: Fn(usize, PE, PE::PackedSubfield) -> PE, { @@ -93,7 +83,6 @@ pub fn ext_base_op_par( ) -> Result<(), Error> where PE: PackedExtension, - PE::Scalar: ExtensionField, F: Field, Func: Fn(usize, PE, PE::PackedSubfield) -> PE + std::marker::Sync, { @@ -117,10 +106,10 @@ mod tests { use crate::{ ext_base_mul, ext_base_mul_par, - packed::{get_packed_slice, set_packed_slice}, + packed::{get_packed_slice, pack_slice}, underlier::WithUnderlier, BinaryField128b, BinaryField16b, BinaryField8b, PackedBinaryField16x16b, - PackedBinaryField2x128b, PackedBinaryField32x8b, PackedField, + PackedBinaryField2x128b, PackedBinaryField32x8b, }; fn strategy_8b_scalars() -> impl Strategy { @@ -138,16 +127,6 @@ mod tests { .prop_map(|arr| arr.map(::from_underlier)) } - fn pack_slice(scalar_slice: &[P::Scalar]) -> Vec

{ - let mut packed_slice = vec![P::default(); scalar_slice.len() / P::WIDTH]; - - for (i, scalar) in scalar_slice.iter().enumerate() { - set_packed_slice(&mut packed_slice, i, *scalar); - } - - packed_slice - } - proptest! { #[test] fn test_base_ext_mul_8(base_scalars in strategy_8b_scalars(), ext_scalars in strategy_128b_scalars()){ diff --git a/crates/field/src/polyval.rs b/crates/field/src/polyval.rs index 989e12eec..cd3d944a6 100644 --- a/crates/field/src/polyval.rs +++ b/crates/field/src/polyval.rs @@ -4,12 +4,16 @@ use std::{ any::TypeId, - array, fmt::{self, Debug, Display, Formatter}, iter::{Product, Sum}, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +use binius_utils::{ + bytes::{Buf, BufMut}, + iter::IterExtensions, + DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, +}; use bytemuck::{Pod, TransparentWrapper, Zeroable}; use rand::{Rng, RngCore}; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; @@ -29,7 +33,7 @@ use crate::{ invert_or_zero_using_packed, multiple_using_packed, square_using_packed, }, linear_transformation::{FieldLinearTransformation, Transformation}, - underlier::UnderlierWithBitOps, + underlier::{IterationMethods, IterationStrategy, UnderlierWithBitOps, U1}, Field, }; @@ -414,7 +418,6 @@ impl Mul for BinaryField1b { } impl ExtensionField for BinaryField128bPolyval { - type Iterator = <[BinaryField1b; 128] as IntoIterator>::IntoIter; const LOG_DEGREE: usize = 7; #[inline] @@ -439,9 +442,44 @@ impl ExtensionField for BinaryField128bPolyval { } #[inline] - fn iter_bases(&self) -> Self::Iterator { - let base_elems = array::from_fn(|i| BinaryField1b::from((self.0 >> i) as u8)); - base_elems.into_iter() + fn iter_bases(&self) -> impl Iterator { + IterationMethods::::value_iter(self.0) + .map_skippable(BinaryField1b::from) + } + + #[inline] + fn into_iter_bases(self) -> impl Iterator { + IterationMethods::::value_iter(self.0) + .map_skippable(BinaryField1b::from) + } +} + +impl SerializeBytes for BinaryField128bPolyval { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + match mode { + SerializationMode::Native => self.0.serialize(write_buf, mode), + SerializationMode::CanonicalTower => { + BinaryField128b::from(*self).serialize(write_buf, mode) + } + } + } +} + +impl DeserializeBytes for BinaryField128bPolyval { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + match mode { + SerializationMode::Native => Ok(Self(DeserializeBytes::deserialize(read_buf, mode)?)), + SerializationMode::CanonicalTower => { + Ok(Self::from(BinaryField128b::deserialize(read_buf, mode)?)) + } + } } } @@ -452,6 +490,13 @@ impl BinaryField for BinaryField128bPolyval { impl TowerField for BinaryField128bPolyval { type Canonical = BinaryField128b; + fn min_tower_level(self) -> usize { + match self { + Self::ZERO | Self::ONE => 0, + _ => 7, + } + } + fn mul_primitive(self, _iota: usize) -> Result { // This method could be implemented by multiplying by isomorphic alpha value // But it's not being used as for now @@ -1019,7 +1064,7 @@ pub fn is_polyval_tower() -> bool { #[cfg(test)] mod tests { - use bytes::BytesMut; + use binius_utils::{bytes::BytesMut, SerializationMode, SerializeBytes}; use proptest::prelude::*; use rand::thread_rng; @@ -1030,11 +1075,10 @@ mod tests { packed_polyval_512::PackedBinaryPolyval4x128b, }, binary_field::tests::is_binary_field_valid_generator, - deserialize_canonical, linear_transformation::PackedTransformationFactory, - serialize_canonical, AESTowerField128b, PackedAESBinaryField1x128b, - PackedAESBinaryField2x128b, PackedAESBinaryField4x128b, PackedBinaryField1x128b, - PackedBinaryField2x128b, PackedBinaryField4x128b, PackedField, + AESTowerField128b, PackedAESBinaryField1x128b, PackedAESBinaryField2x128b, + PackedAESBinaryField4x128b, PackedBinaryField1x128b, PackedBinaryField2x128b, + PackedBinaryField4x128b, PackedField, }; #[test] @@ -1177,25 +1221,25 @@ mod tests { #[test] fn test_canonical_serialization() { + let mode = SerializationMode::CanonicalTower; let mut buffer = BytesMut::new(); let mut rng = thread_rng(); let b128_poly1 = ::random(&mut rng); let b128_poly2 = ::random(&mut rng); - serialize_canonical(b128_poly1, &mut buffer).unwrap(); - serialize_canonical(b128_poly2, &mut buffer).unwrap(); + SerializeBytes::serialize(&b128_poly1, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&b128_poly2, &mut buffer, mode).unwrap(); + let mode = SerializationMode::CanonicalTower; let mut read_buffer = buffer.freeze(); assert_eq!( - deserialize_canonical::(&mut read_buffer).unwrap(), + BinaryField128bPolyval::deserialize(&mut read_buffer, mode).unwrap(), b128_poly1 ); assert_eq!( - BinaryField128bPolyval::from( - deserialize_canonical::(&mut read_buffer).unwrap() - ), + BinaryField128bPolyval::deserialize(&mut read_buffer, mode).unwrap(), b128_poly2 ); } diff --git a/crates/field/src/tower_levels.rs b/crates/field/src/tower_levels.rs index bdac60d06..bfddced06 100644 --- a/crates/field/src/tower_levels.rs +++ b/crates/field/src/tower_levels.rs @@ -16,110 +16,104 @@ use std::{ /// These separate implementations are necessary to overcome the limitations of const generics in Rust. /// These implementations eliminate costly bounds checking that would otherwise be imposed by the compiler /// and allow easy inlining of recursive functions. -pub trait TowerLevel -where - T: Default + Copy, -{ +pub trait TowerLevel { // WIDTH is ALWAYS a power of 2 const WIDTH: usize; // The underlying Data should ALWAYS be a fixed-width array of T's - type Data: AsMut<[T]> + type Data: AsMut<[T]> + AsRef<[T]> + Sized + Index + IndexMut; - type Base: TowerLevel; + type Base: TowerLevel; - // Split something of type Self::Data into two equal halves + // Split something of type Self::Datainto two equal halves #[allow(clippy::type_complexity)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data); + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data); - // Split something of type Self::Data into two equal mutable halves + // Split something of type Self::Datainto two equal mutable halves #[allow(clippy::type_complexity)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data); + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data); // Join two equal-length arrays (the reverse of split) #[allow(clippy::type_complexity)] - fn join( - first: &>::Data, - second: &>::Data, - ) -> Self::Data; + fn join( + first: &::Data, + second: &::Data, + ) -> Self::Data; // Fills an array of T's containing WIDTH elements - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data; + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data; // Fills an array of T's containing WIDTH elements with T::default() - fn default() -> Self::Data { + fn default() -> Self::Data { Self::from_fn(|_| T::default()) } } -pub trait TowerLevelWithArithOps: TowerLevel -where - T: Default + Add + AddAssign + Copy, -{ +pub trait TowerLevelWithArithOps: TowerLevel { #[inline(always)] - fn add_into(field_element: &Self::Data, destination: &mut Self::Data) { + fn add_into( + field_element: &Self::Data, + destination: &mut Self::Data, + ) { for i in 0..Self::WIDTH { destination[i] += field_element[i]; } } #[inline(always)] - fn copy_into(field_element: &Self::Data, destination: &mut Self::Data) { + fn copy_into(field_element: &Self::Data, destination: &mut Self::Data) { for i in 0..Self::WIDTH { destination[i] = field_element[i]; } } #[inline(always)] - fn sum(field_element_a: &Self::Data, field_element_b: &Self::Data) -> Self::Data { + fn sum>( + field_element_a: &Self::Data, + field_element_b: &Self::Data, + ) -> Self::Data { Self::from_fn(|i| field_element_a[i] + field_element_b[i]) } } -impl> TowerLevelWithArithOps for U where - T: Default + Add + AddAssign + Copy -{ -} +impl TowerLevelWithArithOps for T {} pub struct TowerLevel64; -impl TowerLevel for TowerLevel64 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel64 { const WIDTH: usize = 64; - type Data = [T; 64]; + type Data = [T; 64]; type Base = TowerLevel32; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..32].try_into().unwrap()), (data[32..64].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(32); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 64]; result[..32].copy_from_slice(left); result[32..].copy_from_slice(right); @@ -127,43 +121,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel32; -impl TowerLevel for TowerLevel32 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel32 { const WIDTH: usize = 32; - type Data = [T; 32]; + type Data = [T; 32]; type Base = TowerLevel16; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..16].try_into().unwrap()), (data[16..32].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(16); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 32]; result[..16].copy_from_slice(left); result[16..].copy_from_slice(right); @@ -171,43 +162,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel16; -impl TowerLevel for TowerLevel16 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel16 { const WIDTH: usize = 16; - type Data = [T; 16]; + type Data = [T; 16]; type Base = TowerLevel8; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..8].try_into().unwrap()), (data[8..16].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(8); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 16]; result[..8].copy_from_slice(left); result[8..].copy_from_slice(right); @@ -215,43 +203,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel8; -impl TowerLevel for TowerLevel8 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel8 { const WIDTH: usize = 8; - type Data = [T; 8]; + type Data = [T; 8]; type Base = TowerLevel4; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..4].try_into().unwrap()), (data[4..8].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(4); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 8]; result[..4].copy_from_slice(left); result[4..].copy_from_slice(right); @@ -259,43 +244,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel4; -impl TowerLevel for TowerLevel4 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel4 { const WIDTH: usize = 4; - type Data = [T; 4]; + type Data = [T; 4]; type Base = TowerLevel2; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..2].try_into().unwrap()), (data[2..4].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(2); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 4]; result[..2].copy_from_slice(left); result[2..].copy_from_slice(right); @@ -303,43 +285,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel2; -impl TowerLevel for TowerLevel2 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel2 { const WIDTH: usize = 2; - type Data = [T; 2]; + type Data = [T; 2]; type Base = TowerLevel1; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..1].try_into().unwrap()), (data[1..2].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(1); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 2]; result[..1].copy_from_slice(left); result[1..].copy_from_slice(right); @@ -347,48 +326,45 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel1; -impl TowerLevel for TowerLevel1 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel1 { const WIDTH: usize = 1; - type Data = [T; 1]; + type Data = [T; 1]; type Base = Self; // Level 1 is the atomic unit of backing data and must not be split. #[inline(always)] - fn split( - _data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + _data: &Self::Data, + ) -> (&::Data, &::Data) { unreachable!() } #[inline(always)] - fn split_mut( - _data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + _data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { unreachable!() } #[inline(always)] - fn join<'a>( - _left: &>::Data, - _right: &>::Data, - ) -> Self::Data { + fn join( + _left: &::Data, + _right: &::Data, + ) -> Self::Data { unreachable!() } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } diff --git a/crates/field/src/transpose.rs b/crates/field/src/transpose.rs index efabbfc60..b3838b57a 100644 --- a/crates/field/src/transpose.rs +++ b/crates/field/src/transpose.rs @@ -2,7 +2,7 @@ use binius_utils::checked_arithmetics::log2_strict_usize; -use super::{packed::PackedField, ExtensionField, PackedFieldIndexable, RepackedExtension}; +use super::{packed::PackedField, Field, PackedFieldIndexable, RepackedExtension}; /// Error thrown when a transpose operation fails. #[derive(Clone, thiserror::Error, Debug)] @@ -76,7 +76,7 @@ pub fn square_transpose(log_n: usize, elems: &mut [P]) -> Result pub fn transpose_scalars(src: &[PE], dst: &mut [P]) -> Result<(), Error> where P: PackedField, - FE: ExtensionField, + FE: Field, PE: PackedFieldIndexable + RepackedExtension

, { let len = src.len(); diff --git a/crates/field/src/underlier/scaled.rs b/crates/field/src/underlier/scaled.rs index 4210bc653..40cabc5c2 100644 --- a/crates/field/src/underlier/scaled.rs +++ b/crates/field/src/underlier/scaled.rs @@ -1,13 +1,16 @@ // Copyright 2024-2025 Irreducible Inc. -use std::array; +use std::{ + array, + ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr}, +}; use binius_utils::checked_arithmetics::checked_log_2; use bytemuck::{must_cast_mut, must_cast_ref, NoUninit, Pod, Zeroable}; use rand::RngCore; use subtle::{Choice, ConstantTimeEq}; -use super::{Divisible, Random, UnderlierType}; +use super::{Divisible, Random, UnderlierType, UnderlierWithBitOps}; /// A type that represents a pair of elements of the same underlier type. /// We use it as an underlier for the `ScaledPAckedField` type. @@ -104,3 +107,189 @@ where must_cast_mut::(self) } } + +impl + Copy, const N: usize> BitAnd for ScaledUnderlier { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + Self(array::from_fn(|i| self.0[i] & rhs.0[i])) + } +} + +impl BitAndAssign for ScaledUnderlier { + fn bitand_assign(&mut self, rhs: Self) { + for i in 0..N { + self.0[i] &= rhs.0[i]; + } + } +} + +impl + Copy, const N: usize> BitOr for ScaledUnderlier { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + Self(array::from_fn(|i| self.0[i] | rhs.0[i])) + } +} + +impl BitOrAssign for ScaledUnderlier { + fn bitor_assign(&mut self, rhs: Self) { + for i in 0..N { + self.0[i] |= rhs.0[i]; + } + } +} + +impl + Copy, const N: usize> BitXor for ScaledUnderlier { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self::Output { + Self(array::from_fn(|i| self.0[i] ^ rhs.0[i])) + } +} + +impl BitXorAssign for ScaledUnderlier { + fn bitxor_assign(&mut self, rhs: Self) { + for i in 0..N { + self.0[i] ^= rhs.0[i]; + } + } +} + +impl Shr for ScaledUnderlier { + type Output = Self; + + fn shr(self, rhs: usize) -> Self::Output { + let mut result = Self::default(); + + let shift_in_items = rhs / U::BITS; + for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) { + if i + shift_in_items < N { + result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS); + } + if i + shift_in_items + 1 < N && rhs % U::BITS != 0 { + result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS)); + } + } + + result + } +} + +impl Shl for ScaledUnderlier { + type Output = Self; + + fn shl(self, rhs: usize) -> Self::Output { + let mut result = Self::default(); + + let shift_in_items = rhs / U::BITS; + for i in shift_in_items.saturating_sub(1)..N { + if i >= shift_in_items { + result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS); + } + if i > shift_in_items && rhs % U::BITS != 0 { + result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS)); + } + } + + result + } +} + +impl, const N: usize> Not for ScaledUnderlier { + type Output = Self; + + fn not(self) -> Self::Output { + Self(self.0.map(U::not)) + } +} + +impl UnderlierWithBitOps for ScaledUnderlier { + const ZERO: Self = Self([U::ZERO; N]); + const ONE: Self = { + let mut arr = [U::ZERO; N]; + arr[0] = U::ONE; + Self(arr) + }; + const ONES: Self = Self([U::ONES; N]); + + #[inline] + fn fill_with_bit(val: u8) -> Self { + Self(array::from_fn(|_| U::fill_with_bit(val))) + } + + #[inline] + fn shl_128b_lanes(self, rhs: usize) -> Self { + // We assume that the underlier type has at least 128 bits as the current implementation + // is valid for this case only. + // On practice, we don't use scaled underliers with underlier types that have less than 128 bits. + assert!(U::BITS >= 128); + + Self(self.0.map(|x| x.shl_128b_lanes(rhs))) + } + + #[inline] + fn shr_128b_lanes(self, rhs: usize) -> Self { + // We assume that the underlier type has at least 128 bits as the current implementation + // is valid for this case only. + // On practice, we don't use scaled underliers with underlier types that have less than 128 bits. + assert!(U::BITS >= 128); + + Self(self.0.map(|x| x.shr_128b_lanes(rhs))) + } + + #[inline] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + // We assume that the underlier type has at least 128 bits as the current implementation + // is valid for this case only. + // On practice, we don't use scaled underliers with underlier types that have less than 128 bits. + assert!(U::BITS >= 128); + + Self(array::from_fn(|i| self.0[i].unpack_lo_128b_lanes(other.0[i], log_block_len))) + } + + #[inline] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + // We assume that the underlier type has at least 128 bits as the current implementation + // is valid for this case only. + // On practice, we don't use scaled underliers with underlier types that have less than 128 bits. + assert!(U::BITS >= 128); + + Self(array::from_fn(|i| self.0[i].unpack_hi_128b_lanes(other.0[i], log_block_len))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_shr() { + let val = ScaledUnderlier::([0, 1, 2, 3]); + assert_eq!( + val >> 1, + ScaledUnderlier::([0b10000000, 0b00000000, 0b10000001, 0b00000001]) + ); + assert_eq!( + val >> 2, + ScaledUnderlier::([0b01000000, 0b10000000, 0b11000000, 0b00000000]) + ); + assert_eq!( + val >> 8, + ScaledUnderlier::([0b00000001, 0b00000010, 0b00000011, 0b00000000]) + ); + assert_eq!( + val >> 9, + ScaledUnderlier::([0b00000000, 0b10000001, 0b00000001, 0b00000000]) + ); + } + + #[test] + fn test_shl() { + let val = ScaledUnderlier::([0, 1, 2, 3]); + assert_eq!(val << 1, ScaledUnderlier::([0, 2, 4, 6])); + assert_eq!(val << 2, ScaledUnderlier::([0, 4, 8, 12])); + assert_eq!(val << 8, ScaledUnderlier::([0, 0, 1, 2])); + assert_eq!(val << 9, ScaledUnderlier::([0, 0, 2, 4])); + } +} diff --git a/crates/field/src/underlier/small_uint.rs b/crates/field/src/underlier/small_uint.rs index 2e99f211c..3413f5cab 100644 --- a/crates/field/src/underlier/small_uint.rs +++ b/crates/field/src/underlier/small_uint.rs @@ -6,7 +6,12 @@ use std::{ ops::{Not, Shl, Shr}, }; -use binius_utils::checked_arithmetics::checked_log_2; +use binius_utils::{ + bytes::{Buf, BufMut}, + checked_arithmetics::checked_log_2, + serialization::DeserializeBytes, + SerializationError, SerializationMode, SerializeBytes, +}; use bytemuck::{NoUninit, Zeroable}; use derive_more::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign}; use rand::{ @@ -154,6 +159,14 @@ impl UnderlierWithBitOps for SmallU { fn fill_with_bit(val: u8) -> Self { Self(u8::fill_with_bit(val)) & Self::ONES } + + fn shl_128b_lanes(self, rhs: usize) -> Self { + self << rhs + } + + fn shr_128b_lanes(self, rhs: usize) -> Self { + self >> rhs + } } impl From> for u8 { @@ -222,3 +235,22 @@ impl From> for SmallU<4> { pub type U1 = SmallU<1>; pub type U2 = SmallU<2>; pub type U4 = SmallU<4>; + +impl SerializeBytes for SmallU { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + self.val().serialize(write_buf, mode) + } +} + +impl DeserializeBytes for SmallU { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + Ok(Self::new(DeserializeBytes::deserialize(read_buf, mode)?)) + } +} diff --git a/crates/field/src/underlier/underlier_impls.rs b/crates/field/src/underlier/underlier_impls.rs index bb52b689a..e28e27dca 100644 --- a/crates/field/src/underlier/underlier_impls.rs +++ b/crates/field/src/underlier/underlier_impls.rs @@ -23,6 +23,16 @@ macro_rules! impl_underlier_type { debug_assert!(val == 0 || val == 1); (val as Self).wrapping_neg() } + + #[inline(always)] + fn shl_128b_lanes(self, rhs: usize) -> Self { + self << rhs + } + + #[inline(always)] + fn shr_128b_lanes(self, rhs: usize) -> Self { + self >> rhs + } } }; () => {}; diff --git a/crates/field/src/underlier/underlier_with_bit_ops.rs b/crates/field/src/underlier/underlier_with_bit_ops.rs index e151f5010..8cff6d187 100644 --- a/crates/field/src/underlier/underlier_with_bit_ops.rs +++ b/crates/field/src/underlier/underlier_with_bit_ops.rs @@ -118,6 +118,38 @@ pub trait UnderlierWithBitOps: { spread_fallback(self, log_block_len, block_idx) } + + /// Left shift within 128-bit lanes. + /// This can be more efficient than the full `Shl` implementation. + fn shl_128b_lanes(self, shift: usize) -> Self; + + /// Right shift within 128-bit lanes. + /// This can be more efficient than the full `Shr` implementation. + fn shr_128b_lanes(self, shift: usize) -> Self; + + /// Unpacks `1 << log_block_len`-bit values from low parts of `self` and `other` within 128-bit lanes. + /// + /// Example: + /// self: [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7] + /// other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7] + /// log_block_len: 1 + /// + /// result: [a_0, a_0, b_0, b_1, a_2, a_3, b_2, b_3] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + unpack_lo_128b_fallback(self, other, log_block_len) + } + + /// Unpacks `1 << log_block_len`-bit values from high parts of `self` and `other` within 128-bit lanes. + /// + /// Example: + /// self: [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7] + /// other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7] + /// log_block_len: 1 + /// + /// result: [a_4, a_5, b_4, b_5, a_6, a_7, b_6, b_7] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + unpack_hi_128b_fallback(self, other, log_block_len) + } } /// Returns a bit mask for a single `T` element inside underlier type. @@ -171,6 +203,55 @@ where result } +#[inline(always)] +fn single_element_mask_bits_128b_lanes(log_block_len: usize) -> T { + let mut mask = single_element_mask_bits(1 << log_block_len); + for i in 1..T::BITS / 128 { + mask |= mask << (i * 128); + } + + mask +} + +pub(crate) fn unpack_lo_128b_fallback( + lhs: T, + rhs: T, + log_block_len: usize, +) -> T { + assert!(log_block_len <= 6); + + let mask = single_element_mask_bits_128b_lanes::(log_block_len); + + let mut result = T::ZERO; + for i in 0..1 << (6 - log_block_len) { + result |= ((lhs.shr_128b_lanes(i << log_block_len)) & mask) + .shl_128b_lanes(i << (log_block_len + 1)); + result |= ((rhs.shr_128b_lanes(i << log_block_len)) & mask) + .shl_128b_lanes((2 * i + 1) << log_block_len); + } + + result +} + +pub(crate) fn unpack_hi_128b_fallback( + lhs: T, + rhs: T, + log_block_len: usize, +) -> T { + assert!(log_block_len <= 6); + + let mask = single_element_mask_bits_128b_lanes::(log_block_len); + let mut result = T::ZERO; + for i in 0..1 << (6 - log_block_len) { + result |= ((lhs.shr_128b_lanes(64 + (i << log_block_len))) & mask) + .shl_128b_lanes(i << (log_block_len + 1)); + result |= ((rhs.shr_128b_lanes(64 + (i << log_block_len))) & mask) + .shl_128b_lanes((2 * i + 1) << log_block_len); + } + + result +} + pub(crate) fn single_element_mask_bits(bits_count: usize) -> T { if bits_count == T::BITS { !T::ZERO diff --git a/crates/field/src/util.rs b/crates/field/src/util.rs index 3bc6ec295..12f99a256 100644 --- a/crates/field/src/util.rs +++ b/crates/field/src/util.rs @@ -16,7 +16,7 @@ where F: Field, FE: ExtensionField, { - iter::zip(a, b).map(|(a_i, b_i)| a_i * b_i).sum::() + iter::zip(a, b).map(|(a_i, b_i)| a_i * b_i).sum() } /// Calculate inner product for potentially big slices of xs and ys. @@ -38,9 +38,8 @@ where return inner_product_unchecked(PackedField::iter_slice(xs), PackedField::iter_slice(ys)); } - let calc_product_by_ys = |x_offset, ys: &[PY]| { + let calc_product_by_ys = |xs: &[PX], ys: &[PY]| { let mut result = FX::ZERO; - let xs = &xs[x_offset..]; for (j, y) in ys.iter().enumerate() { for (k, y) in y.iter().enumerate() { @@ -56,14 +55,14 @@ where // For different field sizes, the numbers may need to be adjusted. const CHUNK_SIZE: usize = 64; if ys.len() < 16 * CHUNK_SIZE { - calc_product_by_ys(0, ys) + calc_product_by_ys(xs, ys) } else { // According to benchmark results iterating by chunks here is more efficient than using `par_iter` with `min_length` directly. ys.par_chunks(CHUNK_SIZE) .enumerate() .map(|(i, ys)| { let offset = i * checked_int_div(CHUNK_SIZE * PY::WIDTH, PX::WIDTH); - calc_product_by_ys(offset, ys) + calc_product_by_ys(&xs[offset..], ys) }) .sum() } @@ -79,3 +78,91 @@ pub fn eq(x: F, y: F) -> F { pub fn powers(val: F) -> impl Iterator { iter::successors(Some(F::ONE), move |&power| Some(power * val)) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::PackedBinaryField4x32b; + + type P = PackedBinaryField4x32b; + type F =

::Scalar; + + #[test] + fn test_inner_product_par_equal_length() { + // xs and ys have the same number of packed elements + let xs1 = F::new(1); + let xs2 = F::new(2); + let xs = vec![P::set_single(xs1), P::set_single(xs2)]; + let ys1 = F::new(3); + let ys2 = F::new(4); + let ys = vec![P::set_single(ys1), P::set_single(ys2)]; + + let result = inner_product_par::(&xs, &ys); + let expected = xs1 * ys1 + xs2 * ys2; + + assert_eq!(result, expected); + } + + #[test] + fn test_inner_product_par_unequal_length() { + // ys is larger than xs due to packing differences + let xs1 = F::new(1); + let xs = vec![P::set_single(xs1)]; + let ys1 = F::new(2); + let ys2 = F::new(3); + let ys = vec![P::set_single(ys1), P::set_single(ys2)]; + + let result = inner_product_par::(&xs, &ys); + let expected = xs1 * ys1; + + assert_eq!(result, expected); + } + + #[test] + fn test_inner_product_par_large_input_single_threaded() { + // Large input but not enough to trigger parallel execution + let size = 256; + let xs: Vec

= (0..size).map(|i| P::set_single(F::new(i as u32))).collect(); + let ys: Vec

= (0..size) + .map(|i| P::set_single(F::new((i + 1) as u32))) + .collect(); + + let result = inner_product_par::(&xs, &ys); + + let expected = (0..size) + .map(|i| F::new(i as u32) * F::new((i + 1) as u32)) + .sum::(); + + assert_eq!(result, expected); + } + + #[test] + fn test_inner_product_par_large_input_par() { + // Large input to test parallel execution + let size = 2000; + let xs: Vec

= (0..size).map(|i| P::set_single(F::new(i as u32))).collect(); + let ys: Vec

= (0..size) + .map(|i| P::set_single(F::new((i + 1) as u32))) + .collect(); + + let result = inner_product_par::(&xs, &ys); + + let expected = (0..size) + .map(|i| F::new(i as u32) * F::new((i + 1) as u32)) + .sum::(); + + assert_eq!(result, expected); + } + + #[test] + fn test_inner_product_par_empty() { + // Case: Empty input should return 0 + let xs: Vec

= vec![]; + let ys: Vec

= vec![]; + + let result = inner_product_par::(&xs, &ys); + let expected = F::ZERO; + + assert_eq!(result, expected); + } +} diff --git a/crates/hal/src/backend.rs b/crates/hal/src/backend.rs index 05bd2e87a..8d9dac1d2 100644 --- a/crates/hal/src/backend.rs +++ b/crates/hal/src/backend.rs @@ -5,9 +5,9 @@ use std::{ ops::{Deref, DerefMut}, }; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField}; +use binius_field::{Field, PackedExtension, PackedField}; use binius_math::{ - CompositionPolyOS, MultilinearExtension, MultilinearPoly, MultilinearQuery, MultilinearQueryRef, + CompositionPoly, MultilinearExtension, MultilinearPoly, MultilinearQuery, MultilinearQueryRef, }; use binius_maybe_rayon::iter::FromParallelIterator; use tracing::instrument; @@ -42,28 +42,8 @@ pub trait ComputationBackend: Send + Sync + Debug { query: &[P::Scalar], ) -> Result, Error>; - /// Calculate the accumulated evaluations for the first round of zerocheck. - fn sumcheck_compute_first_round_evals( - &self, - n_vars: usize, - multilinears: &[SumcheckMultilinear], - evaluators: &[Evaluator], - evaluation_points: &[FDomain], - ) -> Result>, Error> - where - FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension - + PackedExtension, - M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

; - /// Calculate the accumulated evaluations for an arbitrary round of zerocheck. - fn sumcheck_compute_later_round_evals( + fn sumcheck_compute_round_evals( &self, n_vars: usize, tensor_query: Option>, @@ -73,13 +53,10 @@ pub trait ComputationBackend: Send + Sync + Debug { ) -> Result>, Error> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension, + P: PackedExtension, M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

; + Evaluator: SumcheckEvaluator + Sync, + Composition: CompositionPoly

; /// Partially evaluate the polynomial with assignment to the high-indexed variables. fn evaluate_partial_high( @@ -108,35 +85,7 @@ where T::tensor_product_full_query(self, query) } - fn sumcheck_compute_first_round_evals( - &self, - n_vars: usize, - multilinears: &[SumcheckMultilinear], - evaluators: &[Evaluator], - evaluation_points: &[FDomain], - ) -> Result>, Error> - where - FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension - + PackedExtension, - M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, - { - T::sumcheck_compute_first_round_evals::<_, FBase, _, _, _, _, _>( - self, - n_vars, - multilinears, - evaluators, - evaluation_points, - ) - } - - fn sumcheck_compute_later_round_evals( + fn sumcheck_compute_round_evals( &self, n_vars: usize, tensor_query: Option>, @@ -146,15 +95,12 @@ where ) -> Result>, Error> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension, + P: PackedExtension, M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, + Evaluator: SumcheckEvaluator + Sync, + Composition: CompositionPoly

, { - T::sumcheck_compute_later_round_evals( + T::sumcheck_compute_round_evals( self, n_vars, tensor_query, diff --git a/crates/hal/src/cpu.rs b/crates/hal/src/cpu.rs index 3d9eac8fe..acd17b225 100644 --- a/crates/hal/src/cpu.rs +++ b/crates/hal/src/cpu.rs @@ -2,16 +2,16 @@ use std::fmt::Debug; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField}; +use binius_field::{Field, PackedExtension, PackedField}; use binius_math::{ - eq_ind_partial_eval, CompositionPolyOS, MultilinearExtension, MultilinearPoly, + eq_ind_partial_eval, CompositionPoly, MultilinearExtension, MultilinearPoly, MultilinearQueryRef, }; use tracing::instrument; use crate::{ - sumcheck_round_calculator::{calculate_first_round_evals, calculate_later_round_evals}, - ComputationBackend, Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear, + sumcheck_round_calculator::calculate_round_evals, ComputationBackend, Error, RoundEvals, + SumcheckEvaluator, SumcheckMultilinear, }; /// Implementation of ComputationBackend for the default Backend that uses the CPU for all computations. @@ -37,31 +37,7 @@ impl ComputationBackend for CpuBackend { Ok(eq_ind_partial_eval(query)) } - fn sumcheck_compute_first_round_evals( - &self, - n_vars: usize, - multilinears: &[SumcheckMultilinear], - evaluators: &[Evaluator], - evaluation_points: &[FDomain], - ) -> Result>, Error> - where - FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, - M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, - { - calculate_first_round_evals::<_, FBase, _, _, _, _, _>( - n_vars, - multilinears, - evaluators, - evaluation_points, - ) - } - - fn sumcheck_compute_later_round_evals( + fn sumcheck_compute_round_evals( &self, n_vars: usize, tensor_query: Option>, @@ -71,21 +47,12 @@ impl ComputationBackend for CpuBackend { ) -> Result>, Error> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension, + P: PackedExtension, M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, + Evaluator: SumcheckEvaluator + Sync, + Composition: CompositionPoly

, { - calculate_later_round_evals( - n_vars, - tensor_query, - multilinears, - evaluators, - evaluation_points, - ) + calculate_round_evals(n_vars, tensor_query, multilinears, evaluators, evaluation_points) } #[instrument(skip_all, name = "CpuBackend::evaluate_partial_high")] diff --git a/crates/hal/src/sumcheck_evaluator.rs b/crates/hal/src/sumcheck_evaluator.rs index 982b9c6c7..31e0721a2 100644 --- a/crates/hal/src/sumcheck_evaluator.rs +++ b/crates/hal/src/sumcheck_evaluator.rs @@ -2,17 +2,13 @@ use std::ops::Range; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField, PackedSubfield}; +use binius_field::{Field, PackedField}; /// Evaluations of a polynomial at a set of evaluation points. #[derive(Debug, Clone)] pub struct RoundEvals(pub Vec); -pub trait SumcheckEvaluator -where - FBase: Field, - P: PackedField> + PackedExtension, -{ +pub trait SumcheckEvaluator { /// The range of eval point indices over which composition evaluation and summation should happen. /// Returned range must equal the result of `n_round_evals()` in length. fn eval_point_indices(&self) -> Range; @@ -27,7 +23,7 @@ where &self, subcube_vars: usize, subcube_index: usize, - batch_query: &[&[PackedSubfield]], + batch_query: &[&[P]], ) -> P; /// Returns the composition evaluated by this object. diff --git a/crates/hal/src/sumcheck_round_calculator.rs b/crates/hal/src/sumcheck_round_calculator.rs index f71d1f846..d14b1594b 100644 --- a/crates/hal/src/sumcheck_round_calculator.rs +++ b/crates/hal/src/sumcheck_round_calculator.rs @@ -4,19 +4,16 @@ //! //! This is one of the core computational tasks in the sumcheck proving algorithm. -use std::{iter, marker::PhantomData}; +use std::iter; -use binius_field::{ - recast_packed, ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, - RepackedExtension, -}; +use binius_field::{Field, PackedExtension, PackedField, PackedSubfield}; use binius_math::{ - deinterleave, extrapolate_lines, CompositionPolyOS, MultilinearPoly, MultilinearQuery, + deinterleave, extrapolate_lines, CompositionPoly, MultilinearPoly, MultilinearQuery, MultilinearQueryRef, }; use binius_maybe_rayon::prelude::*; use bytemuck::zeroed_vec; -use itertools::izip; +use itertools::{izip, Itertools}; use stackalloc::stackalloc_with_iter; use crate::{Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear}; @@ -44,39 +41,11 @@ trait SumcheckMultilinearAccess { ) -> Result<(), Error>; } -/// Calculate the accumulated evaluations for the first sumcheck round. -pub(crate) fn calculate_first_round_evals( - n_vars: usize, - multilinears: &[SumcheckMultilinear], - evaluators: &[Evaluator], - evaluation_points: &[FDomain], -) -> Result>, Error> -where - FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, - M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, -{ - let accesses = multilinears - .iter() - .map(FirstRoundAccess::new) - .collect::>(); - calculate_round_evals::<_, FBase, _, _, _, _, _>( - n_vars, - &accesses, - evaluators, - evaluation_points, - ) -} - /// Calculate the accumulated evaluations for an arbitrary sumcheck round. /// /// See [`calculate_first_round_evals`] for an optimized version of this method /// that works over small fields in the first round. -pub(crate) fn calculate_later_round_evals( +pub(crate) fn calculate_round_evals( n_vars: usize, tensor_query: Option>, multilinears: &[SumcheckMultilinear], @@ -85,26 +54,27 @@ pub(crate) fn calculate_later_round_evals Result>, Error> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, + F: Field, + P: PackedField + PackedExtension, M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, + Evaluator: SumcheckEvaluator + Sync, + Composition: CompositionPoly

, { let empty_query = MultilinearQuery::with_capacity(0); - let query = tensor_query.unwrap_or_else(|| empty_query.to_ref()); + let tensor_query = tensor_query.unwrap_or_else(|| empty_query.to_ref()); - let accesses = multilinears + let later_rounds_accesses = multilinears .iter() - .map(|multilinear| LaterRoundAccess { + .map(|multilinear| LargeFieldAccess { multilinear, - tensor_query: query, + tensor_query, }) - .collect::>(); - calculate_round_evals::<_, F, _, _, _, _, _>(n_vars, &accesses, evaluators, evaluation_points) + .collect_vec(); + + calculate_round_evals_with_access(n_vars, &later_rounds_accesses, evaluators, evaluation_points) } -fn calculate_round_evals( +fn calculate_round_evals_with_access( n_vars: usize, multilinears: &[Access], evaluators: &[Evaluator], @@ -112,12 +82,11 @@ fn calculate_round_evals( ) -> Result>, Error> where FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, - Evaluator: SumcheckEvaluator + Sync, - Access: SumcheckMultilinearAccess> + Sync, - Composition: CompositionPolyOS

, + F: Field, + P: PackedField + PackedExtension, + Evaluator: SumcheckEvaluator + Sync, + Access: SumcheckMultilinearAccess

+ Sync, + Composition: CompositionPoly

, { let n_multilinears = multilinears.len(); let n_round_evals = evaluators @@ -182,9 +151,9 @@ where // `binius_math::univariate::extrapolate_line`, except that we do // not repeat the broadcast of the subfield element to a packed // subfield. - *eval_z = recast_packed::(extrapolate_lines( - recast_packed::(eval_0), - recast_packed::(eval_1), + *eval_z = P::cast_ext(extrapolate_lines( + P::cast_base(eval_0), + P::cast_base(eval_1), eval_point_broadcast, )); } @@ -316,57 +285,7 @@ impl ParFoldStates { } #[derive(Debug)] -struct FirstRoundAccess<'a, PBase, P, M> -where - P: PackedField, - M: MultilinearPoly

+ Send + Sync, -{ - multilinear: &'a SumcheckMultilinear, - _marker: PhantomData, -} - -impl<'a, PBase, P, M> FirstRoundAccess<'a, PBase, P, M> -where - P: PackedField, - M: MultilinearPoly

+ Send + Sync, -{ - const fn new(multilinear: &'a SumcheckMultilinear) -> Self { - Self { - multilinear, - _marker: PhantomData, - } - } -} - -impl SumcheckMultilinearAccess for FirstRoundAccess<'_, PBase, P, M> -where - PBase: PackedField, - P: RepackedExtension, - P::Scalar: ExtensionField, - M: MultilinearPoly

+ Send + Sync, -{ - fn subcube_evaluations( - &self, - subcube_vars: usize, - subcube_index: usize, - evals: &mut [PBase], - ) -> Result<(), Error> { - if let SumcheckMultilinear::Transparent { multilinear, .. } = self.multilinear { - let evals =

>::cast_exts_mut(evals); - Ok(multilinear.subcube_evals( - subcube_vars, - subcube_index, - >::LOG_DEGREE, - evals, - )?) - } else { - panic!("precondition: no folded multilinears in the first round"); - } - } -} - -#[derive(Debug)] -struct LaterRoundAccess<'a, P, M> +struct LargeFieldAccess<'a, P, M> where P: PackedField, M: MultilinearPoly

+ Send + Sync, @@ -375,7 +294,7 @@ where tensor_query: MultilinearQueryRef<'a, P>, } -impl SumcheckMultilinearAccess

for LaterRoundAccess<'_, P, M> +impl SumcheckMultilinearAccess

for LargeFieldAccess<'_, P, M> where P: PackedField, M: MultilinearPoly

+ Send + Sync, @@ -388,28 +307,28 @@ where ) -> Result<(), Error> { match self.multilinear { SumcheckMultilinear::Transparent { multilinear, .. } => { - // TODO: Stop using LaterRoundAccess for first round in RegularSumcheckProver and - // GPASumcheckProver, then remove this conditional. if self.tensor_query.n_vars() == 0 { - Ok(multilinear.subcube_evals(subcube_vars, subcube_index, 0, evals)?) + multilinear.subcube_evals(subcube_vars, subcube_index, 0, evals)? } else { - Ok(multilinear.subcube_inner_products( + multilinear.subcube_inner_products( self.tensor_query, subcube_vars, subcube_index, evals, - )?) + )? } } SumcheckMultilinear::Folded { large_field_folded_multilinear, - } => Ok(large_field_folded_multilinear.subcube_evals( + } => large_field_folded_multilinear.subcube_evals( subcube_vars, subcube_index, 0, evals, - )?), + )?, } + + Ok(()) } } diff --git a/crates/hash/src/groestl/hasher.rs b/crates/hash/src/groestl/hasher.rs index 74f9e49c9..d6ee087ce 100644 --- a/crates/hash/src/groestl/hasher.rs +++ b/crates/hash/src/groestl/hasher.rs @@ -153,7 +153,6 @@ impl Hasher

for Groestl256 where F: BinaryField + From + Into, P: PackedExtension, - P::Scalar: ExtensionField, OptimalUnderlier256b: PackScalar + Divisible, Self: UpdateOverSlice, { @@ -300,7 +299,7 @@ mod tests { } #[test] - fn test_aes_binary_convertion() { + fn test_aes_binary_conversion() { let mut rng = thread_rng(); let input_aes: [PackedAESBinaryField32x8b; 90] = array::from_fn(|_| PackedAESBinaryField32x8b::random(&mut rng)); diff --git a/crates/hash/src/serialization.rs b/crates/hash/src/serialization.rs index ff64c9756..f1292c646 100644 --- a/crates/hash/src/serialization.rs +++ b/crates/hash/src/serialization.rs @@ -2,7 +2,7 @@ use std::{borrow::Borrow, cmp::min}; -use binius_utils::serialization::SerializeBytes; +use binius_utils::{SerializationMode, SerializeBytes}; use bytes::{buf::UninitSlice, BufMut}; use digest::{ core_api::{Block, BlockSizeUser}, @@ -75,7 +75,7 @@ where let mut buffer = HashBuffer::new(&mut hasher); for item in items { item.borrow() - .serialize(&mut buffer) + .serialize(&mut buffer, SerializationMode::CanonicalTower) .expect("HashBuffer has infinite capacity"); } } diff --git a/crates/hash/src/vision.rs b/crates/hash/src/vision.rs index 45940192a..750db41b8 100644 --- a/crates/hash/src/vision.rs +++ b/crates/hash/src/vision.rs @@ -221,7 +221,6 @@ where U: PackScalar + Divisible, F: BinaryField + From + Into, P: PackedExtension, - P::Scalar: ExtensionField, PackedAESBinaryField8x32b: WithUnderlier, { type Digest = PackedType; diff --git a/crates/macros/Cargo.toml b/crates/macros/Cargo.toml index c0293a3cc..c3b8defb1 100644 --- a/crates/macros/Cargo.toml +++ b/crates/macros/Cargo.toml @@ -16,6 +16,7 @@ proc-macro2.workspace = true binius_core = { path = "../core" } binius_field = { path = "../field" } binius_math = { path = "../math" } +binius_utils = { path = "../utils" } paste.workspace = true rand.workspace = true diff --git a/crates/macros/src/arith_circuit_poly.rs b/crates/macros/src/arith_circuit_poly.rs index 0f8f1f6ad..93ba74a79 100644 --- a/crates/macros/src/arith_circuit_poly.rs +++ b/crates/macros/src/arith_circuit_poly.rs @@ -3,55 +3,22 @@ use quote::{quote, ToTokens}; use syn::{bracketed, parse::Parse, parse_quote, spanned::Spanned, Token}; -use crate::composition_poly::CompositionPolyItem; - #[derive(Debug)] pub(crate) struct ArithCircuitPolyItem { poly: syn::Expr, - /// We create a composition poly to cache the efficient evaluation implementations - /// for the known packed field types. - composition_poly: CompositionPolyItem, field_name: syn::Ident, } impl ToTokens for ArithCircuitPolyItem { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { - let Self { - poly, - composition_poly, - field_name, - } = self; - - let mut register_cached_impls = proc_macro2::TokenStream::new(); - let packed_extensions = get_packed_extensions(field_name); - if packed_extensions.is_empty() { - register_cached_impls.extend(quote! { result }); - } else { - register_cached_impls.extend(quote! ( - let mut cached = binius_core::polynomial::CachedPoly::new(composition); - - )); - - for packed_extension in get_packed_extensions(field_name) { - register_cached_impls.extend(quote! { - cached.register::(composition.clone()); - }); - } - - register_cached_impls.extend(quote! { - cached - }); - } + let Self { poly, field_name } = self; tokens.extend(quote! { { use binius_field::Field; use binius_math::ArithExpr as Expr; - let mut result = binius_core::polynomial::ArithCircuitPoly::::new(#poly); - let composition = #composition_poly; - - #register_cached_impls + binius_core::polynomial::ArithCircuitPoly::::new(#poly) } }); } @@ -59,7 +26,6 @@ impl ToTokens for ArithCircuitPolyItem { impl Parse for ArithCircuitPolyItem { fn parse(input: syn::parse::ParseStream) -> syn::Result { - let original_tokens = input.fork(); let vars: Vec = { let content; bracketed!(content in input); @@ -73,14 +39,8 @@ impl Parse for ArithCircuitPolyItem { input.parse::()?; let field_name = input.parse()?; - // Here we assume that the `composition_poly` shares the expression syntax with the `arithmetic_circuit_poly`. - let composition_poly = CompositionPolyItem::parse(&original_tokens)?; - Ok(Self { - poly, - composition_poly, - field_name, - }) + Ok(Self { poly, field_name }) } } @@ -120,350 +80,3 @@ fn flatten_expr(expr: &syn::Expr, vars: &[syn::Ident]) -> Result Vec { - match ident.to_string().as_str() { - "BinaryField1b" => vec![ - parse_quote!(PackedBinaryField1x1b), - parse_quote!(PackedBinaryField2x1b), - parse_quote!(PackedBinaryField4x1b), - parse_quote!(PackedBinaryField8x1b), - parse_quote!(PackedBinaryField16x1b), - parse_quote!(PackedBinaryField32x1b), - parse_quote!(PackedBinaryField64x1b), - parse_quote!(PackedBinaryField128x1b), - parse_quote!(PackedBinaryField256x1b), - parse_quote!(PackedBinaryField512x1b), - parse_quote!(PackedBinaryField1x2b), - parse_quote!(PackedBinaryField2x2b), - parse_quote!(PackedBinaryField4x2b), - parse_quote!(PackedBinaryField8x2b), - parse_quote!(PackedBinaryField16x2b), - parse_quote!(PackedBinaryField32x2b), - parse_quote!(PackedBinaryField64x2b), - parse_quote!(PackedBinaryField128x2b), - parse_quote!(PackedBinaryField256x2b), - parse_quote!(PackedBinaryField1x4b), - parse_quote!(PackedBinaryField2x4b), - parse_quote!(PackedBinaryField4x4b), - parse_quote!(PackedBinaryField8x4b), - parse_quote!(PackedBinaryField16x4b), - parse_quote!(PackedBinaryField32x4b), - parse_quote!(PackedBinaryField64x4b), - parse_quote!(PackedBinaryField128x4b), - parse_quote!(PackedBinaryField1x8b), - parse_quote!(PackedBinaryField2x8b), - parse_quote!(PackedBinaryField4x8b), - parse_quote!(PackedBinaryField8x8b), - parse_quote!(PackedBinaryField16x8b), - parse_quote!(PackedBinaryField32x8b), - parse_quote!(PackedBinaryField64x8b), - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - parse_quote!(PackedAESBinaryField1x8b), - parse_quote!(PackedAESBinaryField2x8b), - parse_quote!(PackedAESBinaryField4x8b), - parse_quote!(PackedAESBinaryField8x8b), - parse_quote!(PackedAESBinaryField16x8b), - parse_quote!(PackedAESBinaryField32x8b), - parse_quote!(PackedAESBinaryField64x8b), - parse_quote!(PackedAESBinaryField1x16b), - parse_quote!(PackedAESBinaryField2x16b), - parse_quote!(PackedAESBinaryField4x16b), - parse_quote!(PackedAESBinaryField8x16b), - parse_quote!(PackedAESBinaryField16x16b), - parse_quote!(PackedAESBinaryField32x16b), - parse_quote!(PackedAESBinaryField1x32b), - parse_quote!(PackedAESBinaryField2x32b), - parse_quote!(PackedAESBinaryField4x32b), - parse_quote!(PackedAESBinaryField8x32b), - parse_quote!(PackedAESBinaryField16x32b), - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - parse_quote!(PackedBinaryPolyval1x128b), - parse_quote!(PackedBinaryPolyval2x128b), - parse_quote!(PackedBinaryPolyval4x128b), - ], - "BinaryField2b" => { - vec![ - parse_quote!(PackedBinaryField1x2b), - parse_quote!(PackedBinaryField2x2b), - parse_quote!(PackedBinaryField4x2b), - parse_quote!(PackedBinaryField8x2b), - parse_quote!(PackedBinaryField16x2b), - parse_quote!(PackedBinaryField32x2b), - parse_quote!(PackedBinaryField64x2b), - parse_quote!(PackedBinaryField128x2b), - parse_quote!(PackedBinaryField256x2b), - parse_quote!(PackedBinaryField1x4b), - parse_quote!(PackedBinaryField2x4b), - parse_quote!(PackedBinaryField4x4b), - parse_quote!(PackedBinaryField8x4b), - parse_quote!(PackedBinaryField16x4b), - parse_quote!(PackedBinaryField32x4b), - parse_quote!(PackedBinaryField64x4b), - parse_quote!(PackedBinaryField128x4b), - parse_quote!(PackedBinaryField1x8b), - parse_quote!(PackedBinaryField2x8b), - parse_quote!(PackedBinaryField4x8b), - parse_quote!(PackedBinaryField8x8b), - parse_quote!(PackedBinaryField16x8b), - parse_quote!(PackedBinaryField32x8b), - parse_quote!(PackedBinaryField64x8b), - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField4b" => { - vec![ - parse_quote!(PackedBinaryField1x4b), - parse_quote!(PackedBinaryField2x4b), - parse_quote!(PackedBinaryField4x4b), - parse_quote!(PackedBinaryField8x4b), - parse_quote!(PackedBinaryField16x4b), - parse_quote!(PackedBinaryField32x4b), - parse_quote!(PackedBinaryField64x4b), - parse_quote!(PackedBinaryField128x4b), - parse_quote!(PackedBinaryField1x8b), - parse_quote!(PackedBinaryField2x8b), - parse_quote!(PackedBinaryField4x8b), - parse_quote!(PackedBinaryField8x8b), - parse_quote!(PackedBinaryField16x8b), - parse_quote!(PackedBinaryField32x8b), - parse_quote!(PackedBinaryField64x8b), - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField8b" => { - vec![ - parse_quote!(PackedBinaryField1x8b), - parse_quote!(PackedBinaryField2x8b), - parse_quote!(PackedBinaryField4x8b), - parse_quote!(PackedBinaryField8x8b), - parse_quote!(PackedBinaryField16x8b), - parse_quote!(PackedBinaryField32x8b), - parse_quote!(PackedBinaryField64x8b), - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField16b" => { - vec![ - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField32b" => { - vec![ - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField64b" => { - vec![ - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - - "BinaryField128b" => { - vec![ - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - - "AESTowerField8b" => { - vec![ - parse_quote!(PackedAESBinaryField1x8b), - parse_quote!(PackedAESBinaryField2x8b), - parse_quote!(PackedAESBinaryField4x8b), - parse_quote!(PackedAESBinaryField8x8b), - parse_quote!(PackedAESBinaryField16x8b), - parse_quote!(PackedAESBinaryField32x8b), - parse_quote!(PackedAESBinaryField64x8b), - parse_quote!(PackedAESBinaryField1x16b), - parse_quote!(PackedAESBinaryField2x16b), - parse_quote!(PackedAESBinaryField4x16b), - parse_quote!(PackedAESBinaryField8x16b), - parse_quote!(PackedAESBinaryField16x16b), - parse_quote!(PackedAESBinaryField32x16b), - parse_quote!(PackedAESBinaryField1x32b), - parse_quote!(PackedAESBinaryField2x32b), - parse_quote!(PackedAESBinaryField4x32b), - parse_quote!(PackedAESBinaryField8x32b), - parse_quote!(PackedAESBinaryField16x32b), - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - parse_quote!(ByteSlicedAES32x128b), - ] - } - "AESTowerField16b" => { - vec![ - parse_quote!(PackedAESBinaryField1x16b), - parse_quote!(PackedAESBinaryField2x16b), - parse_quote!(PackedAESBinaryField4x16b), - parse_quote!(PackedAESBinaryField8x16b), - parse_quote!(PackedAESBinaryField16x16b), - parse_quote!(PackedAESBinaryField32x16b), - parse_quote!(PackedAESBinaryField1x32b), - parse_quote!(PackedAESBinaryField2x32b), - parse_quote!(PackedAESBinaryField4x32b), - parse_quote!(PackedAESBinaryField8x32b), - parse_quote!(PackedAESBinaryField16x32b), - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - ] - } - "AESTowerField32b" => { - vec![ - parse_quote!(PackedAESBinaryField1x32b), - parse_quote!(PackedAESBinaryField2x32b), - parse_quote!(PackedAESBinaryField4x32b), - parse_quote!(PackedAESBinaryField8x32b), - parse_quote!(PackedAESBinaryField16x32b), - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - ] - } - "AESTowerField64b" => { - vec![ - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - ] - } - "AESTowerField128b" => { - vec![ - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - ] - } - - _ => vec![], - } -} diff --git a/crates/macros/src/composition_poly.rs b/crates/macros/src/composition_poly.rs index 36200a762..0472fc328 100644 --- a/crates/macros/src/composition_poly.rs +++ b/crates/macros/src/composition_poly.rs @@ -43,7 +43,10 @@ impl ToTokens for CompositionPolyItem { #[derive(Debug, Clone, Copy)] struct #name; - impl binius_math::CompositionPoly<#scalar_type> for #name { + impl

binius_math::CompositionPoly

for #name + where + P: binius_field::PackedField>, + { fn n_vars(&self) -> usize { #n_vars } @@ -56,18 +59,18 @@ impl ToTokens for CompositionPolyItem { 0 } - fn expression>(&self) -> binius_math::ArithExpr { + fn expression(&self) -> binius_math::ArithExpr { (#expr).convert_field() } - fn evaluate>>(&self, query: &[P]) -> Result { + fn evaluate(&self, query: &[P]) -> Result { if query.len() != #n_vars { return Err(binius_math::Error::IncorrectQuerySize { expected: #n_vars }); } Ok(#eval_single) } - fn batch_evaluate>>( + fn batch_evaluate( &self, batch_query: &[&[P]], evals: &mut [P], @@ -89,36 +92,6 @@ impl ToTokens for CompositionPolyItem { Ok(()) } } - - impl

binius_math::CompositionPolyOS

for #name - where - P: binius_field::PackedField>, - { - fn n_vars(&self) -> usize { - >::n_vars(self) - } - - fn degree(&self) -> usize { - >::degree(self) - } - - fn binary_tower_level(&self) -> usize { - >::binary_tower_level(self) - } - - fn expression(&self) -> binius_math::ArithExpr { - >::expression(self) - } - - fn evaluate(&self, query: &[P]) -> Result { - >::evaluate(self, query) - } - - fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), binius_math::Error> { - >::batch_evaluate(self, batch_query, evals) - } - } - }; if *is_anonymous { diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index c2d8f069e..08fa219a2 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -5,26 +5,24 @@ mod arith_circuit_poly; mod arith_expr; mod composition_poly; -use std::collections::BTreeSet; - use proc_macro::TokenStream; use quote::{quote, ToTokens}; -use syn::{parse_macro_input, Data, DeriveInput, Fields}; +use syn::{parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, ItemImpl}; use crate::{ arith_circuit_poly::ArithCircuitPolyItem, arith_expr::ArithExprItem, composition_poly::CompositionPolyItem, }; -/// Useful for concisely creating structs that implement CompositionPolyOS. +/// Useful for concisely creating structs that implement CompositionPoly. /// This currently only supports creating composition polynomials of tower level 0. /// /// ``` /// use binius_macros::composition_poly; -/// use binius_math::CompositionPolyOS; +/// use binius_math::CompositionPoly; /// use binius_field::{Field, BinaryField1b as F}; /// -/// // Defines named struct without any fields that implements CompositionPolyOS +/// // Defines named struct without any fields that implements CompositionPoly /// composition_poly!(MyComposition[x, y, z] = x + y * z); /// assert_eq!( /// MyComposition.evaluate(&[F::ONE, F::ONE, F::ONE]).unwrap(), @@ -76,156 +74,261 @@ pub fn arith_circuit_poly(input: TokenStream) -> TokenStream { .into() } -/// Implements `pub fn iter_oracles(&self) -> impl Iterator`. +/// Derives the trait binius_utils::DeserializeBytes for a struct or enum /// -/// Detects and includes fields with type `OracleId`, `[OracleId; N]` +/// See the DeserializeBytes derive macro docs for examples/tests +#[proc_macro_derive(SerializeBytes)] +pub fn derive_serialize_bytes(input: TokenStream) -> TokenStream { + let input: DeriveInput = parse_macro_input!(input); + let span = input.span(); + let name = input.ident; + let mut generics = input.generics.clone(); + generics.type_params_mut().for_each(|type_param| { + type_param + .bounds + .push(parse_quote!(binius_utils::SerializeBytes)) + }); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let body = match input.data { + Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(), + Data::Struct(data) => { + let fields = field_names(data.fields, None); + quote! { + #(binius_utils::SerializeBytes::serialize(&self.#fields, &mut write_buf, mode)?;)* + } + } + Data::Enum(data) => { + let variants = data + .variants + .into_iter() + .enumerate() + .map(|(i, variant)| { + let variant_ident = &variant.ident; + let variant_index = i as u8; + let fields = field_names(variant.fields.clone(), Some("field_")); + let serialize_variant = quote! { + binius_utils::SerializeBytes::serialize(&#variant_index, &mut write_buf, mode)?; + #(binius_utils::SerializeBytes::serialize(#fields, &mut write_buf, mode)?;)* + }; + match variant.fields { + Fields::Named(_) => quote! { + Self::#variant_ident { #(#fields),* } => { + #serialize_variant + } + }, + Fields::Unnamed(_) => quote! { + Self::#variant_ident(#(#fields),*) => { + #serialize_variant + } + }, + Fields::Unit => quote! { + Self::#variant_ident => { + #serialize_variant + } + }, + } + }) + .collect::>(); + + quote! { + match self { + #(#variants)* + } + } + } + }; + quote! { + impl #impl_generics binius_utils::SerializeBytes for #name #ty_generics #where_clause { + fn serialize(&self, mut write_buf: impl binius_utils::bytes::BufMut, mode: binius_utils::SerializationMode) -> Result<(), binius_utils::SerializationError> { + #body + Ok(()) + } + } + }.into() +} + +/// Derives the trait binius_utils::DeserializeBytes for a struct or enum /// /// ``` -/// use binius_macros::IterOracles; -/// type OracleId = usize; -/// type BatchId = usize; -/// -/// #[derive(IterOracles)] -/// struct Oracle { -/// x: OracleId, -/// y: [OracleId; 5], -/// z: [OracleId; 5*2], -/// ignored_field1: usize, -/// ignored_field2: BatchId, -/// ignored_field3: [[OracleId; 5]; 2], +/// use binius_field::BinaryField128b; +/// use binius_utils::{SerializeBytes, DeserializeBytes, SerializationMode}; +/// use binius_macros::{SerializeBytes, DeserializeBytes}; +/// +/// #[derive(Debug, PartialEq, SerializeBytes, DeserializeBytes)] +/// enum MyEnum { +/// A(usize), +/// B { x: u32, y: u32 }, +/// C /// } +/// +/// +/// let mut buf = vec![]; +/// let value = MyEnum::B { x: 42, y: 1337 }; +/// MyEnum::serialize(&value, &mut buf, SerializationMode::Native).unwrap(); +/// assert_eq!( +/// MyEnum::deserialize(buf.as_slice(), SerializationMode::Native).unwrap(), +/// value +/// ); +/// +/// +/// #[derive(Debug, PartialEq, SerializeBytes, DeserializeBytes)] +/// struct MyStruct { +/// data: Vec +/// } +/// +/// let mut buf = vec![]; +/// let value = MyStruct { +/// data: vec![BinaryField128b::new(1234), BinaryField128b::new(5678)] +/// }; +/// MyStruct::serialize(&value, &mut buf, SerializationMode::CanonicalTower).unwrap(); +/// assert_eq!( +/// MyStruct::::deserialize(buf.as_slice(), SerializationMode::CanonicalTower).unwrap(), +/// value +/// ); /// ``` -#[proc_macro_derive(IterOracles)] -pub fn iter_oracle_derive(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - let Data::Struct(data) = &input.data else { - panic!("#[derive(IterOracles)] is only defined for structs with named fields"); - }; - let Fields::Named(fields) = &data.fields else { - panic!("#[derive(IterOracles)] is only defined for structs with named fields"); +#[proc_macro_derive(DeserializeBytes)] +pub fn derive_deserialize_bytes(input: TokenStream) -> TokenStream { + let input: DeriveInput = parse_macro_input!(input); + let span = input.span(); + let name = input.ident; + let mut generics = input.generics.clone(); + generics.type_params_mut().for_each(|type_param| { + type_param + .bounds + .push(parse_quote!(binius_utils::DeserializeBytes)) + }); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let deserialize_value = quote! { + binius_utils::DeserializeBytes::deserialize(&mut read_buf, mode)? }; + let body = match input.data { + Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(), + Data::Struct(data) => { + let fields = field_names(data.fields, None); + quote! { + Ok(Self { + #(#fields: #deserialize_value,)* + }) + } + } + Data::Enum(data) => { + let variants = data + .variants + .into_iter() + .enumerate() + .map(|(i, variant)| { + let variant_ident = &variant.ident; + let variant_index: u8 = i as u8; + match variant.fields { + Fields::Named(fields) => { + let fields = fields + .named + .into_iter() + .map(|field| field.ident) + .map(|field_name| quote!(#field_name: #deserialize_value)) + .collect::>(); - let name = &input.ident; - let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl(); + quote! { + #variant_index => Self::#variant_ident { #(#fields,)* } + } + } + Fields::Unnamed(fields) => { + let fields = fields + .unnamed + .into_iter() + .map(|_| quote!(#deserialize_value)) + .collect::>(); - let oracles = fields - .named - .iter() - .filter_map(|f| { - let name = f.ident.clone(); - match &f.ty { - syn::Type::Path(type_path) if type_path.path.is_ident("OracleId") => { - Some(quote!(std::iter::once(self.#name))) - } - syn::Type::Array(array) => { - if let syn::Type::Path(type_path) = *array.elem.clone() { - type_path - .path - .is_ident("OracleId") - .then(|| quote!(self.#name.into_iter())) - } else { - None + quote! { + #variant_index => Self::#variant_ident(#(#fields,)*) + } + } + Fields::Unit => quote! { + #variant_index => Self::#variant_ident + }, } - } - _ => None, - } - }) - .collect::>(); + }) + .collect::>(); + let name = name.to_string(); + quote! { + let variant_index: u8 = #deserialize_value; + Ok(match variant_index { + #(#variants,)* + _ => { + return Err(binius_utils::SerializationError::UnknownEnumVariant { + name: #name, + index: variant_index + }) + } + }) + } + } + }; quote! { - impl #impl_generics #name #ty_generics #where_clause { - pub fn iter_oracles(&self) -> impl Iterator { - std::iter::empty() - #(.chain(#oracles))* + impl #impl_generics binius_utils::DeserializeBytes for #name #ty_generics #where_clause { + fn deserialize(mut read_buf: impl binius_utils::bytes::Buf, mode: binius_utils::SerializationMode) -> Result + where + Self: Sized + { + #body } } } .into() } -/// Implements `pub fn iter_polys(&self) -> impl Iterator>`. +/// Use on an impl block for MultivariatePoly, to automatically implement erased_serialize_bytes. /// -/// Supports `Vec

`, `[Vec

; N]`. Currently doesn't filter out fields from the struct, so you can't add any other fields. +/// Importantly, this will serialize the concrete instance, prefixed by the identifier of the data type. /// -/// ``` -/// use binius_macros::IterPolys; -/// use binius_field::PackedField; -/// -/// #[derive(IterPolys)] -/// struct Witness { -/// x: Vec

, -/// y: [Vec

; 5], -/// z: [Vec

; 5*2], -/// } -/// ``` -#[proc_macro_derive(IterPolys)] -pub fn iter_witness_derive(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - let Data::Struct(data) = &input.data else { - panic!("#[derive(IterPolys)] is only defined for structs with named fields"); - }; - let Fields::Named(fields) = &data.fields else { - panic!("#[derive(IterPolys)] is only defined for structs with named fields"); +/// This prefix can be used to figure out which concrete data type it should use for deserialization later. +#[proc_macro_attribute] +pub fn erased_serialize_bytes(_attr: TokenStream, item: TokenStream) -> TokenStream { + let mut item_impl: ItemImpl = parse_macro_input!(item); + let syn::Type::Path(p) = &*item_impl.self_ty else { + return syn::Error::new( + item_impl.span(), + "#[erased_serialize_bytes] can only be used on an impl for a concrete type", + ) + .into_compile_error() + .into(); }; - - let name = &input.ident; - let witnesses = fields - .named - .iter() - .map(|f| { - let name = f.ident.clone(); - match &f.ty { - syn::Type::Array(_) => quote!(self.#name.iter()), - _ => quote!(std::iter::once(&self.#name)), - } - }) - .collect::>(); - - let packed_field_vars = generic_vars_with_trait(&input.generics, "PackedField"); - assert_eq!(packed_field_vars.len(), 1, "Only a single packed field is supported for now"); - let p = packed_field_vars.first(); - let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl(); - quote! { - impl #impl_generics #name #ty_generics #where_clause { - pub fn iter_polys(&self) -> impl Iterator> { - std::iter::empty() - #(.chain(#witnesses))* - .map(|values| binius_math::MultilinearExtension::from_values_slice(values.as_slice()).unwrap()) - } + let name = p.path.segments.last().unwrap().ident.to_string(); + item_impl.items.push(syn::ImplItem::Fn(parse_quote! { + fn erased_serialize( + &self, + write_buf: &mut dyn binius_utils::bytes::BufMut, + mode: binius_utils::SerializationMode, + ) -> Result<(), binius_utils::SerializationError> { + binius_utils::SerializeBytes::serialize(&#name, &mut *write_buf, mode)?; + binius_utils::SerializeBytes::serialize(self, &mut *write_buf, mode) } + })); + quote! { + #item_impl } .into() } -/// This will accept the generics definition of a struct (relevant for derive macros), -/// and return all the generic vars that are constrained by a specific trait identifier. -/// ``` -/// use binius_field::{PackedField, Field}; -/// struct Example(A, B, C); -/// ``` -/// In the above example, when matching against the trait_name "PackedField", -/// the identifiers A and B will be returned, but not C -pub(crate) fn generic_vars_with_trait( - vars: &syn::Generics, - trait_name: &str, -) -> BTreeSet { - vars.params - .iter() - .filter_map(|param| match param { - syn::GenericParam::Type(type_param) => { - let is_bounded_by_trait_name = type_param.bounds.iter().any(|bound| match bound { - syn::TypeParamBound::Trait(trait_bound) => { - if let Some(last_segment) = trait_bound.path.segments.last() { - last_segment.ident == trait_name - } else { - false - } - } - _ => false, - }); - is_bounded_by_trait_name.then(|| type_param.ident.clone()) - } - syn::GenericParam::Const(_) | syn::GenericParam::Lifetime(_) => None, - }) - .collect() +fn field_names(fields: Fields, positional_prefix: Option<&str>) -> Vec { + match fields { + Fields::Named(fields) => fields + .named + .into_iter() + .map(|field| field.ident.into_token_stream()) + .collect(), + Fields::Unnamed(fields) => fields + .unnamed + .into_iter() + .enumerate() + .map(|(i, _)| match positional_prefix { + Some(prefix) => { + quote::format_ident!("{}{}", prefix, syn::Index::from(i)).into_token_stream() + } + None => syn::Index::from(i).into_token_stream(), + }) + .collect(), + Fields::Unit => vec![], + } } diff --git a/crates/macros/tests/arithmetic_circuit.rs b/crates/macros/tests/arithmetic_circuit.rs index 895f6003b..c840e3b96 100644 --- a/crates/macros/tests/arithmetic_circuit.rs +++ b/crates/macros/tests/arithmetic_circuit.rs @@ -2,7 +2,7 @@ use binius_field::*; use binius_macros::arith_circuit_poly; -use binius_math::CompositionPolyOS; +use binius_math::CompositionPoly; use paste::paste; use rand::{rngs::StdRng, SeedableRng}; diff --git a/crates/math/Cargo.toml b/crates/math/Cargo.toml index d077173ab..d37e65281 100644 --- a/crates/math/Cargo.toml +++ b/crates/math/Cargo.toml @@ -9,6 +9,7 @@ workspace = true [dependencies] binius_field = { path = "../field" } +binius_macros = { path = "../macros" } binius_maybe_rayon = { path = "../maybe_rayon", default-features = false } binius_utils = { path = "../utils", default-features = false } auto_impl.workspace = true diff --git a/crates/math/src/arith_expr.rs b/crates/math/src/arith_expr.rs index 6607901f2..84cb05cc7 100644 --- a/crates/math/src/arith_expr.rs +++ b/crates/math/src/arith_expr.rs @@ -7,7 +7,8 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; -use binius_field::{Field, PackedField}; +use binius_field::{Field, PackedField, TowerField}; +use binius_macros::{DeserializeBytes, SerializeBytes}; use super::error::Error; @@ -16,7 +17,7 @@ use super::error::Error; /// Arithmetic expressions are trees, where the leaves are either constants or variables, and the /// non-leaf nodes are arithmetic operations, such as addition, multiplication, etc. They are /// specific representations of multivariate polynomials. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub enum ArithExpr { Const(F), Var(usize), @@ -136,7 +137,7 @@ impl ArithExpr { &self, ) -> Result, >::Error> { Ok(match self { - Self::Const(val) => ArithExpr::Const((*val).try_into()?), + Self::Const(val) => ArithExpr::Const(FTgt::try_from(*val)?), Self::Var(index) => ArithExpr::Var(*index), Self::Add(left, right) => { let new_left = left.try_convert_field()?; @@ -197,6 +198,19 @@ impl ArithExpr { } } +impl ArithExpr { + pub fn binary_tower_level(&self) -> usize { + match self { + Self::Const(value) => value.min_tower_level(), + Self::Var(_) => 0, + Self::Add(left, right) | Self::Mul(left, right) => { + max(left.binary_tower_level(), right.binary_tower_level()) + } + Self::Pow(base, _) => base.binary_tower_level(), + } + } +} + impl Default for ArithExpr where F: Field, diff --git a/crates/math/src/composition_poly.rs b/crates/math/src/composition_poly.rs index 100c9d31c..8fc883f06 100644 --- a/crates/math/src/composition_poly.rs +++ b/crates/math/src/composition_poly.rs @@ -3,16 +3,14 @@ use std::fmt::Debug; use auto_impl::auto_impl; -use binius_field::{ExtensionField, Field, PackedField}; +use binius_field::PackedField; use stackalloc::stackalloc_with_default; use crate::{ArithExpr, Error}; /// A multivariate polynomial that is used as a composition of several multilinear polynomials. -/// -/// This is an object-safe version of the [`CompositionPoly`] trait. #[auto_impl(Arc, &)] -pub trait CompositionPolyOS

: Debug + Send + Sync +pub trait CompositionPoly

: Debug + Send + Sync where P: PackedField, { @@ -65,23 +63,3 @@ where }) } } - -/// A generic version of the `CompositionPolyOS` trait that is not object-safe. -#[auto_impl(&)] -pub trait CompositionPoly: Debug + Send + Sync { - fn n_vars(&self) -> usize; - - fn degree(&self) -> usize; - - fn binary_tower_level(&self) -> usize; - - fn expression>(&self) -> ArithExpr; - - fn evaluate>>(&self, query: &[P]) -> Result; - - fn batch_evaluate>>( - &self, - batch_query: &[&[P]], - evals: &mut [P], - ) -> Result<(), Error>; -} diff --git a/crates/math/src/deinterleave.rs b/crates/math/src/deinterleave.rs index 4f01b437d..1a5972906 100644 --- a/crates/math/src/deinterleave.rs +++ b/crates/math/src/deinterleave.rs @@ -30,14 +30,11 @@ pub fn deinterleave( } let deinterleaved = (0..1 << (log_scalar_count - P::LOG_WIDTH)).map(|i| { - let mut even = interleaved[2 * i]; - let mut odd = interleaved[2 * i + 1]; - - for log_block_len in (0..P::LOG_WIDTH).rev() { - let (even_interleaved, odd_interleaved) = even.interleave(odd, log_block_len); - even = even_interleaved; - odd = odd_interleaved; - } + let (even, odd) = if P::LOG_WIDTH > 0 { + P::unzip(interleaved[2 * i], interleaved[2 * i + 1], 0) + } else { + (interleaved[2 * i], interleaved[2 * i + 1]) + }; (i, even, odd) }); diff --git a/crates/math/src/error.rs b/crates/math/src/error.rs index adfc1f747..bb53b885b 100644 --- a/crates/math/src/error.rs +++ b/crates/math/src/error.rs @@ -10,6 +10,8 @@ pub enum Error { MatrixNotSquare, #[error("the matrix is singular")] MatrixIsSingular, + #[error("domain size needs to be at least one")] + DomainSizeAtLeastOne, #[error("domain size is larger than the field")] DomainSizeTooLarge, #[error("the inputted packed values slice had an unexpected length")] diff --git a/crates/math/src/fold.rs b/crates/math/src/fold.rs index a7002a50e..a35271a30 100644 --- a/crates/math/src/fold.rs +++ b/crates/math/src/fold.rs @@ -4,22 +4,18 @@ use core::slice; use std::{any::TypeId, cmp::min, mem::MaybeUninit}; use binius_field::{ - arch::{byte_sliced::ByteSlicedAES32x128b, ArchOptimal, OptimalUnderlier}, - packed::{get_packed_slice, set_packed_slice_unchecked}, + arch::{ArchOptimal, OptimalUnderlier}, + byte_iteration::{ + can_iterate_bytes, create_partial_sums_lookup_tables, is_sequential_bytes, iterate_bytes, + ByteIteratorCallback, PackedSlice, + }, + packed::{get_packed_slice, get_packed_slice_unchecked, set_packed_slice_unchecked}, underlier::{UnderlierWithBitOps, WithUnderlier}, - AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ByteSlicedAES32x16b, - ByteSlicedAES32x32b, ByteSlicedAES32x64b, ByteSlicedAES32x8b, ExtensionField, Field, - PackedBinaryField128x1b, PackedBinaryField16x1b, PackedBinaryField256x1b, - PackedBinaryField32x1b, PackedBinaryField512x1b, PackedBinaryField64x1b, PackedBinaryField8x1b, - PackedField, -}; -use binius_maybe_rayon::{ - iter::{IndexedParallelIterator, ParallelIterator}, - slice::ParallelSliceMut, + AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ExtensionField, + Field, PackedField, }; use binius_utils::bail; -use bytemuck::{fill_zeroes, Pod}; -use itertools::max; +use bytemuck::fill_zeroes; use lazy_static::lazy_static; use stackalloc::helpers::slice_assume_init_mut; @@ -29,6 +25,9 @@ use crate::Error; /// /// Every consequent `1 << log_query_size` scalar values are dot-producted with the corresponding /// query elements. The result is stored in the `output` slice of packed values. +/// +/// Please note that this method is single threaded. Currently we always have some +/// parallelism above this level, so it's not a problem. pub fn fold_right( evals: &[P], log_evals_size: usize, @@ -49,7 +48,16 @@ where return Ok(()); } - fold_right_fallback(evals, log_evals_size, query, log_query_size, out); + // Use linear interpolation for single variable multilinear queries. + let is_lerp = log_query_size == 1 + && get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE; + + if is_lerp { + let lerp_query = get_packed_slice(query, 1); + fold_right_lerp(evals, log_evals_size, lerp_query, out); + } else { + fold_right_fallback(evals, log_evals_size, query, log_query_size, out); + } Ok(()) } @@ -60,7 +68,7 @@ where /// with the corresponding query element. The results is written to the `output` slice of packed values. /// If the function returns `Ok(())`, then `out` can be safely interpreted as initialized. /// -/// Please note that unlike `fold_right`, this method is single threaded. Currently we always have some +/// Please note that this method is single threaded. Currently we always have some /// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to /// use more efficient optimizations for special cases. If we ever need a parallel version of this /// function, we can implement it separately. @@ -129,115 +137,6 @@ where Ok(()) } -/// A marker trait that the slice of packed values can be iterated as a sequence of bytes. -/// The order of the iteration by BinaryField1b subfield elements and bits within iterated bytes must -/// be the same. -/// -/// # Safety -/// The implementor must ensure that the cast of the slice of packed values to the slice of bytes -/// is safe and preserves the order of the 1-bit elements. -#[allow(unused)] -unsafe trait SequentialBytes: Pod {} - -unsafe impl SequentialBytes for PackedBinaryField8x1b {} -unsafe impl SequentialBytes for PackedBinaryField16x1b {} -unsafe impl SequentialBytes for PackedBinaryField32x1b {} -unsafe impl SequentialBytes for PackedBinaryField64x1b {} -unsafe impl SequentialBytes for PackedBinaryField128x1b {} -unsafe impl SequentialBytes for PackedBinaryField256x1b {} -unsafe impl SequentialBytes for PackedBinaryField512x1b {} - -/// Returns true if T implements `SequentialBytes` trait. -/// Use a hack that exploits that array copying is optimized for the `Copy` types. -/// Unfortunately there is no more proper way to perform this check this in Rust at runtime. -#[allow(clippy::redundant_clone)] -fn is_sequential_bytes() -> bool { - struct X(bool, std::marker::PhantomData); - - impl Clone for X { - fn clone(&self) -> Self { - Self(false, std::marker::PhantomData) - } - } - - impl Copy for X {} - - let value = [X::(true, std::marker::PhantomData)]; - let cloned = value.clone(); - - cloned[0].0 -} - -/// Returns if we can iterate over bytes, each representing 8 1-bit values. -fn can_iterate_bytes() -> bool { - // Packed fields with sequential byte order - if is_sequential_bytes::

() { - return true; - } - - // Byte-sliced fields - // Note: add more byte sliced types here as soon as they are added - match TypeId::of::

() { - x if x == TypeId::of::() => true, - x if x == TypeId::of::() => true, - x if x == TypeId::of::() => true, - x if x == TypeId::of::() => true, - x if x == TypeId::of::() => true, - _ => false, - } -} - -/// Helper macro to generate the iteration over bytes for byte-sliced types. -macro_rules! iterate_byte_sliced { - ($packed_type:ty, $data:ident, $f:ident) => { - assert_eq!(TypeId::of::<$packed_type>(), TypeId::of::

()); - - // Safety: the cast is safe because the type is checked by arm statement - let data = - unsafe { slice::from_raw_parts($data.as_ptr() as *const $packed_type, $data.len()) }; - for value in data.iter() { - for i in 0..<$packed_type>::BYTES { - // Safety: j is less than `ByteSlicedAES32x128b::BYTES` - $f(unsafe { value.get_byte_unchecked(i) }); - } - } - }; -} - -/// Iterate over bytes of a slice of the packed values. -fn iterate_bytes(data: &[P], mut f: impl FnMut(u8)) { - if is_sequential_bytes::

() { - // Safety: `P` implements `SequentialBytes` trait, so the following cast is safe - // and preserves the order. - let bytes = unsafe { - std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) - }; - for byte in bytes { - f(*byte); - } - } else { - // Note: add more byte sliced types here as soon as they are added - match TypeId::of::

() { - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x128b, data, f); - } - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x64b, data, f); - } - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x32b, data, f); - } - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x16b, data, f); - } - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x8b, data, f); - } - _ => unreachable!("packed field doesn't support byte iteration"), - } - } -} - /// Optimized version for 1-bit values with query size 0-2 fn fold_right_1bit_evals_small_query( evals: &[P], @@ -251,18 +150,8 @@ where if LOG_QUERY_SIZE >= 3 { return false; } - let chunk_size = 1 - << max(&[ - 10, - (P::LOG_WIDTH + LOG_QUERY_SIZE).saturating_sub(PE::LOG_WIDTH), - PE::LOG_WIDTH, - ]) - .unwrap(); - if out.len() % chunk_size != 0 { - return false; - } - if P::WIDTH << LOG_QUERY_SIZE > chunk_size << PE::LOG_WIDTH { + if P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH { return false; } @@ -279,33 +168,43 @@ where }) .collect::>(); - out.par_chunks_mut(chunk_size) - .enumerate() - .for_each(|(index, chunk)| { - let input_offset = - ((index * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; - let input_end = - (((index + 1) * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; + struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> { + out: &'a mut [PE], + cached_table: &'a [PE::Scalar], + } + impl ByteIteratorCallback + for Callback<'_, PE, LOG_QUERY_SIZE> + { + #[inline(always)] + fn call(&mut self, iterator: impl Iterator) { + let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1; + let values_in_byte = 1 << (3 - LOG_QUERY_SIZE); let mut current_index = 0; - iterate_bytes(&evals[input_offset..input_end], |byte| { - let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1; - let values_in_byte = 1 << (3 - LOG_QUERY_SIZE); + for byte in iterator { for k in 0..values_in_byte { let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask; // Safety: `i` is less than `chunk_size` unsafe { set_packed_slice_unchecked( - chunk, + self.out, current_index + k, - cached_table[index as usize], + self.cached_table[index as usize], ); } } current_index += values_in_byte; - }); - }); + } + } + } + + let mut callback = Callback::<'_, PE, LOG_QUERY_SIZE> { + out, + cached_table: &cached_table, + }; + + iterate_bytes(evals, &mut callback); true } @@ -323,61 +222,52 @@ where if LOG_QUERY_SIZE < 3 { return false; } - let chunk_size = 1 - << max(&[ - 10, - (P::LOG_WIDTH + LOG_QUERY_SIZE).saturating_sub(PE::LOG_WIDTH), - PE::LOG_WIDTH, - ]) - .unwrap(); - if out.len() % chunk_size != 0 { + + if P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH { return false; } - let log_tables_count = LOG_QUERY_SIZE - 3; - let tables_count = 1 << log_tables_count; - let cached_tables = (0..tables_count) - .map(|i| { - (0..256) - .map(|j| { - let mut result = PE::Scalar::ZERO; - for k in 0..8 { - if j >> k & 1 == 1 { - result += get_packed_slice(query, (i << 3) | k); - } - } - result - }) - .collect::>() - }) - .collect::>(); + let cached_tables = + create_partial_sums_lookup_tables(PackedSlice::new(query, 1 << LOG_QUERY_SIZE)); - out.par_chunks_mut(chunk_size) - .enumerate() - .for_each(|(index, chunk)| { - let input_offset = - ((index * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; - let input_end = - (((index + 1) * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; + struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> { + out: &'a mut [PE], + cached_tables: &'a [PE::Scalar], + } - let mut current_value = PE::Scalar::ZERO; - let mut current_table = 0; + impl ByteIteratorCallback + for Callback<'_, PE, LOG_QUERY_SIZE> + { + #[inline(always)] + fn call(&mut self, iterator: impl Iterator) { + let log_tables_count = LOG_QUERY_SIZE - 3; + let tables_count = 1 << log_tables_count; let mut current_index = 0; - iterate_bytes(&evals[input_offset..input_end], |byte| { - current_value += cached_tables[current_table][byte as usize]; + let mut current_table = 0; + let mut current_value = PE::Scalar::ZERO; + for byte in iterator { + current_value += self.cached_tables[(current_table << 8) + byte as usize]; current_table += 1; if current_table == tables_count { // Safety: `i` is less than `chunk_size` unsafe { - set_packed_slice_unchecked(chunk, current_index, current_value); + set_packed_slice_unchecked(self.out, current_index, current_value); } - current_table = 0; current_index += 1; + current_table = 0; current_value = PE::Scalar::ZERO; } - }); - }); + } + } + } + + let mut callback = Callback::<'_, _, LOG_QUERY_SIZE> { + out, + cached_tables: &cached_tables, + }; + + iterate_bytes(evals, &mut callback); true } @@ -419,6 +309,46 @@ where } } +/// Specialized implementation for a single parameter right fold using linear interpolation +/// instead of tensor expansion resulting in a single multiplication instead of two: +/// f(r||w) = r * (f(1||w) - f(0||w)) + f(0||w). +/// +/// The same approach may be generalized to higher variable counts, with diminishing returns. +fn fold_right_lerp( + evals: &[P], + log_evals_size: usize, + lerp_query: PE::Scalar, + out: &mut [PE], +) where + P: PackedField, + PE: PackedField>, +{ + assert_eq!(1 << log_evals_size.saturating_sub(PE::LOG_WIDTH + 1), out.len()); + + out.iter_mut() + .enumerate() + .for_each(|(i, packed_result_eval)| { + for j in 0..min(PE::WIDTH, 1 << (log_evals_size - 1)) { + let index = (i << PE::LOG_WIDTH) | j; + + let (eval0, eval1) = unsafe { + ( + get_packed_slice_unchecked(evals, index << 1), + get_packed_slice_unchecked(evals, (index << 1) | 1), + ) + }; + + let result_eval = + PE::Scalar::from(eval1 - eval0) * lerp_query + PE::Scalar::from(eval0); + + // Safety: `j` < `PE::WIDTH` + unsafe { + packed_result_eval.set_unchecked(j, result_eval); + } + } + }) +} + /// Fallback implementation for fold that can be executed for any field types and sizes. fn fold_right_fallback( evals: &[P], @@ -430,34 +360,26 @@ fn fold_right_fallback( P: PackedField, PE: PackedField>, { - const CHUNK_SIZE: usize = 1 << 10; - let packed_result_evals = out; - packed_result_evals - .par_chunks_mut(CHUNK_SIZE) - .enumerate() - .for_each(|(i, packed_result_evals)| { - for (k, packed_result_eval) in packed_result_evals.iter_mut().enumerate() { - let offset = i * CHUNK_SIZE; - for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) { - let index = ((offset + k) << PE::LOG_WIDTH) | j; - - let offset = index << log_query_size; - - let mut result_eval = PE::Scalar::ZERO; - for (t, query_expansion) in PackedField::iter_slice(query) - .take(1 << log_query_size) - .enumerate() - { - result_eval += query_expansion * get_packed_slice(evals, t + offset); - } + for (k, packed_result_eval) in out.iter_mut().enumerate() { + for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) { + let index = (k << PE::LOG_WIDTH) | j; + + let offset = index << log_query_size; + + let mut result_eval = PE::Scalar::ZERO; + for (t, query_expansion) in PackedField::iter_slice(query) + .take(1 << log_query_size) + .enumerate() + { + result_eval += query_expansion * get_packed_slice(evals, t + offset); + } - // Safety: `j` < `PE::WIDTH` - unsafe { - packed_result_eval.set_unchecked(j, result_eval); - } - } + // Safety: `j` < `PE::WIDTH` + unsafe { + packed_result_eval.set_unchecked(j, result_eval); } - }); + } + } } type ArchOptimaType = ::OptimalThroughputPacked; @@ -685,27 +607,13 @@ mod tests { use std::iter::repeat_with; use binius_field::{ - packed::set_packed_slice, PackedBinaryField16x32b, PackedBinaryField16x8b, - PackedBinaryField4x1b, PackedBinaryField512x1b, + packed::set_packed_slice, PackedBinaryField128x1b, PackedBinaryField16x32b, + PackedBinaryField16x8b, PackedBinaryField512x1b, PackedBinaryField64x8b, }; use rand::{rngs::StdRng, SeedableRng}; use super::*; - #[test] - fn test_sequential_bits() { - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - - assert!(!is_sequential_bytes::()); - assert!(!is_sequential_bytes::()); - } - fn fold_right_reference( evals: &[P], log_evals_size: usize, @@ -782,7 +690,9 @@ mod tests { let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng)) .take(1 << LOG_EVALS_SIZE) .collect::>(); - let query = vec![PackedBinaryField512x1b::random(&mut rng)]; + let query = repeat_with(|| PackedBinaryField64x8b::random(&mut rng)) + .take(8) + .collect::>(); for log_query_size in 0..10 { check_fold_right( diff --git a/crates/math/src/matrix.rs b/crates/math/src/matrix.rs index 674087d97..408b79863 100644 --- a/crates/math/src/matrix.rs +++ b/crates/math/src/matrix.rs @@ -186,8 +186,6 @@ impl Matrix { } fn scale_row(&mut self, i: usize, scalar: F) { - assert!(i < self.m); - for x in self.row_mut(i) { *x *= scalar; } diff --git a/crates/math/src/mle_adapters.rs b/crates/math/src/mle_adapters.rs index 079fc4ddd..47c227cd1 100644 --- a/crates/math/src/mle_adapters.rs +++ b/crates/math/src/mle_adapters.rs @@ -8,7 +8,6 @@ use binius_field::{ }, ExtensionField, Field, PackedField, RepackedExtension, }; -use binius_maybe_rayon::prelude::*; use binius_utils::bail; use super::{Error, MultilinearExtension, MultilinearPoly, MultilinearQueryRef}; @@ -274,11 +273,9 @@ where P: PackedField, Data: Deref + Send + Sync + Debug + 'a, { - pub fn specialize_arc_dyn(self) -> Arc + Send + Sync + 'a> - where - PE: PackedField + RepackedExtension

, - PE::Scalar: ExtensionField, - { + pub fn specialize_arc_dyn>( + self, + ) -> Arc + Send + Sync + 'a> { self.specialize().upcast_arc_dyn() } } @@ -299,44 +296,6 @@ where pub fn upcast_arc_dyn(self) -> Arc + Send + Sync + 'a> { Arc::new(self) } - - /// Given a ($mu$-variate) multilinear function $f$ and an element $r$, - /// return the multilinear function $f(r, X_1, ..., X_{\mu - 1})$. - pub fn evaluate_zeroth_variable(&self, r: P::Scalar) -> Result, Error> { - let multilin = &self.0; - let mu = multilin.n_vars(); - if mu == 0 { - bail!(Error::ConstantFold); - } - let packed_length = 1 << mu.saturating_sub(P::LOG_WIDTH + 1); - // in general, the formula is: f(r||w) = r * (f(1||w) - f(0||w)) + f(0||w). - let result = (0..packed_length) - .into_par_iter() - .map(|i| { - let eval0_minus_eval1 = P::from_fn(|j| { - let index = (i << P::LOG_WIDTH) | j; - // necessary if `mu_minus_one` < `P::LOG_WIDTH` - if index >= 1 << (mu - 1) { - return P::Scalar::ZERO; - } - let eval0 = get_packed_slice(multilin.evals(), index << 1); - let eval1 = get_packed_slice(multilin.evals(), (index << 1) | 1); - eval0 - eval1 - }); - let eval0 = P::from_fn(|j| { - let index = (i << P::LOG_WIDTH) | j; - // necessary if `mu_minus_one` < `P::LOG_WIDTH` - if index >= 1 << (mu - 1) { - return P::Scalar::ZERO; - } - get_packed_slice(multilin.evals(), index << 1) - }); - eval0_minus_eval1 * r + eval0 - }) - .collect::>(); - - MultilinearExtension::new(mu - 1, result) - } } impl From> for MLEDirectAdapter @@ -699,23 +658,4 @@ mod tests { .unwrap(); assert_eq!(evals_out, poly.packed_evals().unwrap()); } - - #[test] - fn test_evaluate_zeroth_evaluate_partial_low_consistent() { - let mut rng = StdRng::seed_from_u64(0); - let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng)) - .take(1 << 8) - .collect(); - - let me = MultilinearExtension::from_values(values).unwrap(); - let mled = MLEDirectAdapter::from(me); - let r = ::random(&mut rng); - - let eval_1: MultilinearExtension = - mled.evaluate_zeroth_variable(r).unwrap(); - let eval_2 = mled - .evaluate_partial_low(multilinear_query(&[r]).to_ref()) - .unwrap(); - assert_eq!(eval_1, eval_2); - } } diff --git a/crates/math/src/multilinear_extension.rs b/crates/math/src/multilinear_extension.rs index 91c379c83..8403476ee 100644 --- a/crates/math/src/multilinear_extension.rs +++ b/crates/math/src/multilinear_extension.rs @@ -245,6 +245,7 @@ where PE::Scalar: ExtensionField, { let query = query.into(); + if self.mu < query.n_vars() { bail!(Error::IncorrectQuerySize { expected: self.mu }); } @@ -275,15 +276,6 @@ where PE: PackedField, PE::Scalar: ExtensionField, { - if self.mu < query.n_vars() { - bail!(Error::IncorrectQuerySize { expected: self.mu }); - } - if out.len() != 1 << ((self.mu - query.n_vars()).saturating_sub(PE::LOG_WIDTH)) { - bail!(Error::IncorrectOutputPolynomialSize { - expected: self.mu - query.n_vars(), - }); - } - // This operation is a matrix-vector product of the matrix of multilinear coefficients with // the vector of tensor product-expanded query coefficients. fold_right(&self.evals, self.mu, query.expansion(), query.n_vars(), out) @@ -559,6 +551,28 @@ mod tests { ); } + #[test] + fn test_evaluate_partial_low_single_and_multiple_var_consistent() { + let mut rng = StdRng::seed_from_u64(0); + let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng)) + .take(1 << 8) + .collect(); + + let mle = MultilinearExtension::from_values(values).unwrap(); + let r1 = ::random(&mut rng); + let r2 = ::random(&mut rng); + + let eval_1: MultilinearExtension = mle + .evaluate_partial_low::(multilinear_query(&[r1]).to_ref()) + .unwrap() + .evaluate_partial_low(multilinear_query(&[r2]).to_ref()) + .unwrap(); + let eval_2 = mle + .evaluate_partial_low(multilinear_query(&[r1, r2]).to_ref()) + .unwrap(); + assert_eq!(eval_1, eval_2); + } + #[test] fn test_new_mle_with_tiny_nvars() { MultilinearExtension::new( diff --git a/crates/math/src/multilinear_query.rs b/crates/math/src/multilinear_query.rs index fc27a689a..9593367ac 100644 --- a/crates/math/src/multilinear_query.rs +++ b/crates/math/src/multilinear_query.rs @@ -36,7 +36,7 @@ impl<'a, P: PackedField, Data: DerefMut> From<&'a MultilinearQuery for MultilinearQueryRef<'a, P> { fn from(query: &'a MultilinearQuery) -> Self { - MultilinearQueryRef::new(query) + Self::new(query) } } @@ -73,7 +73,7 @@ impl MultilinearQuery> { } pub fn expand(query: &[P::Scalar]) -> Self { - let expanded_query = eq_ind_partial_eval::

(query); + let expanded_query = eq_ind_partial_eval(query); Self { expanded_query, n_vars: query.len(), @@ -148,12 +148,16 @@ impl> MultilinearQuery { #[cfg(test)] mod tests { - use binius_field::{Field, PackedField}; + use binius_field::{Field, PackedBinaryField4x32b, PackedField}; use binius_utils::felts; + use itertools::Itertools; use super::*; use crate::tensor_prod_eq_ind; + type P = PackedBinaryField4x32b; + type F =

::Scalar; + fn tensor_prod(p: &[P::Scalar]) -> Vec

{ let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)]; result[0] = P::set_single(P::Scalar::ONE); @@ -252,4 +256,98 @@ mod tests { felts!(BinaryField16b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2]) ); } + + #[test] + fn test_update_single_var() { + let query = MultilinearQuery::

::with_capacity(2); + let r0 = F::new(2); + let extra_query = [r0]; + + let updated_query = query.update(&extra_query).unwrap(); + + assert_eq!(updated_query.n_vars(), 1); + + let expansion = updated_query.into_expansion(); + let expansion = PackedField::iter_slice(&expansion).collect_vec(); + + assert_eq!(expansion, vec![(F::ONE - r0), r0, F::ZERO, F::ZERO]); + } + + #[test] + fn test_update_two_vars() { + let query = MultilinearQuery::

::with_capacity(3); + let r0 = F::new(2); + let r1 = F::new(3); + let extra_query = [r0, r1]; + + let updated_query = query.update(&extra_query).unwrap(); + assert_eq!(updated_query.n_vars(), 2); + + let expansion = updated_query.expansion(); + let expansion = PackedField::iter_slice(expansion).collect_vec(); + + assert_eq!( + expansion, + vec![ + (F::ONE - r0) * (F::ONE - r1), + r0 * (F::ONE - r1), + (F::ONE - r0) * r1, + r0 * r1, + ] + ); + } + + #[test] + fn test_update_three_vars() { + let query = MultilinearQuery::

::with_capacity(4); + let r0 = F::new(2); + let r1 = F::new(3); + let r2 = F::new(5); + let extra_query = [r0, r1, r2]; + + let updated_query = query.update(&extra_query).unwrap(); + assert_eq!(updated_query.n_vars(), 3); + + let expansion = updated_query.expansion(); + let expansion = PackedField::iter_slice(expansion).collect_vec(); + + assert_eq!( + expansion, + vec![ + (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2), + r0 * (F::ONE - r1) * (F::ONE - r2), + (F::ONE - r0) * r1 * (F::ONE - r2), + r0 * r1 * (F::ONE - r2), + (F::ONE - r0) * (F::ONE - r1) * r2, + r0 * (F::ONE - r1) * r2, + (F::ONE - r0) * r1 * r2, + r0 * r1 * r2, + ] + ); + } + + #[test] + fn test_update_exceeds_capacity() { + let query = MultilinearQuery::

::with_capacity(2); + // More than allowed capacity + let extra_query = [F::new(2), F::new(3), F::new(5)]; + + let result = query.update(&extra_query); + // Expecting an error due to exceeding max_query_vars + assert!(result.is_err()); + } + + #[test] + fn test_update_empty() { + let query = MultilinearQuery::

::with_capacity(2); + // Updating with no new coordinates should be fine + let updated_query = query.update(&[]).unwrap(); + + assert_eq!(updated_query.n_vars(), 0); + + let expansion = updated_query.expansion(); + let expansion = PackedField::iter_slice(expansion).collect_vec(); + + assert_eq!(expansion, vec![F::ONE, F::ZERO, F::ZERO, F::ZERO]); + } } diff --git a/crates/math/src/tensor_prod_eq_ind.rs b/crates/math/src/tensor_prod_eq_ind.rs index 8bcba36fe..ed48bb1de 100644 --- a/crates/math/src/tensor_prod_eq_ind.rs +++ b/crates/math/src/tensor_prod_eq_ind.rs @@ -62,7 +62,7 @@ pub fn tensor_prod_eq_ind( xs.par_iter_mut() .zip(ys.par_iter_mut()) .with_min_len(64) - .for_each(|(x, y): (&mut P, &mut P)| { + .for_each(|(x, y)| { // x = x * (1 - packed_r_i) = x - x * packed_r_i // y = x * packed_r_i // Notice that we can reuse the multiplication: (x * packed_r_i) @@ -95,8 +95,7 @@ pub fn eq_ind_partial_eval(point: &[P::Scalar]) -> Vec

{ let len = 1 << n.saturating_sub(P::LOG_WIDTH); let mut buffer = zeroed_vec::

(len); buffer[0].set(0, P::Scalar::ONE); - tensor_prod_eq_ind(0, &mut buffer[..], point) - .expect("buffer is allocated with the correct length"); + tensor_prod_eq_ind(0, &mut buffer, point).expect("buffer is allocated with the correct length"); buffer } @@ -107,10 +106,11 @@ mod tests { use super::*; + type P = PackedBinaryField4x32b; + type F =

::Scalar; + #[test] fn test_tensor_prod_eq_ind() { - type P = PackedBinaryField4x32b; - type F =

::Scalar; let v0 = F::new(1); let v1 = F::new(2); let query = vec![v0, v1]; @@ -128,4 +128,59 @@ mod tests { ] ); } + + #[test] + fn test_eq_ind_partial_eval_empty() { + let result = eq_ind_partial_eval::

(&[]); + let expected = vec![P::set_single(F::ONE)]; + assert_eq!(result, expected); + } + + #[test] + fn test_eq_ind_partial_eval_single_var() { + // Only one query coordinate + let r0 = F::new(2); + let result = eq_ind_partial_eval::

(&[r0]); + let expected = vec![(F::ONE - r0), r0, F::ZERO, F::ZERO]; + let result = PackedField::iter_slice(&result).collect_vec(); + assert_eq!(result, expected); + } + + #[test] + fn test_eq_ind_partial_eval_two_vars() { + // Two query coordinates + let r0 = F::new(2); + let r1 = F::new(3); + let result = eq_ind_partial_eval::

(&[r0, r1]); + let result = PackedField::iter_slice(&result).collect_vec(); + let expected = vec![ + (F::ONE - r0) * (F::ONE - r1), + r0 * (F::ONE - r1), + (F::ONE - r0) * r1, + r0 * r1, + ]; + assert_eq!(result, expected); + } + + #[test] + fn test_eq_ind_partial_eval_three_vars() { + // Case with three query coordinates + let r0 = F::new(2); + let r1 = F::new(3); + let r2 = F::new(5); + let result = eq_ind_partial_eval::

(&[r0, r1, r2]); + let result = PackedField::iter_slice(&result).collect_vec(); + + let expected = vec![ + (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2), + r0 * (F::ONE - r1) * (F::ONE - r2), + (F::ONE - r0) * r1 * (F::ONE - r2), + r0 * r1 * (F::ONE - r2), + (F::ONE - r0) * (F::ONE - r1) * r2, + r0 * (F::ONE - r1) * r2, + (F::ONE - r0) * r1 * r2, + r0 * r1 * r2, + ]; + assert_eq!(result, expected); + } } diff --git a/crates/math/src/univariate.rs b/crates/math/src/univariate.rs index b99d26067..d2339d213 100644 --- a/crates/math/src/univariate.rs +++ b/crates/math/src/univariate.rs @@ -7,6 +7,7 @@ use binius_field::{ PackedField, }; use binius_utils::bail; +use itertools::{izip, Either}; use super::{binary_subspace::BinarySubspace, error::Error}; use crate::Matrix; @@ -17,8 +18,9 @@ use crate::Matrix; /// to reconstruct a degree <= d. This struct supports Barycentric extrapolation. #[derive(Debug, Clone)] pub struct EvaluationDomain { - points: Vec, + finite_points: Vec, weights: Vec, + with_infinity: bool, } /// An extended version of `EvaluationDomain` that supports interpolation to monomial form. Takes @@ -32,9 +34,20 @@ pub struct InterpolationDomain { /// Wraps type information to enable instantiating EvaluationDomains. #[auto_impl(&)] pub trait EvaluationDomainFactory: Clone + Sync { - /// Instantiates an EvaluationDomain from a set of points isomorphic to direct - /// lexicographic successors of zero in Fan-Paar tower - fn create(&self, size: usize) -> Result, Error>; + /// Instantiates an EvaluationDomain from `size` lexicographically first values from the + /// binary subspace. + fn create(&self, size: usize) -> Result, Error> { + self.create_with_infinity(size, false) + } + + /// Instantiates an EvaluationDomain from `size` values in total: lexicographically first values + /// from the binary subspace and potentially Karatsuba "infinity" point (which is the coefficient of + /// the highest power in the interpolated polynomial). + fn create_with_infinity( + &self, + size: usize, + with_infinity: bool, + ) -> Result, Error>; } #[derive(Default, Clone)] @@ -48,8 +61,18 @@ pub struct IsomorphicEvaluationDomainFactory { } impl EvaluationDomainFactory for DefaultEvaluationDomainFactory { - fn create(&self, size: usize) -> Result, Error> { - EvaluationDomain::from_points(make_evaluation_points(&self.subspace, size)?) + fn create_with_infinity( + &self, + size: usize, + with_infinity: bool, + ) -> Result, Error> { + if size == 0 && with_infinity { + bail!(Error::DomainSizeAtLeastOne); + } + EvaluationDomain::from_points( + make_evaluation_points(&self.subspace, size - if with_infinity { 1 } else { 0 })?, + with_infinity, + ) } } @@ -58,9 +81,17 @@ where FSrc: BinaryField, FTgt: Field + From + BinaryField, { - fn create(&self, size: usize) -> Result, Error> { - let points = make_evaluation_points(&self.subspace, size)?; - EvaluationDomain::from_points(points.into_iter().map(Into::into).collect()) + fn create_with_infinity( + &self, + size: usize, + with_infinity: bool, + ) -> Result, Error> { + if size == 0 && with_infinity { + bail!(Error::DomainSizeAtLeastOne); + } + let points = + make_evaluation_points(&self.subspace, size - if with_infinity { 1 } else { 0 })?; + EvaluationDomain::from_points(points.into_iter().map(Into::into).collect(), false) } } @@ -78,7 +109,8 @@ fn make_evaluation_points( impl From> for InterpolationDomain { fn from(evaluation_domain: EvaluationDomain) -> Self { let n = evaluation_domain.size(); - let evaluation_matrix = vandermonde(evaluation_domain.points()); + let evaluation_matrix = + vandermonde(evaluation_domain.finite_points(), evaluation_domain.with_infinity()); let mut interpolation_matrix = Matrix::zeros(n, n); evaluation_matrix .inverse_into(&mut interpolation_matrix) @@ -97,17 +129,25 @@ impl From> for InterpolationDomain { } impl EvaluationDomain { - pub fn from_points(points: Vec) -> Result { - let weights = compute_barycentric_weights(&points)?; - Ok(Self { points, weights }) + pub fn from_points(finite_points: Vec, with_infinity: bool) -> Result { + let weights = compute_barycentric_weights(&finite_points)?; + Ok(Self { + finite_points, + weights, + with_infinity, + }) } pub fn size(&self) -> usize { - self.points.len() + self.finite_points.len() + if self.with_infinity { 1 } else { 0 } } - pub fn points(&self) -> &[F] { - self.points.as_slice() + pub fn finite_points(&self) -> &[F] { + self.finite_points.as_slice() + } + + pub const fn with_infinity(&self) -> bool { + self.with_infinity } /// Compute a vector of Lagrange polynomial evaluations in $O(N)$ at a given point `x`. @@ -116,19 +156,23 @@ impl EvaluationDomain { /// are defined by /// $$L_i(x) = \sum_{j \neq i}\frac{x - \pi_j}{\pi_i - \pi_j}$$ pub fn lagrange_evals>(&self, x: FE) -> Vec { - let num_evals = self.size(); + let num_evals = self.finite_points().len(); let mut result: Vec = vec![FE::ONE; num_evals]; // Multiply the product suffixes for i in (1..num_evals).rev() { - result[i - 1] = result[i] * (x - self.points[i]); + result[i - 1] = result[i] * (x - self.finite_points[i]); } let mut prefix = FE::ONE; // Multiply the product prefixes and weights - for ((r, &point), &weight) in result.iter_mut().zip(&self.points).zip(&self.weights) { + for ((r, &point), &weight) in result + .iter_mut() + .zip(&self.finite_points) + .zip(&self.weights) + { *r *= prefix * weight; prefix *= x - point; } @@ -141,18 +185,26 @@ impl EvaluationDomain { where PE: PackedField>, { - let lagrange_eval_results = self.lagrange_evals(x); - - let n = self.size(); - if values.len() != n { + if values.len() != self.size() { bail!(Error::ExtrapolateNumberOfEvaluations); } - let result = lagrange_eval_results - .into_iter() - .zip(values) - .map(|(evaluation, &value)| value * evaluation) - .sum::(); + let (values_iter, infinity_term) = if self.with_infinity { + let (&value_at_infinity, finite_values) = + values.split_last().expect("values length checked above"); + let highest_degree = finite_values.len() as u64; + let iter = izip!(&self.finite_points, finite_values).map(move |(&point, &value)| { + value - value_at_infinity * PE::Scalar::from(point).pow(highest_degree) + }); + (Either::Left(iter), value_at_infinity * x.pow(highest_degree)) + } else { + (Either::Right(values.iter().copied()), PE::zero()) + }; + + let result = izip!(self.lagrange_evals(x), values_iter) + .map(|(lagrange_at_x, value)| value * lagrange_at_x) + .sum::() + + infinity_term; Ok(result) } @@ -163,20 +215,24 @@ impl InterpolationDomain { self.evaluation_domain.size() } - pub fn points(&self) -> &[F] { - self.evaluation_domain.points() + pub fn finite_points(&self) -> &[F] { + self.evaluation_domain.finite_points() } - pub fn extrapolate(&self, values: &[PE], x: PE::Scalar) -> Result - where - PE: PackedExtension>, - { + pub const fn with_infinity(&self) -> bool { + self.evaluation_domain.with_infinity() + } + + pub fn extrapolate>( + &self, + values: &[PE], + x: PE::Scalar, + ) -> Result { self.evaluation_domain.extrapolate(values, x) } pub fn interpolate>(&self, values: &[FE]) -> Result, Error> { - let n = self.evaluation_domain.size(); - if values.len() != n { + if values.len() != self.evaluation_domain.size() { bail!(Error::ExtrapolateNumberOfEvaluations); } @@ -188,11 +244,7 @@ impl InterpolationDomain { /// Extrapolates lines through a pair of packed fields at a single point from a subfield. #[inline] -pub fn extrapolate_line(x0: P, x1: P, z: FS) -> P -where - P: PackedExtension>, - FS: Field, -{ +pub fn extrapolate_line, FS: Field>(x0: P, x1: P, z: FS) -> P { x0 + mul_by_subfield_scalar(x1 - x0, z) } @@ -236,8 +288,8 @@ fn compute_barycentric_weights(points: &[F]) -> Result, Error> .collect() } -fn vandermonde(xs: &[F]) -> Matrix { - let n = xs.len(); +fn vandermonde(xs: &[F], with_infinity: bool) -> Matrix { + let n = xs.len() + if with_infinity { 1 } else { 0 }; let mut mat = Matrix::zeros(n, n); for (i, x_i) in xs.iter().copied().enumerate() { @@ -249,6 +301,11 @@ fn vandermonde(xs: &[F]) -> Matrix { mat[(i, j)] = acc; } } + + if with_infinity { + mat[(n - 1, n - 1)] = F::ONE; + } + mat } @@ -277,7 +334,7 @@ mod tests { fn test_new_domain() { let domain_factory = DefaultEvaluationDomainFactory::::default(); assert_eq!( - domain_factory.create(3).unwrap().points, + domain_factory.create(3).unwrap().finite_points, &[ BinaryField8b::new(0), BinaryField8b::new(1), @@ -292,7 +349,7 @@ mod tests { let iso_domain_factory = IsomorphicEvaluationDomainFactory::::default(); let domain_1: EvaluationDomain = default_domain_factory.create(10).unwrap(); let domain_2: EvaluationDomain = iso_domain_factory.create(10).unwrap(); - assert_eq!(domain_1.points, domain_2.points); + assert_eq!(domain_1.finite_points, domain_2.finite_points); } #[test] @@ -303,11 +360,11 @@ mod tests { let domain_2: EvaluationDomain = iso_domain_factory.create(10).unwrap(); assert_eq!( domain_1 - .points + .finite_points .into_iter() .map(AESTowerField32b::from) .collect::>(), - domain_2.points + domain_2.finite_points ); } @@ -343,6 +400,7 @@ mod tests { repeat_with(|| ::random(&mut rng)) .take(degree + 1) .collect(), + false, ) .unwrap(); @@ -351,7 +409,7 @@ mod tests { .collect::>(); let values = domain - .points() + .finite_points() .iter() .map(|&x| evaluate_univariate(&coeffs, x)) .collect::>(); @@ -370,6 +428,7 @@ mod tests { repeat_with(|| ::random(&mut rng)) .take(degree + 1) .collect(), + false, ) .unwrap(); @@ -378,10 +437,44 @@ mod tests { .collect::>(); let values = domain - .points() + .finite_points() + .iter() + .map(|&x| evaluate_univariate(&coeffs, x)) + .collect::>(); + + let interpolated = InterpolationDomain::from(domain) + .interpolate(&values) + .unwrap(); + assert_eq!(interpolated, coeffs); + } + + #[test] + fn test_infinity() { + let mut rng = StdRng::seed_from_u64(0); + let degree = 6; + + let domain = EvaluationDomain::from_points( + repeat_with(|| ::random(&mut rng)) + .take(degree) + .collect(), + true, + ) + .unwrap(); + + let coeffs = repeat_with(|| ::random(&mut rng)) + .take(degree + 1) + .collect::>(); + + let mut values = domain + .finite_points() .iter() .map(|&x| evaluate_univariate(&coeffs, x)) .collect::>(); + values.push(coeffs.last().copied().unwrap()); + + let x = ::random(&mut rng); + let expected_y = evaluate_univariate(&coeffs, x); + assert_eq!(domain.extrapolate(&values, x).unwrap(), expected_y); let interpolated = InterpolationDomain::from(domain) .interpolate(&values) diff --git a/crates/ntt/src/additive_ntt.rs b/crates/ntt/src/additive_ntt.rs index a48d7bb9d..7b5e0bcc4 100644 --- a/crates/ntt/src/additive_ntt.rs +++ b/crates/ntt/src/additive_ntt.rs @@ -1,7 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{ExtensionField, PackedField, RepackedExtension}; -use binius_utils::checked_arithmetics::log2_strict_usize; use super::error::Error; @@ -46,29 +45,19 @@ pub trait AdditiveNTT { log_batch_size: usize, ) -> Result<(), Error>; - fn forward_transform_ext(&self, data: &mut [PE], coset: u32) -> Result<(), Error> - where - PE: RepackedExtension

, - PE::Scalar: ExtensionField, - { - if !PE::Scalar::DEGREE.is_power_of_two() { - return Err(Error::PowerOfTwoExtensionDegreeRequired); - } - - let log_batch_size = log2_strict_usize(PE::Scalar::DEGREE); - self.forward_transform(PE::cast_bases_mut(data), coset, log_batch_size) + fn forward_transform_ext>( + &self, + data: &mut [PE], + coset: u32, + ) -> Result<(), Error> { + self.forward_transform(PE::cast_bases_mut(data), coset, PE::Scalar::LOG_DEGREE) } - fn inverse_transform_ext(&self, data: &mut [PE], coset: u32) -> Result<(), Error> - where - PE: RepackedExtension

, - PE::Scalar: ExtensionField, - { - if !PE::Scalar::DEGREE.is_power_of_two() { - return Err(Error::PowerOfTwoExtensionDegreeRequired); - } - - let log_batch_size = log2_strict_usize(PE::Scalar::DEGREE); - self.inverse_transform(PE::cast_bases_mut(data), coset, log_batch_size) + fn inverse_transform_ext>( + &self, + data: &mut [PE], + coset: u32, + ) -> Result<(), Error> { + self.inverse_transform(PE::cast_bases_mut(data), coset, PE::Scalar::LOG_DEGREE) } } diff --git a/crates/ntt/src/dynamic_dispatch.rs b/crates/ntt/src/dynamic_dispatch.rs index f7eb830f6..bd1d0f8f8 100644 --- a/crates/ntt/src/dynamic_dispatch.rs +++ b/crates/ntt/src/dynamic_dispatch.rs @@ -54,7 +54,7 @@ pub enum DynamicDispatchNTT { impl DynamicDispatchNTT { /// Create a new AdditiveNTT based on the given settings. - pub fn new(log_domain_size: usize, options: NTTOptions) -> Result { + pub fn new(log_domain_size: usize, options: &NTTOptions) -> Result { let log_threads = options.thread_settings.log_threads_count(); let result = match (options.precompute_twiddles, log_threads) { (false, 0) => Self::SingleThreaded(SingleThreadedNTT::new(log_domain_size)?), @@ -144,24 +144,24 @@ mod tests { #[test] fn test_creation() { - fn make_ntt(options: NTTOptions) -> DynamicDispatchNTT { + fn make_ntt(options: &NTTOptions) -> DynamicDispatchNTT { DynamicDispatchNTT::::new(6, options).unwrap() } - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: false, thread_settings: ThreadingSettings::SingleThreaded, }); assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_))); - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: true, thread_settings: ThreadingSettings::SingleThreaded, }); assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_))); let multithreaded = get_log_max_threads() > 0; - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: false, thread_settings: ThreadingSettings::MultithreadedDefault, }); @@ -171,7 +171,7 @@ mod tests { assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_))); } - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: true, thread_settings: ThreadingSettings::MultithreadedDefault, }); @@ -181,19 +181,19 @@ mod tests { assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_))); } - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: false, thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 2 }, }); assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_))); - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: true, thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 }, }); assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_))); - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: false, thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 }, }); diff --git a/crates/ntt/src/single_threaded.rs b/crates/ntt/src/single_threaded.rs index a5723a62e..2448894c6 100644 --- a/crates/ntt/src/single_threaded.rs +++ b/crates/ntt/src/single_threaded.rs @@ -187,7 +187,7 @@ pub fn forward_transform>( // packed twiddles for all packed butterfly units. let log_block_len = i + log_b; let block_twiddle = calculate_twiddle::

( - s_evals[i].coset(log_domain_size - 1 - cutoff, 0), + &s_evals[i].coset(log_domain_size - 1 - cutoff, 0), log_block_len, ); @@ -263,7 +263,7 @@ pub fn inverse_transform>( // packed twiddles for all packed butterfly units. let log_block_len = i + log_b; let block_twiddle = calculate_twiddle::

( - s_evals[i].coset(log_domain_size - 1 - cutoff, 0), + &s_evals[i].coset(log_domain_size - 1 - cutoff, 0), log_block_len, ); @@ -357,7 +357,7 @@ pub const fn check_batch_transform_inputs_and_params( } #[inline] -fn calculate_twiddle

(s_evals: impl TwiddleAccess, log_block_len: usize) -> P +fn calculate_twiddle

(s_evals: &impl TwiddleAccess, log_block_len: usize) -> P where P: PackedField, { diff --git a/crates/ntt/src/tests/ntt_tests.rs b/crates/ntt/src/tests/ntt_tests.rs index 7ab2d2e39..ccdf8b7d8 100644 --- a/crates/ntt/src/tests/ntt_tests.rs +++ b/crates/ntt/src/tests/ntt_tests.rs @@ -10,8 +10,8 @@ use binius_field::{ packed_8::PackedBinaryField1x8b, }, underlier::{NumCast, WithUnderlier}, - AESTowerField8b, BinaryField, BinaryField8b, ExtensionField, PackedBinaryField16x32b, - PackedBinaryField8x32b, PackedField, RepackedExtension, + AESTowerField8b, BinaryField, BinaryField8b, PackedBinaryField16x32b, PackedBinaryField8x32b, + PackedField, RepackedExtension, }; use rand::{rngs::StdRng, SeedableRng}; @@ -144,16 +144,12 @@ fn tests_field_512_bits() { check_roundtrip_all_ntts::(12, 6, 4, 0); } -fn check_packed_extension_roundtrip_with_reference( +fn check_packed_extension_roundtrip_with_reference>( reference_ntt: &impl AdditiveNTT

, ntt: &impl AdditiveNTT

, data: &mut [PE], cosets: Range, -) where - P: PackedField, - PE: RepackedExtension

, - PE::Scalar: ExtensionField, -{ +) { let data_copy = data.to_vec(); let mut data_copy_2 = data.to_vec(); @@ -182,7 +178,6 @@ fn check_packed_extension_roundtrip_all_ntts( ) where P: PackedField, PE: RepackedExtension

+ WithUnderlier>, - PE::Scalar: ExtensionField, { let simple_ntt = SingleThreadedNTT::::new(log_domain_size) .unwrap() diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index 0d04d2917..c0e418404 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -8,9 +8,10 @@ authors.workspace = true workspace = true [dependencies] +auto_impl.workspace = true binius_maybe_rayon = { path = "../maybe_rayon", default-features = false } -bytes.workspace = true bytemuck = { workspace = true, features = ["extern_crate_alloc"] } +bytes.workspace = true cfg-if.workspace = true generic-array.workspace = true itertools.workspace = true diff --git a/crates/utils/src/lib.rs b/crates/utils/src/lib.rs index 493fe6e20..3ac565f0b 100644 --- a/crates/utils/src/lib.rs +++ b/crates/utils/src/lib.rs @@ -17,3 +17,6 @@ pub mod serialization; pub mod sorting; pub mod sparse_index; pub mod thread_local_mut; + +pub use bytes; +pub use serialization::{DeserializeBytes, SerializationError, SerializationMode, SerializeBytes}; diff --git a/crates/utils/src/serialization.rs b/crates/utils/src/serialization.rs index 434befccb..a6782b367 100644 --- a/crates/utils/src/serialization.rs +++ b/crates/utils/src/serialization.rs @@ -1,50 +1,417 @@ // Copyright 2024-2025 Irreducible Inc. +use auto_impl::auto_impl; use bytes::{Buf, BufMut}; -use generic_array::{ArrayLength, GenericArray}; +use thiserror::Error; + +/// Serialize data according to Mode param +#[auto_impl(Box, &)] +pub trait SerializeBytes { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError>; +} + +/// Deserialize data according to Mode param +pub trait DeserializeBytes { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized; +} + +/// Specifies serialization/deserialization behavior +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SerializationMode { + /// This mode is faster, and serializes to the underlying bytes + Native, + /// Will first convert any tower fields into the Fan-Paar field equivalent + CanonicalTower, +} -#[derive(Clone, thiserror::Error, Debug)] -pub enum Error { +#[derive(Error, Debug, Clone)] +pub enum SerializationError { #[error("Write buffer is full")] WriteBufferFull, #[error("Not enough data in read buffer to deserialize")] NotEnoughBytes, + #[error("Unknown enum variant index {name}::{index}")] + UnknownEnumVariant { name: &'static str, index: u8 }, + #[error("Serialization has not been implemented")] + SerializationNotImplemented, + #[error("Deserializer has not been implemented")] + DeserializerNotImplented, + #[error("Multiple deserializers with the same name {name} has been registered")] + DeserializerNameConflict { name: String }, + #[error("FromUtf8Error: {0}")] + FromUtf8Error(#[from] std::string::FromUtf8Error), } -/// Represents type that can be serialized to a byte buffer. -pub trait SerializeBytes { - fn serialize(&self, write_buf: impl BufMut) -> Result<(), Error>; +// Copyright 2025 Irreducible Inc. + +use generic_array::{ArrayLength, GenericArray}; + +impl DeserializeBytes for Box { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + Ok(Self::new(T::deserialize(read_buf, mode)?)) + } } -/// Represents type that can be deserialized from a byte buffer. -pub trait DeserializeBytes { - fn deserialize(read_buf: impl Buf) -> Result +impl SerializeBytes for usize { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + SerializeBytes::serialize(&(*self as u64), &mut write_buf, mode) + } +} + +impl DeserializeBytes for usize { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result where - Self: Sized; + Self: Sized, + { + let value: u64 = DeserializeBytes::deserialize(&mut read_buf, mode)?; + Ok(value as Self) + } } -impl> SerializeBytes for GenericArray { - fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - if write_buf.remaining_mut() < N::USIZE { - return Err(Error::WriteBufferFull); +impl SerializeBytes for u128 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u128_le(*self); + Ok(()) + } +} + +impl DeserializeBytes for u128 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u128_le()) + } +} + +impl SerializeBytes for u64 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u64_le(*self); + Ok(()) + } +} + +impl DeserializeBytes for u64 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u64_le()) + } +} + +impl SerializeBytes for u32 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u32_le(*self); + Ok(()) + } +} + +impl DeserializeBytes for u32 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u32_le()) + } +} + +impl SerializeBytes for u16 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u16_le(*self); + Ok(()) + } +} + +impl DeserializeBytes for u16 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u16_le()) + } +} + +impl SerializeBytes for u8 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u8(*self); + Ok(()) + } +} + +impl DeserializeBytes for u8 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u8()) + } +} + +impl SerializeBytes for bool { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + u8::serialize(&(*self as u8), write_buf, mode) + } +} + +impl DeserializeBytes for bool { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + Ok(u8::deserialize(read_buf, mode)? != 0) + } +} + +impl SerializeBytes for std::marker::PhantomData { + fn serialize( + &self, + _write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl DeserializeBytes for std::marker::PhantomData { + fn deserialize( + _read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(Self) + } +} + +impl SerializeBytes for &str { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + let bytes = self.as_bytes(); + SerializeBytes::serialize(&bytes.len(), &mut write_buf, mode)?; + assert_enough_space_for(&write_buf, bytes.len())?; + write_buf.put_slice(bytes); + Ok(()) + } +} + +impl SerializeBytes for String { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + SerializeBytes::serialize(&self.as_str(), &mut write_buf, mode) + } +} + +impl DeserializeBytes for String { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + let len = DeserializeBytes::deserialize(&mut read_buf, mode)?; + assert_enough_data_for(&read_buf, len)?; + Ok(Self::from_utf8(read_buf.copy_to_bytes(len).to_vec())?) + } +} + +impl SerializeBytes for Vec { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + SerializeBytes::serialize(&self.len(), &mut write_buf, mode)?; + self.iter() + .try_for_each(|item| SerializeBytes::serialize(item, &mut write_buf, mode)) + } +} + +impl DeserializeBytes for Vec { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + let len: usize = DeserializeBytes::deserialize(&mut read_buf, mode)?; + (0..len) + .map(|_| DeserializeBytes::deserialize(&mut read_buf, mode)) + .collect() + } +} + +impl SerializeBytes for Option { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + match self { + Some(value) => { + SerializeBytes::serialize(&true, &mut write_buf, mode)?; + SerializeBytes::serialize(value, &mut write_buf, mode)?; + } + None => { + SerializeBytes::serialize(&false, write_buf, mode)?; + } } + Ok(()) + } +} + +impl DeserializeBytes for Option { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(match bool::deserialize(&mut read_buf, mode)? { + true => Some(T::deserialize(&mut read_buf, mode)?), + false => None, + }) + } +} + +impl SerializeBytes for (U, V) { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + U::serialize(&self.0, &mut write_buf, mode)?; + V::serialize(&self.1, write_buf, mode) + } +} + +impl DeserializeBytes for (U, V) { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok((U::deserialize(&mut read_buf, mode)?, V::deserialize(read_buf, mode)?)) + } +} + +impl> SerializeBytes for GenericArray { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, N::USIZE)?; write_buf.put_slice(self); Ok(()) } } impl> DeserializeBytes for GenericArray { - fn deserialize(mut read_buf: impl Buf) -> Result { - if read_buf.remaining() < N::USIZE { - return Err(Error::NotEnoughBytes); - } - + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result { + assert_enough_data_for(&read_buf, N::USIZE)?; let mut ret = Self::default(); read_buf.copy_to_slice(&mut ret); Ok(ret) } } +#[inline] +fn assert_enough_space_for(write_buf: &impl BufMut, size: usize) -> Result<(), SerializationError> { + if write_buf.remaining_mut() < size { + return Err(SerializationError::WriteBufferFull); + } + Ok(()) +} + +#[inline] +fn assert_enough_data_for(read_buf: &impl Buf, size: usize) -> Result<(), SerializationError> { + if read_buf.remaining() < size { + return Err(SerializationError::NotEnoughBytes); + } + Ok(()) +} + #[cfg(test)] mod tests { use generic_array::typenum::U32; @@ -60,9 +427,11 @@ mod tests { rng.fill_bytes(&mut data); let mut buf = Vec::new(); - data.serialize(&mut buf).unwrap(); + data.serialize(&mut buf, SerializationMode::Native).unwrap(); - let data_deserialized = GenericArray::::deserialize(&mut buf.as_slice()).unwrap(); + let data_deserialized = + GenericArray::::deserialize(&mut buf.as_slice(), SerializationMode::Native) + .unwrap(); assert_eq!(data_deserialized, data); } } diff --git a/crates/utils/src/thread_local_mut.rs b/crates/utils/src/thread_local_mut.rs index 0f196c5b2..9bf80ad32 100644 --- a/crates/utils/src/thread_local_mut.rs +++ b/crates/utils/src/thread_local_mut.rs @@ -6,7 +6,7 @@ use thread_local::ThreadLocal; /// Creates a "scratch space" within each thread with mutable access. /// -/// This is mainly meant to be used as an optimization to avoid unneccesary allocs/frees within rayon code. +/// This is mainly meant to be used as an optimization to avoid unnecessary allocs/frees within rayon code. /// You only pay for allocation of this scratch space once per thread. /// /// Since the space is local to each thread you also don't have to worry about atomicity. diff --git a/examples/Cargo.toml b/examples/Cargo.toml index ece6509b8..0de51d976 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -25,10 +25,6 @@ rand.workspace = true tracing-profile.workspace = true tracing.workspace = true -[[example]] -name = "groestl_circuit" -path = "groestl_circuit.rs" - [[example]] name = "keccakf_circuit" path = "keccakf_circuit.rs" @@ -81,9 +77,9 @@ path = "b32_mul.rs" name = "acc-linear-combination" path = "acc-linear-combination.rs" -[[example]] -name = "acc-linear-combination-with-offset" -path = "acc-linear-combination-with-offset.rs" +#[[example]] +#name = "acc-linear-combination-with-offset" +#path = "acc-linear-combination-with-offset.rs" [[example]] name = "acc-shifted" @@ -145,6 +141,10 @@ path = "acc-step-up.rs" name = "acc-tower-basis" path = "acc-tower-basis.rs" +[[example]] +name = "acc-permutation-channels" +path = "acc-permutation-channels.rs" + [lints.clippy] needless_range_loop = "allow" @@ -154,4 +154,3 @@ aes-tower = [] bail_panic = ["binius_utils/bail_panic"] fp-tower = [] rayon = ["binius_utils/rayon"] - diff --git a/examples/acc-constants.rs b/examples/acc-constants.rs index 0947ed107..9536f172d 100644 --- a/examples/acc-constants.rs +++ b/examples/acc-constants.rs @@ -3,10 +3,7 @@ use binius_core::{ constraint_system::validate::validate_witness, oracle::OracleId, transparent::constant::Constant, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b, BinaryField32b}; - -type U = OptimalUnderlier; -type F128 = BinaryField128b; +use binius_field::{BinaryField1b, BinaryField32b}; type F32 = BinaryField32b; type F1 = BinaryField1b; @@ -17,7 +14,7 @@ const LOG_SIZE: usize = 4; fn constants_gadget( name: impl ToString, log_size: usize, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, constant_value: u32, ) -> OracleId { builder.push_namespace(name); @@ -45,7 +42,8 @@ fn constants_gadget( // Transparent column. fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); pub const SHA256_INIT: [u32; 8] = [ 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, diff --git a/examples/acc-disjoint-product.rs b/examples/acc-disjoint-product.rs index 05f4b59d9..2f6ed5812 100644 --- a/examples/acc-disjoint-product.rs +++ b/examples/acc-disjoint-product.rs @@ -3,12 +3,8 @@ use binius_core::{ constraint_system::validate::validate_witness, transparent::{constant::Constant, disjoint_product::DisjointProduct, powers::Powers}, }; -use binius_field::{ - arch::OptimalUnderlier, BinaryField, BinaryField128b, BinaryField8b, PackedField, -}; +use binius_field::{BinaryField, BinaryField8b, PackedField}; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; const LOG_SIZE: usize = 4; @@ -29,7 +25,7 @@ const LOG_SIZE: usize = 4; // of heights (n_vars) of Powers and Constant, so actual data could be repeated multiple times fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let generator = F8::MULTIPLICATIVE_GENERATOR; let powers = Powers::new(LOG_SIZE, generator.into()); diff --git a/examples/acc-eq-ind-partial-eval.rs b/examples/acc-eq-ind-partial-eval.rs index 7fcdc4b54..6d76fb45f 100644 --- a/examples/acc-eq-ind-partial-eval.rs +++ b/examples/acc-eq-ind-partial-eval.rs @@ -2,9 +2,8 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{ constraint_system::validate::validate_witness, transparent::eq_ind::EqIndPartialEval, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, PackedField}; +use binius_field::{BinaryField128b, PackedField}; -type U = OptimalUnderlier; type F128 = BinaryField128b; const LOG_SIZE: usize = 3; @@ -21,7 +20,8 @@ const LOG_SIZE: usize = 3; // fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); // A truth table [000, 001, 010, 011 ... 111] where each row is in reversed order let rev_basis = [ diff --git a/examples/acc-linear-combination-with-offset.rs b/examples/acc-linear-combination-with-offset.rs.disabled similarity index 100% rename from examples/acc-linear-combination-with-offset.rs rename to examples/acc-linear-combination-with-offset.rs.disabled diff --git a/examples/acc-linear-combination.rs b/examples/acc-linear-combination.rs index 5cb1eab54..b069b5196 100644 --- a/examples/acc-linear-combination.rs +++ b/examples/acc-linear-combination.rs @@ -1,20 +1,17 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; use binius_field::{ - arch::OptimalUnderlier, packed::set_packed_slice, BinaryField128b, BinaryField1b, - BinaryField8b, ExtensionField, TowerField, + packed::set_packed_slice, BinaryField1b, BinaryField8b, ExtensionField, TowerField, }; use binius_macros::arith_expr; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; type F1 = BinaryField1b; // FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production fn bytes_decomposition_gadget( - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, input: OracleId, @@ -146,7 +143,7 @@ fn bytes_decomposition_gadget( } fn elder_4bits_masking_gadget( - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, input: OracleId, @@ -241,12 +238,12 @@ fn elder_4bits_masking_gadget( fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = 1usize; // Define set of bytes that we want to decompose - let p_in = unconstrained::(&mut builder, "p_in".to_string(), log_size).unwrap(); + let p_in = unconstrained::(&mut builder, "p_in".to_string(), log_size).unwrap(); let _ = bytes_decomposition_gadget(&mut builder, "bytes decomposition", log_size, p_in).unwrap(); diff --git a/examples/acc-multilinear-extension-transparent.rs b/examples/acc-multilinear-extension-transparent.rs index f9dd45706..033184edd 100644 --- a/examples/acc-multilinear-extension-transparent.rs +++ b/examples/acc-multilinear-extension-transparent.rs @@ -14,7 +14,7 @@ type F128 = BinaryField128b; type F1 = BinaryField1b; // From a perspective of circuits creation, MultilinearExtensionTransparent can be used naturally for decomposing integers to bits -fn decompose_transparent_u64(builder: &mut ConstraintSystemBuilder, x: u64) { +fn decompose_transparent_u64(builder: &mut ConstraintSystemBuilder, x: u64) { builder.push_namespace("decompose_transparent_u64"); let log_bits = log2_ceil_usize(64); @@ -42,7 +42,7 @@ fn decompose_transparent_u64(builder: &mut ConstraintSystemBuilder, x: builder.pop_namespace(); } -fn decompose_transparent_u32(builder: &mut ConstraintSystemBuilder, x: u32) { +fn decompose_transparent_u32(builder: &mut ConstraintSystemBuilder, x: u32) { builder.push_namespace("decompose_transparent_u32"); let log_bits = log2_ceil_usize(32); @@ -72,7 +72,7 @@ fn decompose_transparent_u32(builder: &mut ConstraintSystemBuilder, x: fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); decompose_transparent_u64(&mut builder, 0xff00ff00ff00ff00); decompose_transparent_u32(&mut builder, 0x00ff00ff); diff --git a/examples/acc-packed.rs b/examples/acc-packed.rs index 6963d8780..434a9d11c 100644 --- a/examples/acc-packed.rs +++ b/examples/acc-packed.rs @@ -1,12 +1,7 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::constraint_system::validate::validate_witness; -use binius_field::{ - arch::OptimalUnderlier, BinaryField128b, BinaryField16b, BinaryField1b, BinaryField32b, - BinaryField8b, TowerField, -}; +use binius_field::{BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, TowerField}; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F32 = BinaryField32b; type F16 = BinaryField16b; type F8 = BinaryField8b; @@ -14,10 +9,10 @@ type F1 = BinaryField1b; // FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production -fn packing_32_bits_to_u32(builder: &mut ConstraintSystemBuilder) { +fn packing_32_bits_to_u32(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("packing_32_bits_to_u32"); - let bits = unconstrained::(builder, "bits", F32::TOWER_LEVEL).unwrap(); + let bits = unconstrained::(builder, "bits", F32::TOWER_LEVEL).unwrap(); let packed = builder .add_packed("packed", bits, F32::TOWER_LEVEL) .unwrap(); @@ -49,10 +44,10 @@ fn packing_32_bits_to_u32(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn packing_4_bytes_to_u32(builder: &mut ConstraintSystemBuilder) { +fn packing_4_bytes_to_u32(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("packing_4_bytes_to_u32"); - let bytes = unconstrained::(builder, "bytes", F16::TOWER_LEVEL).unwrap(); + let bytes = unconstrained::(builder, "bytes", F16::TOWER_LEVEL).unwrap(); let packed = builder .add_packed("packed", bytes, F16::TOWER_LEVEL) .unwrap(); @@ -76,10 +71,10 @@ fn packing_4_bytes_to_u32(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn packing_8_bits_to_u8(builder: &mut ConstraintSystemBuilder) { +fn packing_8_bits_to_u8(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("packing_8_bits_to_u8"); - let bits = unconstrained::(builder, "bits", F8::TOWER_LEVEL).unwrap(); + let bits = unconstrained::(builder, "bits", F8::TOWER_LEVEL).unwrap(); let packed = builder.add_packed("packed", bits, F8::TOWER_LEVEL).unwrap(); if let Some(witness) = builder.witness() { @@ -94,7 +89,7 @@ fn packing_8_bits_to_u8(builder: &mut ConstraintSystemBuilder) { fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); packing_32_bits_to_u32(&mut builder); packing_4_bytes_to_u32(&mut builder); diff --git a/examples/acc-permutation-channels.rs b/examples/acc-permutation-channels.rs new file mode 100644 index 000000000..a628bf350 --- /dev/null +++ b/examples/acc-permutation-channels.rs @@ -0,0 +1,111 @@ +use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::fixed_u32}; +use binius_core::constraint_system::{ + channel::{Boundary, FlushDirection}, + validate::validate_witness, +}; +use binius_field::{BinaryField128b, BinaryField32b}; +use bumpalo::Bump; + +type F128 = BinaryField128b; +type F32 = BinaryField32b; + +const MSG_PERMUTATION: [usize; 16] = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8]; + +// Permutation is a classic construction in a traditional cryptography. It has well-defined security properties +// and high performance due to implementation via lookups. One can possible to implement gadget for permutations using +// channels API from Binius. The following examples shows how to enforce Blake3 permutation - verifier pulls pairs of +// input/output of the permutation (encoded as a BinaryField128b elements, to reduce number of flushes), +// while prover is expected to push similar IO to make channel balanced. +fn permute(m: &mut [u32; 16]) { + let mut permuted = [0; 16]; + for i in 0..16 { + permuted[i] = m[MSG_PERMUTATION[i]]; + } + *m = permuted; +} + +fn main() { + let log_size = 4usize; + + let allocator = Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + let m = [ + 0xfffffff0, 0xfffffff1, 0xfffffff2, 0xfffffff3, 0xfffffff4, 0xfffffff5, 0xfffffff6, + 0xfffffff7, 0xfffffff8, 0xfffffff9, 0xfffffffa, 0xfffffffb, 0xfffffffc, 0xfffffffd, + 0xfffffffe, 0xffffffff, + ]; + + let mut m_clone = m; + permute(&mut m_clone); + + let expected = [ + 0xfffffff2, 0xfffffff6, 0xfffffff3, 0xfffffffa, 0xfffffff7, 0xfffffff0, 0xfffffff4, + 0xfffffffd, 0xfffffff1, 0xfffffffb, 0xfffffffc, 0xfffffff5, 0xfffffff9, 0xfffffffe, + 0xffffffff, 0xfffffff8, + ]; + assert_eq!(m_clone, expected); + + let u32_in = fixed_u32::(&mut builder, "in", log_size, m.to_vec()).unwrap(); + let u32_out = fixed_u32::(&mut builder, "out", log_size, expected.to_vec()).unwrap(); + + // we pack 4-u32 (F32) tuples of permutation IO into F128 columns and use them for flushing + let u128_in = builder.add_packed("in_packed", u32_in, 2).unwrap(); + let u128_out = builder.add_packed("out_packed", u32_out, 2).unwrap(); + + // populate memory layout (witness) + if let Some(witness) = builder.witness() { + let in_f32 = witness.get::(u32_in).unwrap(); + let out_f32 = witness.get::(u32_out).unwrap(); + witness.new_column::(u128_in); + witness.new_column::(u128_out); + + witness.set(u128_in, in_f32.repacked::()).unwrap(); + witness.set(u128_out, out_f32.repacked::()).unwrap(); + } + + let channel = builder.add_channel(); + // count defines how many values ( 0 .. count ) from a given columns to send (pushing to a channel) + builder.send(channel, 4, [u128_in, u128_out]).unwrap(); + + let witness = builder.take_witness().unwrap(); + let cs = builder.build().unwrap(); + + // consider our 4-u32 values from a given tupple as 4 limbs of u128 + let f = |limb0: u32, limb1: u32, limb2: u32, limb3: u32| { + let mut x = 0u128; + + x ^= (limb3 as u128) << 96; + x ^= (limb2 as u128) << 64; + x ^= (limb1 as u128) << 32; + x ^= limb0 as u128; + + F128::new(x) + }; + + // Boundaries define actual data (encoded in a set of Flushes) that verifier can push or pull from a given channel + // in order to check if prover is able to balance that channel + let mut offset = 0usize; + let boundaries = (0..4) + .map(|_| { + let boundary = Boundary { + values: vec![ + f(m[offset], m[offset + 1], m[offset + 2], m[offset + 3]), + f( + expected[offset], + expected[offset + 1], + expected[offset + 2], + expected[offset + 3], + ), + ], + channel_id: channel, + direction: FlushDirection::Pull, + multiplicity: 1, + }; + offset += 4; + boundary + }) + .collect::>>(); + + validate_witness(&cs, &boundaries, &witness).unwrap(); +} diff --git a/examples/acc-powers.rs b/examples/acc-powers.rs index 159aa1136..c266191e8 100644 --- a/examples/acc-powers.rs +++ b/examples/acc-powers.rs @@ -1,12 +1,7 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::constraint_system::validate::validate_witness; -use binius_field::{ - arch::OptimalUnderlier, BinaryField, BinaryField128b, BinaryField16b, BinaryField32b, - PackedField, -}; +use binius_field::{BinaryField, BinaryField16b, BinaryField32b, PackedField}; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F32 = BinaryField32b; type F16 = BinaryField16b; @@ -22,7 +17,7 @@ const LOG_SIZE: usize = 3; // where 'x' is a multiplicative generator - a public value that exists for every BinaryField // -fn powers_gadget_f32(builder: &mut ConstraintSystemBuilder, name: impl ToString) { +fn powers_gadget_f32(builder: &mut ConstraintSystemBuilder, name: impl ToString) { builder.push_namespace(name); let generator = F32::MULTIPLICATIVE_GENERATOR; @@ -43,7 +38,7 @@ fn powers_gadget_f32(builder: &mut ConstraintSystemBuilder, name: impl } // Only Field is being changed -fn powers_gadget_f16(builder: &mut ConstraintSystemBuilder, name: impl ToString) { +fn powers_gadget_f16(builder: &mut ConstraintSystemBuilder, name: impl ToString) { builder.push_namespace(name); let generator = F16::MULTIPLICATIVE_GENERATOR; @@ -65,7 +60,7 @@ fn powers_gadget_f16(builder: &mut ConstraintSystemBuilder, name: impl fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); powers_gadget_f16(&mut builder, "f16"); powers_gadget_f32(&mut builder, "f32"); diff --git a/examples/acc-projected.rs b/examples/acc-projected.rs index 8d4de121d..c7ac8150d 100644 --- a/examples/acc-projected.rs +++ b/examples/acc-projected.rs @@ -1,8 +1,7 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::{constraint_system::validate::validate_witness, oracle::ProjectionVariant}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::{BinaryField128b, BinaryField8b}; -type U = OptimalUnderlier; type F128 = BinaryField128b; type F8 = BinaryField8b; @@ -20,14 +19,13 @@ struct U8U128ProjectionInfo { // has significant impact on input data processing. // In the following example we have input column with bytes (u8) projected to the output column with u128 values. fn projection( - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, projection_info: U8U128ProjectionInfo, namespace: &str, ) { builder.push_namespace(format!("projection {}", namespace)); - let input = - unconstrained::(builder, "in", projection_info.clone().log_size).unwrap(); + let input = unconstrained::(builder, "in", projection_info.clone().log_size).unwrap(); let projected = builder .add_projected( @@ -112,7 +110,7 @@ impl U8U128ProjectionInfo { fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let projection_data = U8U128ProjectionInfo::new( 4usize, diff --git a/examples/acc-repeated.rs b/examples/acc-repeated.rs index 749bb6fee..ef4bf6431 100644 --- a/examples/acc-repeated.rs +++ b/examples/acc-repeated.rs @@ -1,12 +1,9 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::constraint_system::validate::validate_witness; use binius_field::{ - arch::OptimalUnderlier, packed::set_packed_slice, BinaryField128b, BinaryField1b, - BinaryField8b, PackedBinaryField128x1b, + packed::set_packed_slice, BinaryField1b, BinaryField8b, PackedBinaryField128x1b, }; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; type F1 = BinaryField1b; @@ -18,10 +15,10 @@ const LOG_SIZE: usize = 8; // so new column is X times bigger than original one. The following gadget operates over bytes, e.g. // it creates column with some input bytes written and then creates one more 'Repeated' column // where the same bytes are copied multiple times. -fn bytes_repeat_gadget(builder: &mut ConstraintSystemBuilder) { +fn bytes_repeat_gadget(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("bytes_repeat_gadget"); - let bytes = unconstrained::(builder, "input", LOG_SIZE).unwrap(); + let bytes = unconstrained::(builder, "input", LOG_SIZE).unwrap(); let repeat_times_log = 4usize; let repeating = builder @@ -57,10 +54,10 @@ fn bytes_repeat_gadget(builder: &mut ConstraintSystemBuilder) { // repetitions Binius creates column with 8 PackedBinaryField128x1b elements totally. // Proper writing bits requires separate iterating over PackedBinaryField128x1b elements and input bytes // with extracting particular bit values from the input and setting appropriate bit in a given PackedBinaryField128x1b. -fn bits_repeat_gadget(builder: &mut ConstraintSystemBuilder) { +fn bits_repeat_gadget(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("bits_repeat_gadget"); - let bits = unconstrained::(builder, "input", LOG_SIZE).unwrap(); + let bits = unconstrained::(builder, "input", LOG_SIZE).unwrap(); let repeat_times_log = 2usize; // Binius will create column with appropriate height for us @@ -112,7 +109,7 @@ fn bits_repeat_gadget(builder: &mut ConstraintSystemBuilder) { fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); bytes_repeat_gadget(&mut builder); bits_repeat_gadget(&mut builder); diff --git a/examples/acc-select-row.rs b/examples/acc-select-row.rs index 4337540bd..7f45bee02 100644 --- a/examples/acc-select-row.rs +++ b/examples/acc-select-row.rs @@ -2,10 +2,8 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{ constraint_system::validate::validate_witness, transparent::select_row::SelectRow, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::BinaryField8b; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; const LOG_SIZE: usize = 8; @@ -13,7 +11,7 @@ const LOG_SIZE: usize = 8; // SelectRow expects exactly one witness value at particular index to be set. fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let index = 58; assert!(index < 1 << LOG_SIZE); diff --git a/examples/acc-shift-ind-partial-eq.rs b/examples/acc-shift-ind-partial-eq.rs index 74fa2bae7..385a07b85 100644 --- a/examples/acc-shift-ind-partial-eq.rs +++ b/examples/acc-shift-ind-partial-eq.rs @@ -3,16 +3,15 @@ use binius_core::{ constraint_system::validate::validate_witness, oracle::ShiftVariant, transparent::shift_ind::ShiftIndPartialEval, }; -use binius_field::{arch::OptimalUnderlier, util::eq, BinaryField128b, Field}; +use binius_field::{util::eq, BinaryField128b, Field}; -type U = OptimalUnderlier; type F128 = BinaryField128b; // ShiftIndPartialEval is a more elaborated version of EqIndPartialEval. Same idea with challenges, but a bit more // elaborated evaluation algorithm is used fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let block_size = 3; let shift_offset = 4; diff --git a/examples/acc-shifted.rs b/examples/acc-shifted.rs index 924fb8234..a81a6610d 100644 --- a/examples/acc-shifted.rs +++ b/examples/acc-shifted.rs @@ -1,21 +1,19 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::{constraint_system::validate::validate_witness, oracle::ShiftVariant}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b}; +use binius_field::BinaryField1b; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F1 = BinaryField1b; // FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production -fn shift_right_gadget_u32(builder: &mut ConstraintSystemBuilder) { +fn shift_right_gadget_u32(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("u32_right_shift"); // defined empirically and it is the same as 'block_bits' defined below let log_size = 5usize; // create column and write arbitrary bytes to it - let input = unconstrained::(builder, "input", log_size).unwrap(); + let input = unconstrained::(builder, "input", log_size).unwrap(); // we want to shift our u32 variable on 1 bit let shift_offset = 1; @@ -43,11 +41,11 @@ fn shift_right_gadget_u32(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn shift_left_gadget_u8(builder: &mut ConstraintSystemBuilder) { +fn shift_left_gadget_u8(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("u8_left_shift"); let log_size = 3usize; - let input = unconstrained::(builder, "input", log_size).unwrap(); + let input = unconstrained::(builder, "input", log_size).unwrap(); let shift_offset = 4; let shift_type = ShiftVariant::LogicalLeft; let block_bits = 3; @@ -67,11 +65,11 @@ fn shift_left_gadget_u8(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn rotate_left_gadget_u16(builder: &mut ConstraintSystemBuilder) { +fn rotate_left_gadget_u16(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("u16_rotate_right"); let log_size = 4usize; - let input = unconstrained::(builder, "input", log_size).unwrap(); + let input = unconstrained::(builder, "input", log_size).unwrap(); let rotation_offset = 5; let rotation_type = ShiftVariant::CircularLeft; let block_bits = 4usize; @@ -92,11 +90,11 @@ fn rotate_left_gadget_u16(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn rotate_right_gadget_u64(builder: &mut ConstraintSystemBuilder) { +fn rotate_right_gadget_u64(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("u64_rotate_right"); let log_size = 6usize; - let input = unconstrained::(builder, "input", log_size).unwrap(); + let input = unconstrained::(builder, "input", log_size).unwrap(); // Right rotation to X bits is achieved using 'ShiftVariant::CircularLeft' with the offset, // computed as size in bits of the variable type - X (e.g. if we want to right-rotate u64 to 8 bits, @@ -124,7 +122,7 @@ fn rotate_right_gadget_u64(builder: &mut ConstraintSystemBuilder) { fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); shift_right_gadget_u32(&mut builder); shift_left_gadget_u8(&mut builder); diff --git a/examples/acc-step-down.rs b/examples/acc-step-down.rs index cbe3cd6e6..a2ecf3ba2 100644 --- a/examples/acc-step-down.rs +++ b/examples/acc-step-down.rs @@ -2,18 +2,16 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{ constraint_system::validate::validate_witness, transparent::step_down::StepDown, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::BinaryField8b; const LOG_SIZE: usize = 8; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; // StepDown expects all bytes to be set before particular index specified as input fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let index = 10; diff --git a/examples/acc-step-up.rs b/examples/acc-step-up.rs index e659352e8..0d65820e5 100644 --- a/examples/acc-step-up.rs +++ b/examples/acc-step-up.rs @@ -1,9 +1,7 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{constraint_system::validate::validate_witness, transparent::step_up::StepUp}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::BinaryField8b; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; const LOG_SIZE: usize = 8; @@ -11,7 +9,7 @@ const LOG_SIZE: usize = 8; // StepUp expects all bytes to be unset before particular index specified as input (opposite to StepDown) fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let index = 10; diff --git a/examples/acc-tower-basis.rs b/examples/acc-tower-basis.rs index e99e6790b..fcc9a3837 100644 --- a/examples/acc-tower-basis.rs +++ b/examples/acc-tower-basis.rs @@ -2,16 +2,15 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{ constraint_system::validate::validate_witness, transparent::tower_basis::TowerBasis, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, Field, TowerField}; +use binius_field::{BinaryField128b, Field, TowerField}; -type U = OptimalUnderlier; type F128 = BinaryField128b; // TowerBasis expects actually basis vectors written to the witness. // The form of basis could vary depending on 'iota' and 'k' parameters fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let k = 3usize; let iota = 4usize; diff --git a/examples/acc-zeropadded.rs b/examples/acc-zeropadded.rs index 2cbb954ca..b306bd63a 100644 --- a/examples/acc-zeropadded.rs +++ b/examples/acc-zeropadded.rs @@ -1,9 +1,7 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::constraint_system::validate::validate_witness; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::BinaryField8b; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; const LOG_SIZE: usize = 4; @@ -11,9 +9,9 @@ const LOG_SIZE: usize = 4; fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let bytes = unconstrained::(&mut builder, "bytes", LOG_SIZE).unwrap(); + let bytes = unconstrained::(&mut builder, "bytes", LOG_SIZE).unwrap(); // Height of ZeroPadded column can't be smaller than input one. // If n_vars equals to LOG_SIZE, then no padding is required, diff --git a/examples/b32_mul.rs b/examples/b32_mul.rs index b0168f43a..5cdbbea53 100644 --- a/examples/b32_mul.rs +++ b/examples/b32_mul.rs @@ -1,9 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField32b, TowerField}; +use binius_field::{BinaryField32b, TowerField}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_macros::arith_expr; @@ -26,7 +26,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -42,18 +41,18 @@ fn main() -> Result<()> { let log_n_muls = log2_ceil_usize(args.n_ops as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField32b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_muls, ) .unwrap(); - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField32b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_muls, diff --git a/examples/bitwise_ops.rs b/examples/bitwise_ops.rs index 3eaec2bb5..6b46e29ad 100644 --- a/examples/bitwise_ops.rs +++ b/examples/bitwise_ops.rs @@ -3,11 +3,9 @@ use std::{fmt::Display, str::FromStr}; use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{ - arch::OptimalUnderlier, BinaryField128b, BinaryField1b, BinaryField32b, TowerField, -}; +use binius_field::{BinaryField1b, BinaryField32b, TowerField}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_macros::arith_expr; @@ -63,7 +61,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -80,16 +77,16 @@ fn main() -> Result<()> { log2_ceil_usize(args.n_u32_ops as usize) + BinaryField32b::TOWER_LEVEL; let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); // Assuming our 32bit values have been committed as bits - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField1b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_1b_operations, )?; - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField1b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_1b_operations, diff --git a/examples/collatz.rs b/examples/collatz.rs index c30305ca5..f51ce7a06 100644 --- a/examples/collatz.rs +++ b/examples/collatz.rs @@ -2,7 +2,7 @@ use anyhow::Result; use binius_circuits::{ - builder::ConstraintSystemBuilder, + builder::{types::U, ConstraintSystemBuilder}, collatz::{Advice, Collatz}, }; use binius_core::{ @@ -10,7 +10,6 @@ use binius_core::{ fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -29,9 +28,6 @@ struct Args { log_inv_rate: u32, } -type U = OptimalUnderlier; -type F = BinaryField128b; - const SECURITY_BITS: usize = 100; fn main() -> Result<()> { @@ -59,7 +55,7 @@ fn prove(x0: u32, log_inv_rate: usize) -> Result<(Advice, Proof), anyhow::Error> let advice = collatz.init_prover(); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let boundaries = collatz.build(&mut builder, advice)?; @@ -95,7 +91,7 @@ fn prove(x0: u32, log_inv_rate: usize) -> Result<(Advice, Proof), anyhow::Error> fn verify(x0: u32, advice: Advice, proof: Proof, log_inv_rate: usize) -> Result<(), anyhow::Error> { let collatz = Collatz::new(x0); - let mut builder = ConstraintSystemBuilder::::new(); + let mut builder = ConstraintSystemBuilder::new(); let boundaries = collatz.build(&mut builder, advice)?; diff --git a/examples/groestl_circuit.rs b/examples/groestl_circuit.rs.disabled similarity index 100% rename from examples/groestl_circuit.rs rename to examples/groestl_circuit.rs.disabled diff --git a/examples/keccakf_circuit.rs b/examples/keccakf_circuit.rs index ed87a920a..68f12812d 100644 --- a/examples/keccakf_circuit.rs +++ b/examples/keccakf_circuit.rs @@ -5,9 +5,8 @@ use std::vec; use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -28,7 +27,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -45,14 +43,14 @@ fn main() -> Result<()> { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = log_n_permutations; let trace_gen_scope = tracing::info_span!("generating trace").entered(); let input_witness = vec![]; let _state_out = - binius_circuits::keccakf::keccakf(&mut builder, Some(input_witness), log_size)?; + binius_circuits::keccakf::keccakf(&mut builder, &Some(input_witness), log_size)?; drop(trace_gen_scope); let witness = builder diff --git a/examples/modular_mul.rs b/examples/modular_mul.rs index b5dfb9e77..1649fa556 100644 --- a/examples/modular_mul.rs +++ b/examples/modular_mul.rs @@ -5,15 +5,14 @@ use std::array; use alloy_primitives::U512; use anyhow::Result; use binius_circuits::{ - builder::ConstraintSystemBuilder, + builder::{types::U, ConstraintSystemBuilder}, lasso::big_integer_ops::{byte_sliced_modular_mul, byte_sliced_test_utils::random_u512}, transparent, }; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; use binius_field::{ - arch::OptimalUnderlier128b, tower_levels::{TowerLevel4, TowerLevel8}, - BinaryField128b, BinaryField1b, BinaryField8b, Field, TowerField, + BinaryField1b, BinaryField8b, Field, TowerField, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; @@ -36,8 +35,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier128b; - type F = BinaryField128b; type B8 = BinaryField8b; const SECURITY_BITS: usize = 100; const WIDTH: usize = 4; @@ -53,7 +50,7 @@ fn main() -> Result<()> { println!("Verifying {} u32 modular multiplications", args.n_multiplications); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = log2_ceil_usize(args.n_multiplications as usize); let mut rng = thread_rng(); @@ -98,7 +95,7 @@ fn main() -> Result<()> { let zero_oracle_carry = transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - let _modded_product = byte_sliced_modular_mul::<_, _, TowerLevel4, TowerLevel8>( + let _modded_product = byte_sliced_modular_mul::( &mut builder, "lasso_bytesliced_mul", &mult_a, diff --git a/examples/sha256_circuit.rs b/examples/sha256_circuit.rs index 2c5d162ac..3835c79fc 100644 --- a/examples/sha256_circuit.rs +++ b/examples/sha256_circuit.rs @@ -5,11 +5,14 @@ use std::array; use anyhow::Result; -use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_circuits::{ + builder::{types::U, ConstraintSystemBuilder}, + unconstrained::unconstrained, +}; use binius_core::{ constraint_system, fiat_shamir::HasherChallenger, oracle::OracleId, tower::CanonicalTowerFamily, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b}; +use binius_field::BinaryField1b; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -32,7 +35,6 @@ struct Args { const COMPRESSION_LOG_LEN: usize = 5; fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -48,15 +50,11 @@ fn main() -> Result<()> { let log_n_compressions = log2_ceil_usize(args.n_compressions as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); let input: [OracleId; 16] = array::try_from_fn(|i| { - unconstrained::<_, _, BinaryField1b>( - &mut builder, - i, - log_n_compressions + COMPRESSION_LOG_LEN, - ) + unconstrained::(&mut builder, i, log_n_compressions + COMPRESSION_LOG_LEN) })?; let _state_out = binius_circuits::sha256::sha256( diff --git a/examples/sha256_circuit_with_lookup.rs b/examples/sha256_circuit_with_lookup.rs index 6b7cd791f..f7f83435b 100644 --- a/examples/sha256_circuit_with_lookup.rs +++ b/examples/sha256_circuit_with_lookup.rs @@ -5,13 +5,14 @@ use std::array; use anyhow::Result; -use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_circuits::{ + builder::{types::U, ConstraintSystemBuilder}, + unconstrained::unconstrained, +}; use binius_core::{ constraint_system, fiat_shamir::HasherChallenger, oracle::OracleId, tower::CanonicalTowerFamily, }; -use binius_field::{ - arch::OptimalUnderlier, as_packed_field::PackedType, BinaryField128b, BinaryField1b, -}; +use binius_field::{arch::OptimalUnderlier, as_packed_field::PackedType, BinaryField1b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -38,7 +39,6 @@ struct Args { const COMPRESSION_LOG_LEN: usize = 5; fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -54,15 +54,11 @@ fn main() -> Result<()> { let log_n_compressions = log2_ceil_usize(args.n_compressions as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating witness").entered(); let input: [OracleId; 16] = array::try_from_fn(|i| { - unconstrained::<_, _, BinaryField1b>( - &mut builder, - i, - log_n_compressions + COMPRESSION_LOG_LEN, - ) + unconstrained::(&mut builder, i, log_n_compressions + COMPRESSION_LOG_LEN) })?; let _state_out = binius_circuits::lasso::sha256( diff --git a/examples/u32_add.rs b/examples/u32_add.rs index 2feed9ef3..22c446ce9 100644 --- a/examples/u32_add.rs +++ b/examples/u32_add.rs @@ -1,9 +1,12 @@ // Copyright 2024-2025 Irreducible Inc. use anyhow::Result; -use binius_circuits::{arithmetic::Flags, builder::ConstraintSystemBuilder}; +use binius_circuits::{ + arithmetic::Flags, + builder::{types::U, ConstraintSystemBuilder}, +}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b}; +use binius_field::BinaryField1b; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -24,7 +27,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -40,15 +42,15 @@ fn main() -> Result<()> { let log_n_additions = log2_ceil_usize(args.n_additions as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField1b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_additions + 5, )?; - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField1b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_additions + 5, diff --git a/examples/u32_mul.rs b/examples/u32_mul.rs index d7ecb1cff..8a8bfafc4 100644 --- a/examples/u32_mul.rs +++ b/examples/u32_mul.rs @@ -4,7 +4,7 @@ use std::array; use anyhow::Result; use binius_circuits::{ - builder::ConstraintSystemBuilder, + builder::{types::U, ConstraintSystemBuilder}, lasso::{ batch::LookupBatch, big_integer_ops::byte_sliced_mul, @@ -14,9 +14,8 @@ use binius_circuits::{ }; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; use binius_field::{ - arch::OptimalUnderlier, tower_levels::{TowerLevel4, TowerLevel8}, - BinaryField128b, BinaryField1b, BinaryField32b, BinaryField8b, Field, + BinaryField1b, BinaryField32b, BinaryField8b, Field, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; @@ -38,7 +37,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -54,12 +52,12 @@ fn main() -> Result<()> { let log_n_muls = log2_ceil_usize(args.n_muls as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); // Assuming our input data is already transposed, i.e a length 4 array of B8's let in_a = array::from_fn(|i| { - binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + binius_circuits::unconstrained::unconstrained::( &mut builder, format!("in_a_{}", i), log_n_muls, @@ -67,7 +65,7 @@ fn main() -> Result<()> { .unwrap() }); let in_b = array::from_fn(|i| { - binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + binius_circuits::unconstrained::unconstrained::( &mut builder, format!("in_b_{}", i), log_n_muls, @@ -84,7 +82,7 @@ fn main() -> Result<()> { let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]); let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); - let _mul_and_cout = byte_sliced_mul::<_, _, TowerLevel4, TowerLevel8>( + let _mul_and_cout = byte_sliced_mul::( &mut builder, "lasso_bytesliced_mul", &in_a, @@ -95,9 +93,9 @@ fn main() -> Result<()> { &mut lookup_batch_add, &mut lookup_batch_dci, )?; - lookup_batch_mul.execute::(&mut builder)?; - lookup_batch_add.execute::(&mut builder)?; - lookup_batch_dci.execute::(&mut builder)?; + lookup_batch_mul.execute::(&mut builder)?; + lookup_batch_add.execute::(&mut builder)?; + lookup_batch_dci.execute::(&mut builder)?; drop(trace_gen_scope); diff --git a/examples/u32add_with_lookup.rs b/examples/u32add_with_lookup.rs index 1eba4ab79..9bc06cfbd 100644 --- a/examples/u32add_with_lookup.rs +++ b/examples/u32add_with_lookup.rs @@ -1,11 +1,10 @@ // Copyright 2024-2025 Irreducible Inc. use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; use binius_field::{ - arch::OptimalUnderlier, as_packed_field::PackedType, BinaryField128b, BinaryField1b, - BinaryField8b, + arch::OptimalUnderlier, as_packed_field::PackedType, BinaryField1b, BinaryField8b, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; @@ -29,7 +28,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -45,20 +43,20 @@ fn main() -> Result<()> { let log_n_additions = log2_ceil_usize(args.n_additions as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_additions + 2, )?; - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_additions + 2, )?; - let _product = binius_circuits::lasso::u32add::<_, _, BinaryField8b, BinaryField8b>( + let _product = binius_circuits::lasso::u32add::( &mut builder, "out_c", in_a, diff --git a/examples/u8mul.rs b/examples/u8mul.rs index 90d706dc5..cc6e0be9f 100644 --- a/examples/u8mul.rs +++ b/examples/u8mul.rs @@ -2,11 +2,11 @@ use anyhow::Result; use binius_circuits::{ - builder::ConstraintSystemBuilder, + builder::{types::U, ConstraintSystemBuilder}, lasso::{batch::LookupBatch, lookups}, }; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField32b, BinaryField8b}; +use binius_field::{BinaryField32b, BinaryField8b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -27,7 +27,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -43,15 +42,15 @@ fn main() -> Result<()> { let log_n_multiplications = log2_ceil_usize(args.n_multiplications as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_multiplications, )?; - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_multiplications, @@ -70,7 +69,7 @@ fn main() -> Result<()> { args.n_multiplications as usize, )?; - lookup_batch.execute::<_, _, BinaryField32b>(&mut builder)?; + lookup_batch.execute::(&mut builder)?; drop(trace_gen_scope); let witness = builder diff --git a/examples/vision32b_circuit.rs b/examples/vision32b_circuit.rs index 61c74d762..b5afa6f58 100644 --- a/examples/vision32b_circuit.rs +++ b/examples/vision32b_circuit.rs @@ -10,11 +10,11 @@ use std::array; use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{ constraint_system, fiat_shamir::HasherChallenger, oracle::OracleId, tower::CanonicalTowerFamily, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField32b, BinaryField8b}; +use binius_field::{BinaryField32b, BinaryField8b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::IsomorphicEvaluationDomainFactory; @@ -35,7 +35,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -51,11 +50,11 @@ fn main() -> Result<()> { let log_n_permutations = log2_ceil_usize(args.n_permutations as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); let state_in: [OracleId; 24] = array::from_fn(|i| { - binius_circuits::unconstrained::unconstrained::<_, _, BinaryField32b>( + binius_circuits::unconstrained::unconstrained::( &mut builder, format!("p_in_{i}"), log_n_permutations, diff --git a/scripts/nightly_benchmarks.py b/scripts/nightly_benchmarks.py new file mode 100755 index 000000000..1c2f9ad03 --- /dev/null +++ b/scripts/nightly_benchmarks.py @@ -0,0 +1,265 @@ +#!/usr/bin/python3 + +import argparse +import csv +import json +import os +import re +import subprocess +from typing import Union + +ENV_VARS = { + "RUSTFLAGS": "-C target-cpu=native", +} + +SAMPLE_SIZE = 5 + +KECCAKF_PERMS = 1 << 13 +VISION32B_PERMS = 1 << 14 +SHA256_PERMS = 1 << 14 +NUM_BINARY_OPS = 1 << 22 +NUM_MULS = 1 << 20 + +HASHER_TO_RUN = { + r"keccakf": { + "type": "hasher", + "display": r"Keccak-f", + "export": "keccakf-report.csv", + "args": ["keccakf_circuit", "--", "--n-permutations"], + "n_ops": KECCAKF_PERMS, + }, + "vision32b": { + "type": "hasher", + "display": r"Vision Mark-32", + "export": "vision32b-report.csv", + "args": ["vision32b_circuit", "--", "--n-permutations"], + "n_ops": VISION32B_PERMS, + }, + "sha256": { + "type": "hasher", + "display": "SHA-256", + "export": "sha256-report.csv", + "args": ["sha256_circuit", "--", "--n-compressions"], + "n_ops": SHA256_PERMS, + }, + "b32_mul": { + "type": "binary_ops", + "display": "BinaryField32b mul", + "export": "b32-mul-report.csv", + "args": ["b32_mul", "--", "--n-ops"], + "n_ops": NUM_MULS, + }, + "u32_add": { + "type": "binary_ops", + "display": "u32 add", + "export": "u32-add-report.csv", + "args": ["u32_add", "--", "--n-additions"], + "n_ops": NUM_BINARY_OPS, + }, + "u32_mul": { + "type": "binary_ops", + "display": "u32 mul", + "export": "u32-mul-report.csv", + "args": ["u32_mul", "--", "--n-muls"], + "n_ops": NUM_MULS, + }, + "xor": { + "type": "binary_ops", + "display": "Xor", + "export": "xor-report.csv", + "args": ["bitwise_ops", "--", "--op", "xor", "--n-u32-ops"], + "n_ops": NUM_BINARY_OPS, + }, + "and": { + "type": "binary_ops", + "display": "And", + "export": "and-report.csv", + "args": ["bitwise_ops", "--", "--op", "and", "--n-u32-ops"], + "n_ops": NUM_BINARY_OPS, + }, + "or": { + "type": "binary_ops", + "display": "Or", + "export": "or-report.csv", + "args": ["bitwise_ops", "--", "--op", "or", "--n-u32-ops"], + "n_ops": NUM_BINARY_OPS, + }, +} + +HASHER_BENCHMARKS = {} +BINARY_OPS_BENCHMARKS = {} + + +def run_benchmark(benchmark_args) -> tuple[bytes, bytes]: + command = ( + ["cargo", "run", "--release", "--example"] + + benchmark_args["args"] + + [f"{benchmark_args['n_ops']}"] + ) + env_vars_to_run = { + **os.environ, + **ENV_VARS, + "PROFILE_CSV_FILE": benchmark_args["export"], + } + process = subprocess.run( + command, env=env_vars_to_run, capture_output=True, check=True + ) + return process.stdout, process.stderr + + +def parse_csv_file(file_name) -> dict: + data = {} + with open(file_name) as file: + reader = csv.reader(file) + for row in reader: + if row[0] == "generating trace": + data.update({"trace_gen_time": int(row[2])}) + elif row[0] == "constraint_system::prove": + data.update({"proving_time": int(row[2])}) + elif row[0] == "constraint_system::verify": + data.update({"verification_time": int(row[2])}) + return data + + +KIB_TO_BYTES = 1024.0 +MIB_TO_BYTES = KIB_TO_BYTES * 1024.0 +GIB_TO_BYTES = MIB_TO_BYTES * 1024.0 +KB_TO_BYTES = 1000.0 +MB_TO_BYTES = KB_TO_BYTES * 1000.0 +GB_TO_BYTES = MB_TO_BYTES * 1000.0 + +SIZE_CONVERSIONS = { + "KiB": KIB_TO_BYTES, + "MiB": MIB_TO_BYTES, + "GiB": GIB_TO_BYTES, + " B": 1, + "KB": KB_TO_BYTES, + "MB": MB_TO_BYTES, + "GB": GB_TO_BYTES, +} + + +def parse_proof_size(proof_size: bytes) -> int: + proof_size = proof_size.decode("utf-8").strip() + for unit, factor in SIZE_CONVERSIONS.items(): + if proof_size.endswith(unit): + byte_len = float(proof_size[: -len(unit)]) * factor + break + else: + raise ValueError(f"Unknown proof size format: {proof_size}") + + # Convert to KiB + return int(byte_len / KIB_TO_BYTES) + + +def nano_to_milli(nano) -> float: + return float(nano) / 1000000.0 + + +def nano_to_seconds(nano) -> float: + return float(nano) / 1000000000.0 + + +def run_and_parse_benchmark(benchmark, benchmark_args) -> tuple[dict, int]: + data = {} + stdout = None + print(f"Running benchmark: {benchmark} with {SAMPLE_SIZE} samples") + for _ in range(SAMPLE_SIZE): + stdout, _stderr = run_benchmark(benchmark_args) + result = parse_csv_file(benchmark_args["export"]) + # Parse the csv file + if len(result.keys()) != 3: + print(f"Failed to parse csv file for benchmark: {benchmark}") + exit(1) + + # Append the results to the data + for key, value in result.items(): + if data.get(key) is None: + data[key] = [] + data[key].append(value) + # Get proof sizes + found = re.search(rb"Proof size: (.*)", stdout) + if found: + return data, parse_proof_size(found.group(1)) + else: + print(f"Failed to get proof size for benchmark: {benchmark}") + exit(1) + + +def run_benchmark_group(benchmarks) -> dict: + benchmark_results = {} + for benchmark, benchmark_args in benchmarks.items(): + try: + data, proof_size = run_and_parse_benchmark(benchmark, benchmark_args) + benchmark_results[benchmark] = {"proof_size_kib": proof_size} + data["n_ops"] = benchmark_args["n_ops"] + data["display"] = benchmark_args["display"] + data["type"] = benchmark_args["type"] + benchmark_results[benchmark].update(data) + + except Exception as e: + print(f"Failed to run benchmark: {benchmark} with error {e} \nExiting...") + exit(1) + return benchmark_results + + +def value_to_bencher(value: Union[list[float], int], throughput: bool = False) -> dict: + if isinstance(value, list): + avg_value = sum(value) / len(value) + max_value = max(value) + min_value = min(value) + else: + avg_value = max_value = min_value = value + + metric_type = "throughput" if throughput else "latency" + return { + metric_type: { + "value": avg_value, + "upper_value": max_value, + "lower_value": min_value, + } + } + + +def dict_to_bencher(data: dict) -> dict: + bencher_data = {} + for benchmark, value in data.items(): + # Name is of the following format: ::::(trace_gen_time | proving_time | verification_time | proof_size_kib | n_ops) + common_name = f"{value['type']}::{value['display']}" + for key in [ + "trace_gen_time", + "proving_time", + "verification_time", + "proof_size_kib", + "n_ops", + ]: + bencher_data[f"{common_name}::{key}"] = value_to_bencher(value[key]) + return bencher_data + + +def main(): + parser = argparse.ArgumentParser( + description="Run nightly benchmarks and export results" + ) + parser.add_argument( + "--export-file", + required=False, + type=str, + help="Export benchmarks results to file (defaults to stdout)", + ) + + args = parser.parse_args() + + benchmarks = run_benchmark_group(HASHER_TO_RUN) + + bencher_data = dict_to_bencher(benchmarks) + if args.export_file is None: + print("Couldn't find export file for hashers writing to stdout instead") + print(json.dumps(bencher_data)) + else: + with open(args.export_file, "w") as file: + json.dump(bencher_data, file) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_tests_and_examples.sh b/scripts/run_tests_and_examples.sh index 85b0ffec2..70da6e396 100755 --- a/scripts/run_tests_and_examples.sh +++ b/scripts/run_tests_and_examples.sh @@ -19,7 +19,7 @@ if [ -z "$CARGO_STABLE" ]; then # Execute examples. # Unfortunately there cargo doesn't support executing all examples with a single command. - # Cargo plugins such as "cargo-examples" do suport it but without a possibility to specify "release" profile. + # Cargo plugins such as "cargo-examples" do support it but without a possibility to specify "release" profile. for example in examples/*.rs do cargo run --profile $CARGO_PROFILE --example "$(basename "${example%.rs}")" $CARGO_EXTRA_FLAGS