From 0e8a314e865abad4c8a7ebc8b4a023b3f26c868e Mon Sep 17 00:00:00 2001 From: Nikita Lesnikov Date: Fri, 31 Jan 2025 19:08:50 +0000 Subject: [PATCH 01/50] Support for Karatsuba "infinity" point in evaluation & interpolation domains --- .../protocols/gkr_gpa/gpa_sumcheck/prove.rs | 4 +- .../sumcheck/prove/regular_sumcheck.rs | 4 +- .../protocols/sumcheck/prove/univariate.rs | 2 +- .../src/protocols/sumcheck/prove/zerocheck.rs | 4 +- crates/math/src/error.rs | 2 + crates/math/src/univariate.rs | 176 ++++++++++++++---- 6 files changed, 145 insertions(+), 47 deletions(-) diff --git a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs index 800ff7f7b..362c6f2c8 100644 --- a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs +++ b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs @@ -104,8 +104,8 @@ where let evaluation_points = domains .iter() - .max_by_key(|domain| domain.points().len()) - .map_or_else(|| Vec::new(), |domain| domain.points().to_vec()); + .max_by_key(|domain| domain.size()) + .map_or_else(|| Vec::new(), |domain| domain.finite_points().to_vec()); let state = ProverState::new( multilinears, diff --git a/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs b/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs index a3695046a..757b6d906 100644 --- a/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs @@ -142,8 +142,8 @@ where let evaluation_points = domains .iter() - .max_by_key(|domain| domain.points().len()) - .map_or_else(|| Vec::new(), |domain| domain.points().to_vec()); + .max_by_key(|domain| domain.size()) + .map_or_else(|| Vec::new(), |domain| domain.finite_points().to_vec()); let state = ProverState::new( multilinears, diff --git a/crates/core/src/protocols/sumcheck/prove/univariate.rs b/crates/core/src/protocols/sumcheck/prove/univariate.rs index 4d77666b9..7c80eff51 100644 --- a/crates/core/src/protocols/sumcheck/prove/univariate.rs +++ b/crates/core/src/protocols/sumcheck/prove/univariate.rs @@ -823,7 +823,7 @@ mod tests { .map(|i| interleaved_scalars[(i << log_batch) + batch_idx]) .collect::>(); - for (i, &point) in max_domain.points()[1 << skip_rounds..] + for (i, &point) in max_domain.finite_points()[1 << skip_rounds..] [..extrapolated_scalars_cnt] .iter() .enumerate() diff --git a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs index c3325604d..912dd7a84 100644 --- a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs @@ -480,8 +480,8 @@ where ) -> Result { let evaluation_points = domains .iter() - .max_by_key(|domain| domain.points().len()) - .map_or_else(|| Vec::new(), |domain| domain.points().to_vec()); + .max_by_key(|domain| domain.size()) + .map_or_else(|| Vec::new(), |domain| domain.finite_points().to_vec()); if claimed_prime_sums.len() != compositions.len() { bail!(Error::IncorrectClaimedPrimeSumsLength); diff --git a/crates/math/src/error.rs b/crates/math/src/error.rs index adfc1f747..bb53b885b 100644 --- a/crates/math/src/error.rs +++ b/crates/math/src/error.rs @@ -10,6 +10,8 @@ pub enum Error { MatrixNotSquare, #[error("the matrix is singular")] MatrixIsSingular, + #[error("domain size needs to be at least one")] + DomainSizeAtLeastOne, #[error("domain size is larger than the field")] DomainSizeTooLarge, #[error("the inputted packed values slice had an unexpected length")] diff --git a/crates/math/src/univariate.rs b/crates/math/src/univariate.rs index b99d26067..7e12f4d19 100644 --- a/crates/math/src/univariate.rs +++ b/crates/math/src/univariate.rs @@ -7,6 +7,7 @@ use binius_field::{ PackedField, }; use binius_utils::bail; +use itertools::{izip, Either}; use super::{binary_subspace::BinarySubspace, error::Error}; use crate::Matrix; @@ -17,8 +18,9 @@ use crate::Matrix; /// to reconstruct a degree <= d. This struct supports Barycentric extrapolation. #[derive(Debug, Clone)] pub struct EvaluationDomain { - points: Vec, + finite_points: Vec, weights: Vec, + with_infinity: bool, } /// An extended version of `EvaluationDomain` that supports interpolation to monomial form. Takes @@ -32,9 +34,20 @@ pub struct InterpolationDomain { /// Wraps type information to enable instantiating EvaluationDomains. #[auto_impl(&)] pub trait EvaluationDomainFactory: Clone + Sync { - /// Instantiates an EvaluationDomain from a set of points isomorphic to direct - /// lexicographic successors of zero in Fan-Paar tower - fn create(&self, size: usize) -> Result, Error>; + /// Instantiates an EvaluationDomain from `size` lexicographically first values from the + /// binary subspace. + fn create(&self, size: usize) -> Result, Error> { + self.create_with_infinity(size, false) + } + + /// Instantiates an EvaluationDomain from `size` values in total: lexicographically first values + /// from the binary subspace and potentially Karatsuba "infinity" point (which is the coefficient of + /// the highest power in the interpolated polynomial). + fn create_with_infinity( + &self, + size: usize, + with_infinity: bool, + ) -> Result, Error>; } #[derive(Default, Clone)] @@ -48,8 +61,18 @@ pub struct IsomorphicEvaluationDomainFactory { } impl EvaluationDomainFactory for DefaultEvaluationDomainFactory { - fn create(&self, size: usize) -> Result, Error> { - EvaluationDomain::from_points(make_evaluation_points(&self.subspace, size)?) + fn create_with_infinity( + &self, + size: usize, + with_infinity: bool, + ) -> Result, Error> { + if size == 0 && with_infinity { + bail!(Error::DomainSizeAtLeastOne); + } + EvaluationDomain::from_points( + make_evaluation_points(&self.subspace, size - if with_infinity { 1 } else { 0 })?, + with_infinity, + ) } } @@ -58,9 +81,17 @@ where FSrc: BinaryField, FTgt: Field + From + BinaryField, { - fn create(&self, size: usize) -> Result, Error> { - let points = make_evaluation_points(&self.subspace, size)?; - EvaluationDomain::from_points(points.into_iter().map(Into::into).collect()) + fn create_with_infinity( + &self, + size: usize, + with_infinity: bool, + ) -> Result, Error> { + if size == 0 && with_infinity { + bail!(Error::DomainSizeAtLeastOne); + } + let points = + make_evaluation_points(&self.subspace, size - if with_infinity { 1 } else { 0 })?; + EvaluationDomain::from_points(points.into_iter().map(Into::into).collect(), false) } } @@ -78,7 +109,8 @@ fn make_evaluation_points( impl From> for InterpolationDomain { fn from(evaluation_domain: EvaluationDomain) -> Self { let n = evaluation_domain.size(); - let evaluation_matrix = vandermonde(evaluation_domain.points()); + let evaluation_matrix = + vandermonde(evaluation_domain.finite_points(), evaluation_domain.with_infinity()); let mut interpolation_matrix = Matrix::zeros(n, n); evaluation_matrix .inverse_into(&mut interpolation_matrix) @@ -97,17 +129,25 @@ impl From> for InterpolationDomain { } impl EvaluationDomain { - pub fn from_points(points: Vec) -> Result { - let weights = compute_barycentric_weights(&points)?; - Ok(Self { points, weights }) + pub fn from_points(finite_points: Vec, with_infinity: bool) -> Result { + let weights = compute_barycentric_weights(&finite_points)?; + Ok(Self { + finite_points, + weights, + with_infinity, + }) } pub fn size(&self) -> usize { - self.points.len() + self.finite_points.len() + if self.with_infinity { 1 } else { 0 } } - pub fn points(&self) -> &[F] { - self.points.as_slice() + pub fn finite_points(&self) -> &[F] { + self.finite_points.as_slice() + } + + pub const fn with_infinity(&self) -> bool { + self.with_infinity } /// Compute a vector of Lagrange polynomial evaluations in $O(N)$ at a given point `x`. @@ -116,19 +156,23 @@ impl EvaluationDomain { /// are defined by /// $$L_i(x) = \sum_{j \neq i}\frac{x - \pi_j}{\pi_i - \pi_j}$$ pub fn lagrange_evals>(&self, x: FE) -> Vec { - let num_evals = self.size(); + let num_evals = self.finite_points().len(); let mut result: Vec = vec![FE::ONE; num_evals]; // Multiply the product suffixes for i in (1..num_evals).rev() { - result[i - 1] = result[i] * (x - self.points[i]); + result[i - 1] = result[i] * (x - self.finite_points[i]); } let mut prefix = FE::ONE; // Multiply the product prefixes and weights - for ((r, &point), &weight) in result.iter_mut().zip(&self.points).zip(&self.weights) { + for ((r, &point), &weight) in result + .iter_mut() + .zip(&self.finite_points) + .zip(&self.weights) + { *r *= prefix * weight; prefix *= x - point; } @@ -141,18 +185,26 @@ impl EvaluationDomain { where PE: PackedField>, { - let lagrange_eval_results = self.lagrange_evals(x); - - let n = self.size(); - if values.len() != n { + if values.len() != self.size() { bail!(Error::ExtrapolateNumberOfEvaluations); } - let result = lagrange_eval_results - .into_iter() - .zip(values) - .map(|(evaluation, &value)| value * evaluation) - .sum::(); + let (values_iter, infinity_term) = if self.with_infinity { + let (&value_at_infinity, finite_values) = + values.split_last().expect("values length checked above"); + let highest_degree = finite_values.len() as u64; + let iter = izip!(&self.finite_points, finite_values).map(move |(&point, &value)| { + value - value_at_infinity * PE::Scalar::from(point).pow(highest_degree) + }); + (Either::Left(iter), value_at_infinity * x.pow(highest_degree)) + } else { + (Either::Right(values.iter().copied()), PE::zero()) + }; + + let result = izip!(self.lagrange_evals(x), values_iter) + .map(|(lagrange_at_x, value)| value * lagrange_at_x) + .sum::() + + infinity_term; Ok(result) } @@ -163,8 +215,12 @@ impl InterpolationDomain { self.evaluation_domain.size() } - pub fn points(&self) -> &[F] { - self.evaluation_domain.points() + pub fn finite_points(&self) -> &[F] { + self.evaluation_domain.finite_points() + } + + pub const fn with_infinity(&self) -> bool { + self.evaluation_domain.with_infinity() } pub fn extrapolate(&self, values: &[PE], x: PE::Scalar) -> Result @@ -175,8 +231,7 @@ impl InterpolationDomain { } pub fn interpolate>(&self, values: &[FE]) -> Result, Error> { - let n = self.evaluation_domain.size(); - if values.len() != n { + if values.len() != self.evaluation_domain.size() { bail!(Error::ExtrapolateNumberOfEvaluations); } @@ -236,8 +291,8 @@ fn compute_barycentric_weights(points: &[F]) -> Result, Error> .collect() } -fn vandermonde(xs: &[F]) -> Matrix { - let n = xs.len(); +fn vandermonde(xs: &[F], with_infinity: bool) -> Matrix { + let n = xs.len() + if with_infinity { 1 } else { 0 }; let mut mat = Matrix::zeros(n, n); for (i, x_i) in xs.iter().copied().enumerate() { @@ -249,6 +304,11 @@ fn vandermonde(xs: &[F]) -> Matrix { mat[(i, j)] = acc; } } + + if with_infinity { + mat[(n - 1, n - 1)] = F::ONE; + } + mat } @@ -277,7 +337,7 @@ mod tests { fn test_new_domain() { let domain_factory = DefaultEvaluationDomainFactory::::default(); assert_eq!( - domain_factory.create(3).unwrap().points, + domain_factory.create(3).unwrap().finite_points, &[ BinaryField8b::new(0), BinaryField8b::new(1), @@ -292,7 +352,7 @@ mod tests { let iso_domain_factory = IsomorphicEvaluationDomainFactory::::default(); let domain_1: EvaluationDomain = default_domain_factory.create(10).unwrap(); let domain_2: EvaluationDomain = iso_domain_factory.create(10).unwrap(); - assert_eq!(domain_1.points, domain_2.points); + assert_eq!(domain_1.finite_points, domain_2.finite_points); } #[test] @@ -303,11 +363,11 @@ mod tests { let domain_2: EvaluationDomain = iso_domain_factory.create(10).unwrap(); assert_eq!( domain_1 - .points + .finite_points .into_iter() .map(AESTowerField32b::from) .collect::>(), - domain_2.points + domain_2.finite_points ); } @@ -343,6 +403,7 @@ mod tests { repeat_with(|| ::random(&mut rng)) .take(degree + 1) .collect(), + false, ) .unwrap(); @@ -351,7 +412,7 @@ mod tests { .collect::>(); let values = domain - .points() + .finite_points() .iter() .map(|&x| evaluate_univariate(&coeffs, x)) .collect::>(); @@ -370,6 +431,7 @@ mod tests { repeat_with(|| ::random(&mut rng)) .take(degree + 1) .collect(), + false, ) .unwrap(); @@ -378,10 +440,44 @@ mod tests { .collect::>(); let values = domain - .points() + .finite_points() + .iter() + .map(|&x| evaluate_univariate(&coeffs, x)) + .collect::>(); + + let interpolated = InterpolationDomain::from(domain) + .interpolate(&values) + .unwrap(); + assert_eq!(interpolated, coeffs); + } + + #[test] + fn test_infinity() { + let mut rng = StdRng::seed_from_u64(0); + let degree = 6; + + let domain = EvaluationDomain::from_points( + repeat_with(|| ::random(&mut rng)) + .take(degree) + .collect(), + true, + ) + .unwrap(); + + let coeffs = repeat_with(|| ::random(&mut rng)) + .take(degree + 1) + .collect::>(); + + let mut values = domain + .finite_points() .iter() .map(|&x| evaluate_univariate(&coeffs, x)) .collect::>(); + values.push(coeffs.last().copied().unwrap()); + + let x = ::random(&mut rng); + let expected_y = evaluate_univariate(&coeffs, x); + assert_eq!(domain.extrapolate(&values, x).unwrap(), expected_y); let interpolated = InterpolationDomain::from(domain) .interpolate(&values) From 6d4e9457e9cb7545d2906caab97ca6a5d5164e43 Mon Sep 17 00:00:00 2001 From: Nikita Lesnikov Date: Fri, 31 Jan 2025 19:54:20 +0000 Subject: [PATCH 02/50] [sumcheck] Small field zerocheck and its HAL support removed --- .../protocols/gkr_gpa/gpa_sumcheck/prove.rs | 4 +- .../sumcheck/prove/concrete_prover.rs | 65 -------- .../core/src/protocols/sumcheck/prove/mod.rs | 2 - .../protocols/sumcheck/prove/prover_state.rs | 32 +--- .../sumcheck/prove/regular_sumcheck.rs | 4 +- .../src/protocols/sumcheck/prove/zerocheck.rs | 140 +++++++---------- crates/hal/src/backend.rs | 68 +-------- crates/hal/src/cpu.rs | 45 +----- crates/hal/src/sumcheck_evaluator.rs | 10 +- crates/hal/src/sumcheck_round_calculator.rs | 141 ++++-------------- 10 files changed, 109 insertions(+), 402 deletions(-) delete mode 100644 crates/core/src/protocols/sumcheck/prove/concrete_prover.rs diff --git a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs index 362c6f2c8..51736bf52 100644 --- a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs +++ b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs @@ -229,7 +229,7 @@ where }) .collect::>(); - let evals = self.state.calculate_later_round_evals(&evaluators)?; + let evals = self.state.calculate_round_evals(&evaluators)?; let coeffs = self.state .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)?; @@ -287,7 +287,7 @@ where gpa_round_challenge: P::Scalar, } -impl SumcheckEvaluator +impl SumcheckEvaluator for GPAEvaluator<'_, P, FDomain, Composition> where F: Field + ExtensionField, diff --git a/crates/core/src/protocols/sumcheck/prove/concrete_prover.rs b/crates/core/src/protocols/sumcheck/prove/concrete_prover.rs deleted file mode 100644 index b67e99c2f..000000000 --- a/crates/core/src/protocols/sumcheck/prove/concrete_prover.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2024-2025 Irreducible Inc. - -use binius_field::{ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable}; -use binius_hal::ComputationBackend; -use binius_math::{CompositionPolyOS, MultilinearPoly}; - -use super::{batch_prove::SumcheckProver, RegularSumcheckProver, ZerocheckProver}; -use crate::protocols::sumcheck::{common::RoundCoeffs, error::Error}; - -/// A sum type that is used to put both regular sumchecks and zerochecks into the same `batch_prove` call. -pub enum ConcreteProver<'a, FDomain, PBase, P, CompositionBase, Composition, M, Backend> -where - FDomain: Field, - PBase: PackedField, - P: PackedField, - M: MultilinearPoly

+ Send + Sync, - Backend: ComputationBackend, -{ - Sumcheck(RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend>), - Zerocheck(ZerocheckProver<'a, FDomain, PBase, P, CompositionBase, Composition, M, Backend>), -} - -impl SumcheckProver - for ConcreteProver<'_, FDomain, FBase, P, CompositionBase, Composition, M, Backend> -where - F: Field + ExtensionField + ExtensionField, - FDomain: Field, - FBase: ExtensionField, - P: PackedFieldIndexable - + PackedExtension - + PackedExtension - + PackedExtension, - CompositionBase: CompositionPolyOS<

>::PackedSubfield>, - Composition: CompositionPolyOS

, - M: MultilinearPoly

+ Send + Sync, - Backend: ComputationBackend, -{ - fn n_vars(&self) -> usize { - match self { - ConcreteProver::Sumcheck(prover) => prover.n_vars(), - ConcreteProver::Zerocheck(prover) => prover.n_vars(), - } - } - - fn execute(&mut self, batch_coeff: F) -> Result, Error> { - match self { - ConcreteProver::Sumcheck(prover) => prover.execute(batch_coeff), - ConcreteProver::Zerocheck(prover) => prover.execute(batch_coeff), - } - } - - fn fold(&mut self, challenge: F) -> Result<(), Error> { - match self { - ConcreteProver::Sumcheck(prover) => prover.fold(challenge), - ConcreteProver::Zerocheck(prover) => prover.fold(challenge), - } - } - - fn finish(self: Box) -> Result, Error> { - match *self { - ConcreteProver::Sumcheck(prover) => Box::new(prover).finish(), - ConcreteProver::Zerocheck(prover) => Box::new(prover).finish(), - } - } -} diff --git a/crates/core/src/protocols/sumcheck/prove/mod.rs b/crates/core/src/protocols/sumcheck/prove/mod.rs index 94029d716..44c1f587c 100644 --- a/crates/core/src/protocols/sumcheck/prove/mod.rs +++ b/crates/core/src/protocols/sumcheck/prove/mod.rs @@ -3,7 +3,6 @@ mod batch_prove; mod batch_prove_univariate_zerocheck; pub(crate) mod common; -mod concrete_prover; pub mod front_loaded; pub mod oracles; pub mod prover_state; @@ -15,7 +14,6 @@ pub use batch_prove::{batch_prove, batch_prove_with_start, SumcheckProver}; pub use batch_prove_univariate_zerocheck::{ batch_prove_zerocheck_univariate_round, UnivariateZerocheckProver, }; -pub use concrete_prover::ConcreteProver; pub use oracles::{ constraint_set_sumcheck_prover, constraint_set_zerocheck_prover, split_constraint_set, }; diff --git a/crates/core/src/protocols/sumcheck/prove/prover_state.rs b/crates/core/src/protocols/sumcheck/prove/prover_state.rs index d05e72523..606152664 100644 --- a/crates/core/src/protocols/sumcheck/prove/prover_state.rs +++ b/crates/core/src/protocols/sumcheck/prove/prover_state.rs @@ -71,7 +71,7 @@ impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend> where FDomain: Field, F: Field + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, + P: PackedField + PackedExtension, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -242,41 +242,17 @@ where .collect() } - /// Calculate the accumulated evaluations for the first sumcheck round. - #[instrument(skip_all, level = "debug")] - pub fn calculate_first_round_evals( - &self, - evaluators: &[Evaluator], - ) -> Result>, Error> - where - FBase: ExtensionField, - F: ExtensionField, - P: PackedExtension, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, - { - Ok(self.backend.sumcheck_compute_first_round_evals( - self.n_vars, - &self.multilinears, - evaluators, - &self.evaluation_points, - )?) - } - /// Calculate the accumulated evaluations for an arbitrary sumcheck round. - /// - /// See [`Self::calculate_first_round_evals`] for an optimized version of this method that - /// operates over small fields in the first round. #[instrument(skip_all, level = "debug")] - pub fn calculate_later_round_evals( + pub fn calculate_round_evals( &self, evaluators: &[Evaluator], ) -> Result>, Error> where - Evaluator: SumcheckEvaluator + Sync, + Evaluator: SumcheckEvaluator + Sync, Composition: CompositionPolyOS

, { - Ok(self.backend.sumcheck_compute_later_round_evals( + Ok(self.backend.sumcheck_compute_round_evals( self.n_vars, self.tensor_query.as_ref().map(Into::into), &self.multilinears, diff --git a/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs b/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs index 757b6d906..8cae1a0a1 100644 --- a/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs @@ -193,7 +193,7 @@ where }) .collect::>(); - let evals = self.state.calculate_later_round_evals(&evaluators)?; + let evals = self.state.calculate_round_evals(&evaluators)?; self.state .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals) } @@ -213,7 +213,7 @@ where _marker: PhantomData

, } -impl SumcheckEvaluator +impl SumcheckEvaluator for RegularSumcheckEvaluator<'_, P, FDomain, Composition> where F: Field + ExtensionField, diff --git a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs index 912dd7a84..23d4a573e 100644 --- a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs @@ -23,7 +23,7 @@ use tracing::instrument; use crate::{ polynomial::{Error as PolynomialError, MultilinearComposite}, protocols::sumcheck::{ - common::{determine_switchovers, equal_n_vars_check, small_field_embedding_degree_check}, + common::{determine_switchovers, equal_n_vars_check}, prove::{ common::fold_partial_eq_ind, univariate::{ @@ -154,8 +154,6 @@ where validate_witness(&multilinears, &compositions)?; } - small_field_embedding_degree_check::<_, FBase, P, _>(&multilinears)?; - let switchover_rounds = determine_switchovers(&multilinears, switchover_fn); let zerocheck_challenges = zerocheck_challenges.to_vec(); @@ -188,16 +186,7 @@ where pub fn into_regular_zerocheck( self, ) -> Result< - ZerocheckProver< - 'a, - FDomain, - FBase, - P, - CompositionBase, - Composition, - MultilinearWitness<'m, P>, - Backend, - >, + ZerocheckProver<'a, FDomain, P, Composition, MultilinearWitness<'m, P>, Backend>, Error, > { if self.univariate_evals_output.is_some() { @@ -224,26 +213,29 @@ where validate_witness(&multilinears, &compositions)?; } + let compositions = self + .compositions + .into_iter() + .map(|(_, _, composition)| composition) + .collect::>(); + // Evaluate zerocheck partial indicator in variables 1..n_vars let start = self.n_vars.min(1); let partial_eq_ind_evals = self .backend .tensor_product_full_query(&self.zerocheck_challenges[start..])?; - let claimed_sums = vec![F::ZERO; self.compositions.len()]; + let claimed_sums = vec![F::ZERO; compositions.len()]; // This is a regular multilinear zerocheck constructor, split over two creation stages. ZerocheckProver::new( multilinears, self.switchover_rounds, - self.compositions - .into_iter() - .map(|(_, a, b)| (a, b)) - .collect(), + compositions, partial_eq_ind_evals, self.zerocheck_challenges, claimed_sums, self.domains, - RegularFirstRound::BaseField, + RegularFirstRound::SkipCube, self.backend, ) } @@ -387,23 +379,26 @@ where let zerocheck_challenges = self.zerocheck_challenges.clone(); + let compositions = self + .compositions + .into_iter() + .map(|(_, _, composition)| composition) + .collect(); + // This is also regular multilinear zerocheck constructor, but "jump started" in round // `skip_rounds` while using witness with a projected univariate round. // NB: first round evaluator has to be overriden due to issues proving // `P: RepackedExtension

` relation in the generic context, as well as the need // to use later round evaluator (as this _is_ a "later" round, albeit numbered at zero) - let regular_prover = ZerocheckProver::<_, FBase, _, _, _, _, _>::new( + let regular_prover = ZerocheckProver::new( partial_low_multilinears, switchover_rounds, - self.compositions - .into_iter() - .map(|(_, a, b)| (a, b)) - .collect(), + compositions, partial_eq_ind_evals, zerocheck_challenges, claimed_prime_sums, self.domains, - RegularFirstRound::LargeField, + RegularFirstRound::LaterRound, self.backend, )?; @@ -413,8 +408,8 @@ where #[derive(Debug, Clone, Copy)] enum RegularFirstRound { - BaseField, - LargeField, + SkipCube, + LaterRound, } /// A "regular" multilinear zerocheck prover. @@ -432,10 +427,9 @@ enum RegularFirstRound { /// /// [Gruen24]: #[derive(Debug)] -pub struct ZerocheckProver<'a, FDomain, FBase, P, CompositionBase, Composition, M, Backend> +pub struct ZerocheckProver<'a, FDomain, P, Composition, M, Backend> where FDomain: Field, - FBase: PackedField, P: PackedField, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, @@ -445,23 +439,17 @@ where eq_ind_eval: P::Scalar, partial_eq_ind_evals: Backend::Vec

, zerocheck_challenges: Vec, - compositions: Vec<(CompositionBase, Composition)>, + compositions: Vec, domains: Vec>, first_round: RegularFirstRound, - _f_base_marker: PhantomData, } -impl<'a, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend> - ZerocheckProver<'a, FDomain, FBase, P, CompositionBase, Composition, M, Backend> +impl<'a, F, FDomain, P, Composition, M, Backend> + ZerocheckProver<'a, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField + ExtensionField, + F: Field + ExtensionField, FDomain: Field, - FBase: ExtensionField, - P: PackedFieldIndexable - + PackedExtension - + PackedExtension - + PackedExtension, - CompositionBase: CompositionPolyOS>, + P: PackedFieldIndexable + PackedExtension, Composition: CompositionPolyOS

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, @@ -470,7 +458,7 @@ where fn new( multilinears: Vec, switchover_rounds: Vec, - compositions: Vec<(CompositionBase, Composition)>, + compositions: Vec, partial_eq_ind_evals: Backend::Vec

, zerocheck_challenges: Vec, claimed_prime_sums: Vec, @@ -517,7 +505,6 @@ where compositions, domains, first_round, - _f_base_marker: PhantomData, }) } @@ -547,17 +534,12 @@ where } } -impl SumcheckProver - for ZerocheckProver<'_, FDomain, FBase, P, CompositionBase, Composition, M, Backend> +impl SumcheckProver + for ZerocheckProver<'_, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField + ExtensionField, + F: Field + ExtensionField, FDomain: Field, - FBase: ExtensionField, - P: PackedFieldIndexable - + PackedExtension - + PackedExtension - + PackedExtension, - CompositionBase: CompositionPolyOS<

>::PackedSubfield>, + P: PackedFieldIndexable + PackedExtension, Composition: CompositionPolyOS

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, @@ -580,33 +562,29 @@ where #[instrument(skip_all, name = "ZerocheckProver::execute", level = "debug")] fn execute(&mut self, batch_coeff: F) -> Result, Error> { let round = self.round(); - let base_field_first_round = - round == 0 && matches!(self.first_round, RegularFirstRound::BaseField); - let coeffs = if base_field_first_round { + let skip_cube_first_round = + round == 0 && matches!(self.first_round, RegularFirstRound::SkipCube); + let coeffs = if skip_cube_first_round { let evaluators = izip!(&self.compositions, &self.domains) - .map(|((composition_base, composition), interpolation_domain)| { - ZerocheckFirstRoundEvaluator { - composition_base, - composition, - interpolation_domain, - partial_eq_ind_evals: &self.partial_eq_ind_evals, - _f_base_marker: PhantomData::, - } + .map(|(composition, interpolation_domain)| ZerocheckFirstRoundEvaluator { + composition, + interpolation_domain, + partial_eq_ind_evals: &self.partial_eq_ind_evals, }) .collect::>(); - let evals = self.state.calculate_first_round_evals(&evaluators)?; + let evals = self.state.calculate_round_evals(&evaluators)?; self.state .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)? } else { let evaluators = izip!(&self.compositions, &self.domains) - .map(|((_, composition), interpolation_domain)| ZerocheckLaterRoundEvaluator { + .map(|(composition, interpolation_domain)| ZerocheckLaterRoundEvaluator { composition, interpolation_domain, partial_eq_ind_evals: &self.partial_eq_ind_evals, round_zerocheck_challenge: self.zerocheck_challenges[round], }) .collect::>(); - let evals = self.state.calculate_later_round_evals(&evaluators)?; + let evals = self.state.calculate_round_evals(&evaluators)?; self.state .calculate_round_coeffs_from_evals(&evaluators, batch_coeff, evals)? }; @@ -636,27 +614,21 @@ where } } -struct ZerocheckFirstRoundEvaluator<'a, P, FBase, FDomain, CompositionBase, Composition> +struct ZerocheckFirstRoundEvaluator<'a, P, FDomain, Composition> where P: PackedField, - FBase: Field, FDomain: Field, { - composition_base: &'a CompositionBase, composition: &'a Composition, interpolation_domain: &'a InterpolationDomain, partial_eq_ind_evals: &'a [P], - _f_base_marker: PhantomData, } -impl SumcheckEvaluator - for ZerocheckFirstRoundEvaluator<'_, P, FBase, FDomain, CompositionBase, Composition> +impl SumcheckEvaluator + for ZerocheckFirstRoundEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField + ExtensionField, - FBase: Field, - P: PackedField + PackedExtension, + P: PackedField>, FDomain: Field, - CompositionBase: CompositionPolyOS>, Composition: CompositionPolyOS

, { fn eval_point_indices(&self) -> Range { @@ -670,7 +642,7 @@ where &self, subcube_vars: usize, subcube_index: usize, - batch_query: &[&[PackedSubfield]], + batch_query: &[&[P]], ) -> P { // If the composition is a linear polynomial, then the composite multivariate polynomial // is multilinear. If the prover is honest, then this multilinear is identically zero, @@ -681,7 +653,7 @@ where let row_len = batch_query.first().map_or(0, |row| row.len()); stackalloc_with_default(row_len, |evals| { - self.composition_base + self.composition .batch_evaluate(batch_query, evals) .expect("correct by query construction invariant"); @@ -705,13 +677,12 @@ where } } -impl SumcheckInterpolator - for ZerocheckFirstRoundEvaluator<'_, P, FBase, FDomain, CompositionBase, Composition> +impl SumcheckInterpolator + for ZerocheckFirstRoundEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField + ExtensionField, - FBase: Field, - FDomain: Field, + F: Field + ExtensionField, P: PackedField, + FDomain: Field, { fn round_evals_to_coeffs( &self, @@ -741,11 +712,10 @@ where round_zerocheck_challenge: P::Scalar, } -impl SumcheckEvaluator +impl SumcheckEvaluator for ZerocheckLaterRoundEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, + P: PackedField>, FDomain: Field, Composition: CompositionPolyOS

, { diff --git a/crates/hal/src/backend.rs b/crates/hal/src/backend.rs index 05bd2e87a..02a443e7e 100644 --- a/crates/hal/src/backend.rs +++ b/crates/hal/src/backend.rs @@ -42,28 +42,8 @@ pub trait ComputationBackend: Send + Sync + Debug { query: &[P::Scalar], ) -> Result, Error>; - /// Calculate the accumulated evaluations for the first round of zerocheck. - fn sumcheck_compute_first_round_evals( - &self, - n_vars: usize, - multilinears: &[SumcheckMultilinear], - evaluators: &[Evaluator], - evaluation_points: &[FDomain], - ) -> Result>, Error> - where - FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension - + PackedExtension, - M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

; - /// Calculate the accumulated evaluations for an arbitrary round of zerocheck. - fn sumcheck_compute_later_round_evals( + fn sumcheck_compute_round_evals( &self, n_vars: usize, tensor_query: Option>, @@ -73,12 +53,9 @@ pub trait ComputationBackend: Send + Sync + Debug { ) -> Result>, Error> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension, + P: PackedField> + PackedExtension, M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, + Evaluator: SumcheckEvaluator + Sync, Composition: CompositionPolyOS

; /// Partially evaluate the polynomial with assignment to the high-indexed variables. @@ -108,35 +85,7 @@ where T::tensor_product_full_query(self, query) } - fn sumcheck_compute_first_round_evals( - &self, - n_vars: usize, - multilinears: &[SumcheckMultilinear], - evaluators: &[Evaluator], - evaluation_points: &[FDomain], - ) -> Result>, Error> - where - FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension - + PackedExtension, - M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, - { - T::sumcheck_compute_first_round_evals::<_, FBase, _, _, _, _, _>( - self, - n_vars, - multilinears, - evaluators, - evaluation_points, - ) - } - - fn sumcheck_compute_later_round_evals( + fn sumcheck_compute_round_evals( &self, n_vars: usize, tensor_query: Option>, @@ -146,15 +95,12 @@ where ) -> Result>, Error> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension, + P: PackedField> + PackedExtension, M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, + Evaluator: SumcheckEvaluator + Sync, Composition: CompositionPolyOS

, { - T::sumcheck_compute_later_round_evals( + T::sumcheck_compute_round_evals( self, n_vars, tensor_query, diff --git a/crates/hal/src/cpu.rs b/crates/hal/src/cpu.rs index 3d9eac8fe..b9f079f75 100644 --- a/crates/hal/src/cpu.rs +++ b/crates/hal/src/cpu.rs @@ -10,8 +10,8 @@ use binius_math::{ use tracing::instrument; use crate::{ - sumcheck_round_calculator::{calculate_first_round_evals, calculate_later_round_evals}, - ComputationBackend, Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear, + sumcheck_round_calculator::calculate_round_evals, ComputationBackend, Error, RoundEvals, + SumcheckEvaluator, SumcheckMultilinear, }; /// Implementation of ComputationBackend for the default Backend that uses the CPU for all computations. @@ -37,31 +37,7 @@ impl ComputationBackend for CpuBackend { Ok(eq_ind_partial_eval(query)) } - fn sumcheck_compute_first_round_evals( - &self, - n_vars: usize, - multilinears: &[SumcheckMultilinear], - evaluators: &[Evaluator], - evaluation_points: &[FDomain], - ) -> Result>, Error> - where - FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, - M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, - { - calculate_first_round_evals::<_, FBase, _, _, _, _, _>( - n_vars, - multilinears, - evaluators, - evaluation_points, - ) - } - - fn sumcheck_compute_later_round_evals( + fn sumcheck_compute_round_evals( &self, n_vars: usize, tensor_query: Option>, @@ -71,21 +47,12 @@ impl ComputationBackend for CpuBackend { ) -> Result>, Error> where FDomain: Field, - F: Field + ExtensionField, - P: PackedField - + PackedExtension - + PackedExtension, + P: PackedField> + PackedExtension, M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, + Evaluator: SumcheckEvaluator + Sync, Composition: CompositionPolyOS

, { - calculate_later_round_evals( - n_vars, - tensor_query, - multilinears, - evaluators, - evaluation_points, - ) + calculate_round_evals(n_vars, tensor_query, multilinears, evaluators, evaluation_points) } #[instrument(skip_all, name = "CpuBackend::evaluate_partial_high")] diff --git a/crates/hal/src/sumcheck_evaluator.rs b/crates/hal/src/sumcheck_evaluator.rs index 982b9c6c7..31e0721a2 100644 --- a/crates/hal/src/sumcheck_evaluator.rs +++ b/crates/hal/src/sumcheck_evaluator.rs @@ -2,17 +2,13 @@ use std::ops::Range; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField, PackedSubfield}; +use binius_field::{Field, PackedField}; /// Evaluations of a polynomial at a set of evaluation points. #[derive(Debug, Clone)] pub struct RoundEvals(pub Vec); -pub trait SumcheckEvaluator -where - FBase: Field, - P: PackedField> + PackedExtension, -{ +pub trait SumcheckEvaluator { /// The range of eval point indices over which composition evaluation and summation should happen. /// Returned range must equal the result of `n_round_evals()` in length. fn eval_point_indices(&self) -> Range; @@ -27,7 +23,7 @@ where &self, subcube_vars: usize, subcube_index: usize, - batch_query: &[&[PackedSubfield]], + batch_query: &[&[P]], ) -> P; /// Returns the composition evaluated by this object. diff --git a/crates/hal/src/sumcheck_round_calculator.rs b/crates/hal/src/sumcheck_round_calculator.rs index f71d1f846..8fe276662 100644 --- a/crates/hal/src/sumcheck_round_calculator.rs +++ b/crates/hal/src/sumcheck_round_calculator.rs @@ -4,19 +4,16 @@ //! //! This is one of the core computational tasks in the sumcheck proving algorithm. -use std::{iter, marker::PhantomData}; +use std::iter; -use binius_field::{ - recast_packed, ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, - RepackedExtension, -}; +use binius_field::{ExtensionField, Field, PackedExtension, PackedField, PackedSubfield}; use binius_math::{ deinterleave, extrapolate_lines, CompositionPolyOS, MultilinearPoly, MultilinearQuery, MultilinearQueryRef, }; use binius_maybe_rayon::prelude::*; use bytemuck::zeroed_vec; -use itertools::izip; +use itertools::{izip, Itertools}; use stackalloc::stackalloc_with_iter; use crate::{Error, RoundEvals, SumcheckEvaluator, SumcheckMultilinear}; @@ -44,39 +41,11 @@ trait SumcheckMultilinearAccess { ) -> Result<(), Error>; } -/// Calculate the accumulated evaluations for the first sumcheck round. -pub(crate) fn calculate_first_round_evals( - n_vars: usize, - multilinears: &[SumcheckMultilinear], - evaluators: &[Evaluator], - evaluation_points: &[FDomain], -) -> Result>, Error> -where - FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, - M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, -{ - let accesses = multilinears - .iter() - .map(FirstRoundAccess::new) - .collect::>(); - calculate_round_evals::<_, FBase, _, _, _, _, _>( - n_vars, - &accesses, - evaluators, - evaluation_points, - ) -} - /// Calculate the accumulated evaluations for an arbitrary sumcheck round. /// /// See [`calculate_first_round_evals`] for an optimized version of this method /// that works over small fields in the first round. -pub(crate) fn calculate_later_round_evals( +pub(crate) fn calculate_round_evals( n_vars: usize, tensor_query: Option>, multilinears: &[SumcheckMultilinear], @@ -86,25 +55,26 @@ pub(crate) fn calculate_later_round_evals, - P: PackedField + PackedExtension + PackedExtension, + P: PackedField + PackedExtension, M: MultilinearPoly

+ Send + Sync, - Evaluator: SumcheckEvaluator + Sync, + Evaluator: SumcheckEvaluator + Sync, Composition: CompositionPolyOS

, { let empty_query = MultilinearQuery::with_capacity(0); - let query = tensor_query.unwrap_or_else(|| empty_query.to_ref()); + let tensor_query = tensor_query.unwrap_or_else(|| empty_query.to_ref()); - let accesses = multilinears + let later_rounds_accesses = multilinears .iter() - .map(|multilinear| LaterRoundAccess { + .map(|multilinear| LargeFieldAccess { multilinear, - tensor_query: query, + tensor_query, }) - .collect::>(); - calculate_round_evals::<_, F, _, _, _, _, _>(n_vars, &accesses, evaluators, evaluation_points) + .collect_vec(); + + calculate_round_evals_with_access(n_vars, &later_rounds_accesses, evaluators, evaluation_points) } -fn calculate_round_evals( +fn calculate_round_evals_with_access( n_vars: usize, multilinears: &[Access], evaluators: &[Evaluator], @@ -112,11 +82,10 @@ fn calculate_round_evals( ) -> Result>, Error> where FDomain: Field, - FBase: ExtensionField, - F: Field + ExtensionField + ExtensionField, - P: PackedField + PackedExtension + PackedExtension, - Evaluator: SumcheckEvaluator + Sync, - Access: SumcheckMultilinearAccess> + Sync, + F: ExtensionField, + P: PackedField + PackedExtension, + Evaluator: SumcheckEvaluator + Sync, + Access: SumcheckMultilinearAccess

+ Sync, Composition: CompositionPolyOS

, { let n_multilinears = multilinears.len(); @@ -182,9 +151,9 @@ where // `binius_math::univariate::extrapolate_line`, except that we do // not repeat the broadcast of the subfield element to a packed // subfield. - *eval_z = recast_packed::(extrapolate_lines( - recast_packed::(eval_0), - recast_packed::(eval_1), + *eval_z = P::cast_ext(extrapolate_lines( + P::cast_base(eval_0), + P::cast_base(eval_1), eval_point_broadcast, )); } @@ -316,57 +285,7 @@ impl ParFoldStates { } #[derive(Debug)] -struct FirstRoundAccess<'a, PBase, P, M> -where - P: PackedField, - M: MultilinearPoly

+ Send + Sync, -{ - multilinear: &'a SumcheckMultilinear, - _marker: PhantomData, -} - -impl<'a, PBase, P, M> FirstRoundAccess<'a, PBase, P, M> -where - P: PackedField, - M: MultilinearPoly

+ Send + Sync, -{ - const fn new(multilinear: &'a SumcheckMultilinear) -> Self { - Self { - multilinear, - _marker: PhantomData, - } - } -} - -impl SumcheckMultilinearAccess for FirstRoundAccess<'_, PBase, P, M> -where - PBase: PackedField, - P: RepackedExtension, - P::Scalar: ExtensionField, - M: MultilinearPoly

+ Send + Sync, -{ - fn subcube_evaluations( - &self, - subcube_vars: usize, - subcube_index: usize, - evals: &mut [PBase], - ) -> Result<(), Error> { - if let SumcheckMultilinear::Transparent { multilinear, .. } = self.multilinear { - let evals =

>::cast_exts_mut(evals); - Ok(multilinear.subcube_evals( - subcube_vars, - subcube_index, - >::LOG_DEGREE, - evals, - )?) - } else { - panic!("precondition: no folded multilinears in the first round"); - } - } -} - -#[derive(Debug)] -struct LaterRoundAccess<'a, P, M> +struct LargeFieldAccess<'a, P, M> where P: PackedField, M: MultilinearPoly

+ Send + Sync, @@ -375,7 +294,7 @@ where tensor_query: MultilinearQueryRef<'a, P>, } -impl SumcheckMultilinearAccess

for LaterRoundAccess<'_, P, M> +impl SumcheckMultilinearAccess

for LargeFieldAccess<'_, P, M> where P: PackedField, M: MultilinearPoly

+ Send + Sync, @@ -388,28 +307,28 @@ where ) -> Result<(), Error> { match self.multilinear { SumcheckMultilinear::Transparent { multilinear, .. } => { - // TODO: Stop using LaterRoundAccess for first round in RegularSumcheckProver and - // GPASumcheckProver, then remove this conditional. if self.tensor_query.n_vars() == 0 { - Ok(multilinear.subcube_evals(subcube_vars, subcube_index, 0, evals)?) + multilinear.subcube_evals(subcube_vars, subcube_index, 0, evals)? } else { - Ok(multilinear.subcube_inner_products( + multilinear.subcube_inner_products( self.tensor_query, subcube_vars, subcube_index, evals, - )?) + )? } } SumcheckMultilinear::Folded { large_field_folded_multilinear, - } => Ok(large_field_folded_multilinear.subcube_evals( + } => large_field_folded_multilinear.subcube_evals( subcube_vars, subcube_index, 0, evals, - )?), + )?, } + + Ok(()) } } From eb52576f53745c38ecdc15f9d1dd4bf826c03c9d Mon Sep 17 00:00:00 2001 From: Dmitry Gordon Date: Fri, 31 Jan 2025 21:50:23 +0000 Subject: [PATCH 03/50] [ring_switch] Optimize RingSwitchEqInd::multilinear_extension --- crates/core/src/ring_switch/eq_ind.rs | 92 +++++- crates/core/src/ring_switch/prove.rs | 12 +- crates/core/src/ring_switch/verify.rs | 12 +- crates/field/src/aes_field.rs | 1 - crates/field/src/binary_field.rs | 24 +- crates/field/src/byte_iteration.rs | 440 ++++++++++++++++++++++++++ crates/field/src/extension.rs | 16 +- crates/field/src/lib.rs | 1 + crates/field/src/packed_extension.rs | 2 +- crates/field/src/polyval.rs | 17 +- crates/math/src/fold.rs | 269 ++++++---------- 11 files changed, 657 insertions(+), 229 deletions(-) create mode 100644 crates/field/src/byte_iteration.rs diff --git a/crates/core/src/ring_switch/eq_ind.rs b/crates/core/src/ring_switch/eq_ind.rs index daffe0ba8..c6c448c0e 100644 --- a/crates/core/src/ring_switch/eq_ind.rs +++ b/crates/core/src/ring_switch/eq_ind.rs @@ -1,10 +1,14 @@ // Copyright 2024-2025 Irreducible Inc. -use std::{iter, marker::PhantomData, sync::Arc}; +use std::{any::TypeId, iter, marker::PhantomData, sync::Arc}; use binius_field::{ - util::inner_product_unchecked, ExtensionField, Field, PackedExtension, PackedField, - PackedFieldIndexable, TowerField, + byte_iteration::{ + can_iterate_bytes, create_partial_sums_lookup_tables, iterate_bytes, ByteIteratorCallback, + }, + util::inner_product_unchecked, + BinaryField1b, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, + TowerField, }; use binius_math::{tensor_prod_eq_ind, MultilinearExtension}; use binius_maybe_rayon::prelude::*; @@ -17,6 +21,34 @@ use crate::{ tensor_algebra::TensorAlgebra, }; +/// Information about the row-batching coefficients. +#[derive(Debug)] +pub struct RowBatchCoeffs { + coeffs: Vec, + /// This is a lookup table for the partial sums of the coefficients + /// that is used to efficiently fold with 1-bit coefficients. + partial_sums_lookup_table: Vec, +} + +impl RowBatchCoeffs { + pub fn new(coeffs: Vec) -> Self { + let partial_sums_lookup_table = if coeffs.len() >= 8 { + create_partial_sums_lookup_tables(coeffs.as_slice()) + } else { + Vec::new() + }; + + Self { + coeffs, + partial_sums_lookup_table, + } + } + + pub fn coeffs(&self) -> &[F] { + &self.coeffs + } +} + /// The multilinear function $A$ from [DP24] Section 5. /// /// The function $A$ is $\ell':= \ell - \kappa$-variate and depends on the last $\ell'$ coordinates @@ -27,7 +59,7 @@ use crate::{ pub struct RingSwitchEqInd { /// $z_{\kappa}, \ldots, z_{\ell-1}$ z_vals: Arc<[F]>, - row_batch_coeffs: Arc<[F]>, + row_batch_coeffs: Arc>, mixing_coeff: F, _marker: PhantomData, } @@ -39,10 +71,10 @@ where { pub fn new( z_vals: Arc<[F]>, - row_batch_coeffs: Arc<[F]>, + row_batch_coeffs: Arc>, mixing_coeff: F, ) -> Result { - if row_batch_coeffs.len() < F::DEGREE { + if row_batch_coeffs.coeffs.len() < F::DEGREE { bail!(Error::InvalidArgs( "RingSwitchEqInd::new expects row_batch_coeffs length greater than or equal to \ the extension degree" @@ -67,16 +99,49 @@ where P::unpack_scalars_mut(&mut evals) .par_iter_mut() .for_each(|val| { - let vert = *val; - *val = inner_product_unchecked( - self.row_batch_coeffs.iter().copied(), - ExtensionField::::iter_bases(&vert), - ); + *val = inner_product_subfield(*val, &self.row_batch_coeffs); }); Ok(MultilinearExtension::from_values(evals)?) } } +#[inline(always)] +fn inner_product_subfield(value: F, row_batch_coeffs: &RowBatchCoeffs) -> F +where + FSub: Field, + F: ExtensionField, +{ + if TypeId::of::() == TypeId::of::() && can_iterate_bytes::() { + // Special case when we are folding with 1-bit coefficients. + // Use partial sums lookup table to speed up the computation. + + struct Callback<'a, F> { + partial_sums_lookup: &'a [F], + result: F, + } + + impl ByteIteratorCallback for Callback<'_, F> { + #[inline(always)] + fn call(&mut self, iter: impl Iterator) { + for (byte_index, byte) in iter.enumerate() { + self.result += self.partial_sums_lookup[(byte_index << 8) + byte as usize]; + } + } + } + + let mut callback = Callback { + partial_sums_lookup: &row_batch_coeffs.partial_sums_lookup_table, + result: F::ZERO, + }; + iterate_bytes(std::slice::from_ref(&value), &mut callback); + + callback.result + } else { + // fall back to the general case + inner_product_unchecked(row_batch_coeffs.coeffs.iter().copied(), F::iter_bases(&value)) + } +} + impl MultivariatePoly for RingSwitchEqInd where FSub: TowerField, @@ -108,7 +173,7 @@ where }, ); - let folded_eval = tensor_eval.fold_vertical(&self.row_batch_coeffs); + let folded_eval = tensor_eval.fold_vertical(&self.row_batch_coeffs.coeffs); Ok(folded_eval) } @@ -141,7 +206,8 @@ mod tests { let row_batch_coeffs = repeat_with(|| ::random(&mut rng)) .take(1 << kappa) - .collect::>(); + .collect::>(); + let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(row_batch_coeffs)); let eval_point = repeat_with(|| ::random(&mut rng)) .take(n_vars) diff --git a/crates/core/src/ring_switch/prove.rs b/crates/core/src/ring_switch/prove.rs index ab2d6aee8..5a7423aad 100644 --- a/crates/core/src/ring_switch/prove.rs +++ b/crates/core/src/ring_switch/prove.rs @@ -11,6 +11,7 @@ use tracing::instrument; use super::{ common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc}, + eq_ind::RowBatchCoeffs, error::Error, tower_tensor_algebra::TowerTensorAlgebra, }; @@ -75,11 +76,12 @@ where // Sample the row-batching randomness. let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa()); - let row_batch_coeffs = - Arc::from(MultilinearQuery::::expand(&row_batch_challenges).into_expansion()); + let row_batch_coeffs = Arc::new(RowBatchCoeffs::new( + MultilinearQuery::::expand(&row_batch_challenges).into_expansion(), + )); let row_batched_evals = - compute_row_batched_sumcheck_evals(scaled_tensor_elems, &row_batch_coeffs); + compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs()); transcript.message().write_scalar_slice(&row_batched_evals); // Create the reduced PIOP sumcheck witnesses. @@ -217,7 +219,7 @@ where fn make_ring_switch_eq_inds( sumcheck_claim_descs: &[PIOPSumcheckClaimDesc], suffix_descs: &[EvalClaimSuffixDesc], - row_batch_coeffs: Arc<[F]>, + row_batch_coeffs: Arc>, mixing_coeffs: &[F], ) -> Result>, Error> where @@ -238,7 +240,7 @@ where fn make_ring_switch_eq_ind( suffix_desc: &EvalClaimSuffixDesc>, - row_batch_coeffs: Arc<[FExt]>, + row_batch_coeffs: Arc>>, mixing_coeff: FExt, ) -> Result, Error> where diff --git a/crates/core/src/ring_switch/verify.rs b/crates/core/src/ring_switch/verify.rs index 330b0ffce..9c2cf2255 100644 --- a/crates/core/src/ring_switch/verify.rs +++ b/crates/core/src/ring_switch/verify.rs @@ -8,6 +8,7 @@ use binius_utils::checked_arithmetics::log2_ceil_usize; use bytes::Buf; use itertools::izip; +use super::eq_ind::RowBatchCoeffs; use crate::{ fiat_shamir::{CanSample, Challenger}, piop::PIOPSumcheckClaim, @@ -50,8 +51,9 @@ where // Sample the row-batching randomness. let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa()); - let row_batch_coeffs = - Arc::from(MultilinearQuery::::expand(&row_batch_challenges).into_expansion()); + let row_batch_coeffs = Arc::new(RowBatchCoeffs::new( + MultilinearQuery::::expand(&row_batch_challenges).into_expansion(), + )); // For each original evaluation claim, receive the row-batched evaluation claim. let row_batched_evals = transcript @@ -66,7 +68,7 @@ where &system.eval_claim_to_prefix_desc_index, ); for (expected, tensor_elem) in iter::zip(mixed_row_batched_evals, tensor_elems) { - if tensor_elem.fold_vertical(&row_batch_coeffs) != expected { + if tensor_elem.fold_vertical(row_batch_coeffs.coeffs()) != expected { return Err(VerificationError::IncorrectRowBatchedSum.into()); } } @@ -173,7 +175,7 @@ fn accumulate_evaluations_by_prefixes( fn make_ring_switch_eq_inds( sumcheck_claim_descs: &[PIOPSumcheckClaimDesc], suffix_descs: &[EvalClaimSuffixDesc], - row_batch_coeffs: Arc<[F]>, + row_batch_coeffs: Arc>, mixing_coeffs: &[F], ) -> Result>>, Error> where @@ -190,7 +192,7 @@ where fn make_ring_switch_eq_ind( suffix_desc: &EvalClaimSuffixDesc>, - row_batch_coeffs: Arc<[FExt]>, + row_batch_coeffs: Arc>>, mixing_coeff: FExt, ) -> Result>>, Error> where diff --git a/crates/field/src/aes_field.rs b/crates/field/src/aes_field.rs index c69b665ab..ab650f4cd 100644 --- a/crates/field/src/aes_field.rs +++ b/crates/field/src/aes_field.rs @@ -2,7 +2,6 @@ use std::{ any::TypeId, - array, fmt::{Debug, Display, Formatter}, iter::{Product, Sum}, marker::PhantomData, diff --git a/crates/field/src/binary_field.rs b/crates/field/src/binary_field.rs index be6f8ea80..f38a43739 100644 --- a/crates/field/src/binary_field.rs +++ b/crates/field/src/binary_field.rs @@ -2,7 +2,6 @@ use std::{ any::TypeId, - array, fmt::{Debug, Display, Formatter}, iter::{Product, Sum}, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, @@ -564,7 +563,6 @@ macro_rules! impl_field_extension { } impl ExtensionField<$subfield_name> for $name { - type Iterator = <[$subfield_name; 1 << $log_degree] as IntoIterator>::IntoIter; const LOG_DEGREE: usize = $log_degree; #[inline] @@ -597,15 +595,21 @@ macro_rules! impl_field_extension { } #[inline] - fn iter_bases(&self) -> Self::Iterator { - use $crate::underlier::NumCast; + fn iter_bases(&self) -> impl Iterator { + use $crate::underlier::{WithUnderlier, IterationMethods, IterationStrategy}; + use binius_utils::iter::IterExtensions; + + IterationMethods::<<$subfield_name as WithUnderlier>::Underlier, Self::Underlier>::ref_iter(&self.0) + .map_skippable($subfield_name::from) + } + + #[inline] + fn into_iter_bases(self) -> impl Iterator { + use $crate::underlier::{WithUnderlier, IterationMethods, IterationStrategy}; + use binius_utils::iter::IterExtensions; - let base_elems = array::from_fn(|i| { - <$subfield_name>::new(<$subfield_typ>::num_cast_from( - (self.0 >> (i * $subfield_name::N_BITS)), - )) - }); - base_elems.into_iter() + IterationMethods::<<$subfield_name as WithUnderlier>::Underlier, Self::Underlier>::value_iter(self.0) + .map_skippable($subfield_name::from) } } }; diff --git a/crates/field/src/byte_iteration.rs b/crates/field/src/byte_iteration.rs new file mode 100644 index 000000000..e7acda990 --- /dev/null +++ b/crates/field/src/byte_iteration.rs @@ -0,0 +1,440 @@ +// Copyright 2023-2025 Irreducible Inc. + +use std::any::TypeId; + +use bytemuck::Pod; + +use crate::{ + packed::get_packed_slice, AESTowerField128b, AESTowerField16b, AESTowerField32b, + AESTowerField64b, AESTowerField8b, BinaryField128b, BinaryField128bPolyval, BinaryField16b, + BinaryField32b, BinaryField64b, BinaryField8b, ByteSlicedAES32x128b, ByteSlicedAES32x16b, + ByteSlicedAES32x32b, ByteSlicedAES32x64b, ByteSlicedAES32x8b, Field, + PackedAESBinaryField16x16b, PackedAESBinaryField16x32b, PackedAESBinaryField16x8b, + PackedAESBinaryField1x128b, PackedAESBinaryField1x16b, PackedAESBinaryField1x32b, + PackedAESBinaryField1x64b, PackedAESBinaryField1x8b, PackedAESBinaryField2x128b, + PackedAESBinaryField2x16b, PackedAESBinaryField2x32b, PackedAESBinaryField2x64b, + PackedAESBinaryField2x8b, PackedAESBinaryField32x16b, PackedAESBinaryField32x8b, + PackedAESBinaryField4x128b, PackedAESBinaryField4x16b, PackedAESBinaryField4x32b, + PackedAESBinaryField4x64b, PackedAESBinaryField4x8b, PackedAESBinaryField64x8b, + PackedAESBinaryField8x16b, PackedAESBinaryField8x64b, PackedAESBinaryField8x8b, + PackedBinaryField128x1b, PackedBinaryField128x2b, PackedBinaryField128x4b, + PackedBinaryField16x16b, PackedBinaryField16x1b, PackedBinaryField16x2b, + PackedBinaryField16x32b, PackedBinaryField16x4b, PackedBinaryField16x8b, + PackedBinaryField1x128b, PackedBinaryField1x16b, PackedBinaryField1x32b, + PackedBinaryField1x64b, PackedBinaryField1x8b, PackedBinaryField256x1b, + PackedBinaryField256x2b, PackedBinaryField2x128b, PackedBinaryField2x16b, + PackedBinaryField2x32b, PackedBinaryField2x4b, PackedBinaryField2x64b, PackedBinaryField2x8b, + PackedBinaryField32x16b, PackedBinaryField32x1b, PackedBinaryField32x2b, + PackedBinaryField32x4b, PackedBinaryField32x8b, PackedBinaryField4x128b, + PackedBinaryField4x16b, PackedBinaryField4x2b, PackedBinaryField4x32b, PackedBinaryField4x4b, + PackedBinaryField4x64b, PackedBinaryField4x8b, PackedBinaryField512x1b, PackedBinaryField64x1b, + PackedBinaryField64x2b, PackedBinaryField64x4b, PackedBinaryField64x8b, PackedBinaryField8x16b, + PackedBinaryField8x1b, PackedBinaryField8x2b, PackedBinaryField8x32b, PackedBinaryField8x4b, + PackedBinaryField8x64b, PackedBinaryField8x8b, PackedBinaryPolyval1x128b, + PackedBinaryPolyval2x128b, PackedBinaryPolyval4x128b, PackedField, +}; + +/// A marker trait that the slice of packed values can be iterated as a sequence of bytes. +/// The order of the iteration by BinaryField1b subfield elements and bits within iterated bytes must +/// be the same. +/// +/// # Safety +/// The implementor must ensure that the cast of the slice of packed values to the slice of bytes +/// is safe and preserves the order of the 1-bit elements. +#[allow(unused)] +unsafe trait SequentialBytes: Pod {} + +unsafe impl SequentialBytes for BinaryField8b {} +unsafe impl SequentialBytes for BinaryField16b {} +unsafe impl SequentialBytes for BinaryField32b {} +unsafe impl SequentialBytes for BinaryField64b {} +unsafe impl SequentialBytes for BinaryField128b {} + +unsafe impl SequentialBytes for PackedBinaryField8x1b {} +unsafe impl SequentialBytes for PackedBinaryField16x1b {} +unsafe impl SequentialBytes for PackedBinaryField32x1b {} +unsafe impl SequentialBytes for PackedBinaryField64x1b {} +unsafe impl SequentialBytes for PackedBinaryField128x1b {} +unsafe impl SequentialBytes for PackedBinaryField256x1b {} +unsafe impl SequentialBytes for PackedBinaryField512x1b {} + +unsafe impl SequentialBytes for PackedBinaryField4x2b {} +unsafe impl SequentialBytes for PackedBinaryField8x2b {} +unsafe impl SequentialBytes for PackedBinaryField16x2b {} +unsafe impl SequentialBytes for PackedBinaryField32x2b {} +unsafe impl SequentialBytes for PackedBinaryField64x2b {} +unsafe impl SequentialBytes for PackedBinaryField128x2b {} +unsafe impl SequentialBytes for PackedBinaryField256x2b {} + +unsafe impl SequentialBytes for PackedBinaryField2x4b {} +unsafe impl SequentialBytes for PackedBinaryField4x4b {} +unsafe impl SequentialBytes for PackedBinaryField8x4b {} +unsafe impl SequentialBytes for PackedBinaryField16x4b {} +unsafe impl SequentialBytes for PackedBinaryField32x4b {} +unsafe impl SequentialBytes for PackedBinaryField64x4b {} +unsafe impl SequentialBytes for PackedBinaryField128x4b {} + +unsafe impl SequentialBytes for PackedBinaryField1x8b {} +unsafe impl SequentialBytes for PackedBinaryField2x8b {} +unsafe impl SequentialBytes for PackedBinaryField4x8b {} +unsafe impl SequentialBytes for PackedBinaryField8x8b {} +unsafe impl SequentialBytes for PackedBinaryField16x8b {} +unsafe impl SequentialBytes for PackedBinaryField32x8b {} +unsafe impl SequentialBytes for PackedBinaryField64x8b {} + +unsafe impl SequentialBytes for PackedBinaryField1x16b {} +unsafe impl SequentialBytes for PackedBinaryField2x16b {} +unsafe impl SequentialBytes for PackedBinaryField4x16b {} +unsafe impl SequentialBytes for PackedBinaryField8x16b {} +unsafe impl SequentialBytes for PackedBinaryField16x16b {} +unsafe impl SequentialBytes for PackedBinaryField32x16b {} + +unsafe impl SequentialBytes for PackedBinaryField1x32b {} +unsafe impl SequentialBytes for PackedBinaryField2x32b {} +unsafe impl SequentialBytes for PackedBinaryField4x32b {} +unsafe impl SequentialBytes for PackedBinaryField8x32b {} +unsafe impl SequentialBytes for PackedBinaryField16x32b {} + +unsafe impl SequentialBytes for PackedBinaryField1x64b {} +unsafe impl SequentialBytes for PackedBinaryField2x64b {} +unsafe impl SequentialBytes for PackedBinaryField4x64b {} +unsafe impl SequentialBytes for PackedBinaryField8x64b {} + +unsafe impl SequentialBytes for PackedBinaryField1x128b {} +unsafe impl SequentialBytes for PackedBinaryField2x128b {} +unsafe impl SequentialBytes for PackedBinaryField4x128b {} + +unsafe impl SequentialBytes for AESTowerField8b {} +unsafe impl SequentialBytes for AESTowerField16b {} +unsafe impl SequentialBytes for AESTowerField32b {} +unsafe impl SequentialBytes for AESTowerField64b {} +unsafe impl SequentialBytes for AESTowerField128b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField8x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField16x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField32x8b {} +unsafe impl SequentialBytes for PackedAESBinaryField64x8b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField8x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField16x16b {} +unsafe impl SequentialBytes for PackedAESBinaryField32x16b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x32b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x32b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x32b {} +unsafe impl SequentialBytes for PackedAESBinaryField16x32b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x64b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x64b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x64b {} +unsafe impl SequentialBytes for PackedAESBinaryField8x64b {} + +unsafe impl SequentialBytes for PackedAESBinaryField1x128b {} +unsafe impl SequentialBytes for PackedAESBinaryField2x128b {} +unsafe impl SequentialBytes for PackedAESBinaryField4x128b {} + +unsafe impl SequentialBytes for BinaryField128bPolyval {} + +unsafe impl SequentialBytes for PackedBinaryPolyval1x128b {} +unsafe impl SequentialBytes for PackedBinaryPolyval2x128b {} +unsafe impl SequentialBytes for PackedBinaryPolyval4x128b {} + +/// Returns true if T implements `SequentialBytes` trait. +/// Use a hack that exploits that array copying is optimized for the `Copy` types. +/// Unfortunately there is no more proper way to perform this check this in Rust at runtime. +#[inline(always)] +#[allow(clippy::redundant_clone)] // this is intentional in this method +pub fn is_sequential_bytes() -> bool { + struct X(bool, std::marker::PhantomData); + + impl Clone for X { + fn clone(&self) -> Self { + Self(false, std::marker::PhantomData) + } + } + + impl Copy for X {} + + let value = [X::(true, std::marker::PhantomData)]; + let cloned = value.clone(); + + cloned[0].0 +} + +/// Returns if we can iterate over bytes, each representing 8 1-bit values. +pub fn can_iterate_bytes() -> bool { + // Packed fields with sequential byte order + if is_sequential_bytes::

() { + return true; + } + + // Byte-sliced fields + // Note: add more byte sliced types here as soon as they are added + match TypeId::of::

() { + x if x == TypeId::of::() => true, + x if x == TypeId::of::() => true, + x if x == TypeId::of::() => true, + x if x == TypeId::of::() => true, + x if x == TypeId::of::() => true, + _ => false, + } +} + +/// Helper macro to generate the iteration over bytes for byte-sliced types. +macro_rules! iterate_byte_sliced { + ($packed_type:ty, $data:ident, $callback:ident) => { + assert_eq!(TypeId::of::<$packed_type>(), TypeId::of::

()); + + // Safety: the cast is safe because the type is checked by arm statement + let data = unsafe { + std::slice::from_raw_parts($data.as_ptr() as *const $packed_type, $data.len()) + }; + let iter = data.iter().flat_map(|value| { + (0..<$packed_type>::BYTES).map(move |i| unsafe { value.get_byte_unchecked(i) }) + }); + + $callback.call(iter); + }; +} + +/// Callback for byte iteration. +/// We can't return different types from the `iterate_bytes` and Fn traits don't support associated types +/// that's why we use a callback with a generic function. +pub trait ByteIteratorCallback { + fn call(&mut self, iter: impl Iterator); +} + +/// Iterate over bytes of a slice of the packed values. +/// The method panics if the packed field doesn't support byte iteration, so use `can_iterate_bytes` to check it. +#[inline(always)] +pub fn iterate_bytes(data: &[P], callback: &mut impl ByteIteratorCallback) { + if is_sequential_bytes::

() { + // Safety: `P` implements `SequentialBytes` trait, so the following cast is safe + // and preserves the order. + let bytes = unsafe { + std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) + }; + callback.call(bytes.iter().copied()); + } else { + // Note: add more byte sliced types here as soon as they are added + match TypeId::of::

() { + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x128b, data, callback); + } + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x64b, data, callback); + } + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x32b, data, callback); + } + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x16b, data, callback); + } + x if x == TypeId::of::() => { + iterate_byte_sliced!(ByteSlicedAES32x8b, data, callback); + } + _ => unreachable!("packed field doesn't support byte iteration"), + } + } +} + +/// Scalars collection abstraction. +/// This trait is used to abstract over different types of collections of field elements. +pub trait ScalarsCollection { + fn len(&self) -> usize; + fn get(&self, i: usize) -> T; + fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl ScalarsCollection for &[F] { + #[inline(always)] + fn len(&self) -> usize { + <[F]>::len(self) + } + + #[inline(always)] + fn get(&self, i: usize) -> F { + self[i] + } +} + +pub struct PackedSlice<'a, P: PackedField> { + slice: &'a [P], + len: usize, +} + +impl<'a, P: PackedField> PackedSlice<'a, P> { + #[inline(always)] + pub const fn new(slice: &'a [P], len: usize) -> Self { + Self { slice, len } + } +} + +impl ScalarsCollection for PackedSlice<'_, P> { + #[inline(always)] + fn len(&self) -> usize { + self.len + } + + #[inline(always)] + fn get(&self, i: usize) -> P::Scalar { + get_packed_slice(self.slice, i) + } +} + +/// Create a lookup table for partial sums of 8 consequent elements with coefficients corresponding to bits in a byte. +/// The lookup table has the following structure: +/// [ +/// partial_sum_chunk_0_7_byte_0, partial_sum_chunk_0_7_byte_1, ..., partial_sum_chunk_0_7_byte_255, +/// partial_sum_chunk_8_15_byte_0, partial_sum_chunk_8_15_byte_1, ..., partial_sum_chunk_8_15_byte_255, +/// ... +/// ] +pub fn create_partial_sums_lookup_tables( + values: impl ScalarsCollection

, +) -> Vec

{ + let len = values.len(); + assert!(len % 8 == 0); + + let mut result = Vec::with_capacity(len * 256 / 8); + for chunk_i in 0..len / 8 { + let offset = chunk_i * 8; + for i in 0..256 { + let mut sum = P::zero(); + for j in 0..8 { + if i & (1 << j) != 0 { + sum += values.get(offset + j); + } + } + result.push(sum); + } + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{PackedBinaryField1x1b, PackedBinaryField2x1b, PackedBinaryField4x1b}; + + #[test] + fn test_sequential_bits() { + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + assert!(is_sequential_bytes::()); + + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + assert!(!is_sequential_bytes::()); + } +} diff --git a/crates/field/src/extension.rs b/crates/field/src/extension.rs index 45ca2e368..84a660aaa 100644 --- a/crates/field/src/extension.rs +++ b/crates/field/src/extension.rs @@ -18,9 +18,6 @@ pub trait ExtensionField: + SubAssign + MulAssign { - /// Iterator returned by `iter_bases`. - type Iterator: Iterator; - /// Base-2 logarithm of the extension degree. const LOG_DEGREE: usize; @@ -47,12 +44,13 @@ pub trait ExtensionField: fn from_bases_sparse(base_elems: &[F], log_stride: usize) -> Result; /// Iterator over base field elements. - fn iter_bases(&self) -> Self::Iterator; + fn iter_bases(&self) -> impl Iterator; + + /// Convert into an iterator over base field elements. + fn into_iter_bases(self) -> impl Iterator; } impl ExtensionField for F { - type Iterator = iter::Once; - const LOG_DEGREE: usize = 0; fn basis(i: usize) -> Result { @@ -74,7 +72,11 @@ impl ExtensionField for F { } } - fn iter_bases(&self) -> Self::Iterator { + fn iter_bases(&self) -> impl Iterator { iter::once(*self) } + + fn into_iter_bases(self) -> impl Iterator { + iter::once(self) + } } diff --git a/crates/field/src/lib.rs b/crates/field/src/lib.rs index cae1070da..76414c038 100644 --- a/crates/field/src/lib.rs +++ b/crates/field/src/lib.rs @@ -20,6 +20,7 @@ pub mod arithmetic_traits; pub mod as_packed_field; pub mod binary_field; mod binary_field_arithmetic; +pub mod byte_iteration; pub mod error; pub mod extension; pub mod field; diff --git a/crates/field/src/packed_extension.rs b/crates/field/src/packed_extension.rs index 3a2b11a2a..3ecd7b662 100644 --- a/crates/field/src/packed_extension.rs +++ b/crates/field/src/packed_extension.rs @@ -54,7 +54,7 @@ where /// PE: PackedField>, /// F: Field, /// { -/// packed.iter().flat_map(|ext| ext.iter_bases()) +/// packed.iter().flat_map(|ext| ext.into_iter_bases()) /// } /// /// fn cast_then_iter<'a, F, PE>(packed: &'a PE) -> impl Iterator + 'a diff --git a/crates/field/src/polyval.rs b/crates/field/src/polyval.rs index 989e12eec..7b535fbbf 100644 --- a/crates/field/src/polyval.rs +++ b/crates/field/src/polyval.rs @@ -4,12 +4,12 @@ use std::{ any::TypeId, - array, fmt::{self, Debug, Display, Formatter}, iter::{Product, Sum}, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +use binius_utils::iter::IterExtensions; use bytemuck::{Pod, TransparentWrapper, Zeroable}; use rand::{Rng, RngCore}; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; @@ -29,7 +29,7 @@ use crate::{ invert_or_zero_using_packed, multiple_using_packed, square_using_packed, }, linear_transformation::{FieldLinearTransformation, Transformation}, - underlier::UnderlierWithBitOps, + underlier::{IterationMethods, IterationStrategy, UnderlierWithBitOps, U1}, Field, }; @@ -414,7 +414,6 @@ impl Mul for BinaryField1b { } impl ExtensionField for BinaryField128bPolyval { - type Iterator = <[BinaryField1b; 128] as IntoIterator>::IntoIter; const LOG_DEGREE: usize = 7; #[inline] @@ -439,9 +438,15 @@ impl ExtensionField for BinaryField128bPolyval { } #[inline] - fn iter_bases(&self) -> Self::Iterator { - let base_elems = array::from_fn(|i| BinaryField1b::from((self.0 >> i) as u8)); - base_elems.into_iter() + fn iter_bases(&self) -> impl Iterator { + IterationMethods::::value_iter(self.0) + .map_skippable(BinaryField1b::from) + } + + #[inline] + fn into_iter_bases(self) -> impl Iterator { + IterationMethods::::value_iter(self.0) + .map_skippable(BinaryField1b::from) } } diff --git a/crates/math/src/fold.rs b/crates/math/src/fold.rs index a7002a50e..96621f3ab 100644 --- a/crates/math/src/fold.rs +++ b/crates/math/src/fold.rs @@ -4,21 +4,22 @@ use core::slice; use std::{any::TypeId, cmp::min, mem::MaybeUninit}; use binius_field::{ - arch::{byte_sliced::ByteSlicedAES32x128b, ArchOptimal, OptimalUnderlier}, + arch::{ArchOptimal, OptimalUnderlier}, + byte_iteration::{ + can_iterate_bytes, create_partial_sums_lookup_tables, is_sequential_bytes, iterate_bytes, + ByteIteratorCallback, PackedSlice, + }, packed::{get_packed_slice, set_packed_slice_unchecked}, underlier::{UnderlierWithBitOps, WithUnderlier}, - AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ByteSlicedAES32x16b, - ByteSlicedAES32x32b, ByteSlicedAES32x64b, ByteSlicedAES32x8b, ExtensionField, Field, - PackedBinaryField128x1b, PackedBinaryField16x1b, PackedBinaryField256x1b, - PackedBinaryField32x1b, PackedBinaryField512x1b, PackedBinaryField64x1b, PackedBinaryField8x1b, - PackedField, + AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ExtensionField, + Field, PackedField, }; use binius_maybe_rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, slice::ParallelSliceMut, }; use binius_utils::bail; -use bytemuck::{fill_zeroes, Pod}; +use bytemuck::fill_zeroes; use itertools::max; use lazy_static::lazy_static; use stackalloc::helpers::slice_assume_init_mut; @@ -129,115 +130,6 @@ where Ok(()) } -/// A marker trait that the slice of packed values can be iterated as a sequence of bytes. -/// The order of the iteration by BinaryField1b subfield elements and bits within iterated bytes must -/// be the same. -/// -/// # Safety -/// The implementor must ensure that the cast of the slice of packed values to the slice of bytes -/// is safe and preserves the order of the 1-bit elements. -#[allow(unused)] -unsafe trait SequentialBytes: Pod {} - -unsafe impl SequentialBytes for PackedBinaryField8x1b {} -unsafe impl SequentialBytes for PackedBinaryField16x1b {} -unsafe impl SequentialBytes for PackedBinaryField32x1b {} -unsafe impl SequentialBytes for PackedBinaryField64x1b {} -unsafe impl SequentialBytes for PackedBinaryField128x1b {} -unsafe impl SequentialBytes for PackedBinaryField256x1b {} -unsafe impl SequentialBytes for PackedBinaryField512x1b {} - -/// Returns true if T implements `SequentialBytes` trait. -/// Use a hack that exploits that array copying is optimized for the `Copy` types. -/// Unfortunately there is no more proper way to perform this check this in Rust at runtime. -#[allow(clippy::redundant_clone)] -fn is_sequential_bytes() -> bool { - struct X(bool, std::marker::PhantomData); - - impl Clone for X { - fn clone(&self) -> Self { - Self(false, std::marker::PhantomData) - } - } - - impl Copy for X {} - - let value = [X::(true, std::marker::PhantomData)]; - let cloned = value.clone(); - - cloned[0].0 -} - -/// Returns if we can iterate over bytes, each representing 8 1-bit values. -fn can_iterate_bytes() -> bool { - // Packed fields with sequential byte order - if is_sequential_bytes::

() { - return true; - } - - // Byte-sliced fields - // Note: add more byte sliced types here as soon as they are added - match TypeId::of::

() { - x if x == TypeId::of::() => true, - x if x == TypeId::of::() => true, - x if x == TypeId::of::() => true, - x if x == TypeId::of::() => true, - x if x == TypeId::of::() => true, - _ => false, - } -} - -/// Helper macro to generate the iteration over bytes for byte-sliced types. -macro_rules! iterate_byte_sliced { - ($packed_type:ty, $data:ident, $f:ident) => { - assert_eq!(TypeId::of::<$packed_type>(), TypeId::of::

()); - - // Safety: the cast is safe because the type is checked by arm statement - let data = - unsafe { slice::from_raw_parts($data.as_ptr() as *const $packed_type, $data.len()) }; - for value in data.iter() { - for i in 0..<$packed_type>::BYTES { - // Safety: j is less than `ByteSlicedAES32x128b::BYTES` - $f(unsafe { value.get_byte_unchecked(i) }); - } - } - }; -} - -/// Iterate over bytes of a slice of the packed values. -fn iterate_bytes(data: &[P], mut f: impl FnMut(u8)) { - if is_sequential_bytes::

() { - // Safety: `P` implements `SequentialBytes` trait, so the following cast is safe - // and preserves the order. - let bytes = unsafe { - std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data)) - }; - for byte in bytes { - f(*byte); - } - } else { - // Note: add more byte sliced types here as soon as they are added - match TypeId::of::

() { - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x128b, data, f); - } - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x64b, data, f); - } - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x32b, data, f); - } - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x16b, data, f); - } - x if x == TypeId::of::() => { - iterate_byte_sliced!(ByteSlicedAES32x8b, data, f); - } - _ => unreachable!("packed field doesn't support byte iteration"), - } - } -} - /// Optimized version for 1-bit values with query size 0-2 fn fold_right_1bit_evals_small_query( evals: &[P], @@ -257,7 +149,7 @@ where (P::LOG_WIDTH + LOG_QUERY_SIZE).saturating_sub(PE::LOG_WIDTH), PE::LOG_WIDTH, ]) - .unwrap(); + .expect("sequence of max values is not empty"); if out.len() % chunk_size != 0 { return false; } @@ -287,24 +179,43 @@ where let input_end = (((index + 1) * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; - let mut current_index = 0; - iterate_bytes(&evals[input_offset..input_end], |byte| { - let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1; - let values_in_byte = 1 << (3 - LOG_QUERY_SIZE); - for k in 0..values_in_byte { - let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask; - // Safety: `i` is less than `chunk_size` - unsafe { - set_packed_slice_unchecked( - chunk, - current_index + k, - cached_table[index as usize], - ); + struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> { + chunk: &'a mut [PE], + cached_table: &'a [PE::Scalar], + } + + impl ByteIteratorCallback + for Callback<'_, PE, LOG_QUERY_SIZE> + { + #[inline(always)] + fn call(&mut self, iterator: impl Iterator) { + let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1; + let values_in_byte = 1 << (3 - LOG_QUERY_SIZE); + let mut current_index = 0; + for byte in iterator { + for k in 0..values_in_byte { + let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask; + // Safety: `i` is less than `chunk_size` + unsafe { + set_packed_slice_unchecked( + self.chunk, + current_index + k, + self.cached_table[index as usize], + ); + } + } + + current_index += values_in_byte; } } + } + + let mut callback = Callback::<'_, PE, LOG_QUERY_SIZE> { + chunk, + cached_table: &cached_table, + }; - current_index += values_in_byte; - }); + iterate_bytes(&evals[input_offset..input_end], &mut callback); }); true @@ -329,29 +240,13 @@ where (P::LOG_WIDTH + LOG_QUERY_SIZE).saturating_sub(PE::LOG_WIDTH), PE::LOG_WIDTH, ]) - .unwrap(); + .expect("sequence of max values is not empty"); if out.len() % chunk_size != 0 { return false; } - let log_tables_count = LOG_QUERY_SIZE - 3; - let tables_count = 1 << log_tables_count; - let cached_tables = (0..tables_count) - .map(|i| { - (0..256) - .map(|j| { - let mut result = PE::Scalar::ZERO; - for k in 0..8 { - if j >> k & 1 == 1 { - result += get_packed_slice(query, (i << 3) | k); - } - } - result - }) - .collect::>() - }) - .collect::>(); - + let cached_tables = + create_partial_sums_lookup_tables(PackedSlice::new(query, 1 << LOG_QUERY_SIZE)); out.par_chunks_mut(chunk_size) .enumerate() .for_each(|(index, chunk)| { @@ -360,23 +255,49 @@ where let input_end = (((index + 1) * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; - let mut current_value = PE::Scalar::ZERO; - let mut current_table = 0; - let mut current_index = 0; - iterate_bytes(&evals[input_offset..input_end], |byte| { - current_value += cached_tables[current_table][byte as usize]; - current_table += 1; + struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> { + chunk: &'a mut [PE], + cached_tables: &'a [PE::Scalar], + current_value: PE::Scalar, + current_table: usize, + } - if current_table == tables_count { - // Safety: `i` is less than `chunk_size` - unsafe { - set_packed_slice_unchecked(chunk, current_index, current_value); + impl ByteIteratorCallback + for Callback<'_, PE, LOG_QUERY_SIZE> + { + #[inline(always)] + fn call(&mut self, iterator: impl Iterator) { + let log_tables_count = LOG_QUERY_SIZE - 3; + let tables_count = 1 << log_tables_count; + for (current_index, byte) in iterator.enumerate() { + self.current_value += + self.cached_tables[(self.current_table << 8) + byte as usize]; + self.current_table += 1; + + if self.current_table == tables_count { + // Safety: `i` is less than `chunk_size` + unsafe { + set_packed_slice_unchecked( + self.chunk, + current_index, + self.current_value, + ); + } + self.current_table = 0; + self.current_value = PE::Scalar::ZERO; + } } - current_table = 0; - current_index += 1; - current_value = PE::Scalar::ZERO; } - }); + } + + let mut callback = Callback::<'_, _, LOG_QUERY_SIZE> { + chunk, + cached_tables: &cached_tables, + current_value: PE::Scalar::ZERO, + current_table: 0, + }; + + iterate_bytes(&evals[input_offset..input_end], &mut callback); }); true @@ -685,27 +606,13 @@ mod tests { use std::iter::repeat_with; use binius_field::{ - packed::set_packed_slice, PackedBinaryField16x32b, PackedBinaryField16x8b, - PackedBinaryField4x1b, PackedBinaryField512x1b, + packed::set_packed_slice, PackedBinaryField128x1b, PackedBinaryField16x32b, + PackedBinaryField16x8b, PackedBinaryField512x1b, }; use rand::{rngs::StdRng, SeedableRng}; use super::*; - #[test] - fn test_sequential_bits() { - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - assert!(is_sequential_bytes::()); - - assert!(!is_sequential_bytes::()); - assert!(!is_sequential_bytes::()); - } - fn fold_right_reference( evals: &[P], log_evals_size: usize, From a12f2a4fb6429c86b9fc4b1457f7aa08975cb142 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Tue, 4 Feb 2025 10:05:35 +0000 Subject: [PATCH 04/50] [clippy]: avoid needless pass by value --- crates/circuits/src/keccakf.rs | 2 +- crates/circuits/src/lib.rs | 2 +- crates/core/src/constraint_system/prove.rs | 4 ++-- crates/core/src/constraint_system/verify.rs | 6 +++--- crates/core/src/piop/prove.rs | 8 ++++---- crates/core/src/piop/verify.rs | 2 +- crates/core/src/protocols/evalcheck/error.rs | 2 +- crates/core/src/protocols/fri/common.rs | 6 +++--- crates/core/src/protocols/fri/tests.rs | 4 ++-- crates/core/src/protocols/gkr_gpa/verify.rs | 4 ++-- .../src/protocols/sumcheck/prove/zerocheck.rs | 10 +++++----- crates/core/src/protocols/test_utils.rs | 2 +- crates/core/src/reed_solomon/reed_solomon.rs | 6 +++--- crates/core/src/ring_switch/common.rs | 2 +- crates/core/src/ring_switch/tests.rs | 4 ++-- crates/core/src/ring_switch/verify.rs | 4 ++-- .../src/arch/portable/packed_arithmetic.rs | 4 ++-- crates/field/src/packed.rs | 4 ++-- crates/ntt/src/dynamic_dispatch.rs | 18 +++++++++--------- crates/ntt/src/single_threaded.rs | 6 +++--- examples/keccakf_circuit.rs | 2 +- 21 files changed, 51 insertions(+), 51 deletions(-) diff --git a/crates/circuits/src/keccakf.rs b/crates/circuits/src/keccakf.rs index a73c5cea8..638a89e4f 100644 --- a/crates/circuits/src/keccakf.rs +++ b/crates/circuits/src/keccakf.rs @@ -27,7 +27,7 @@ pub struct KeccakfOracles { pub fn keccakf( builder: &mut ConstraintSystemBuilder, - input_witness: Option>, + input_witness: &Option>, log_size: usize, ) -> Result where diff --git a/crates/circuits/src/lib.rs b/crates/circuits/src/lib.rs index c76857c67..4be63828e 100644 --- a/crates/circuits/src/lib.rs +++ b/crates/circuits/src/lib.rs @@ -377,7 +377,7 @@ mod tests { let mut rng = StdRng::seed_from_u64(0); let input_states = vec![KeccakfState(rng.gen())]; - let _state_out = keccakf(&mut builder, Some(input_states), log_size); + let _state_out = keccakf(&mut builder, &Some(input_states), log_size); let witness = builder.take_witness().unwrap(); diff --git a/crates/core/src/constraint_system/prove.rs b/crates/core/src/constraint_system/prove.rs index d405405e1..ef167b671 100644 --- a/crates/core/src/constraint_system/prove.rs +++ b/crates/core/src/constraint_system/prove.rs @@ -241,7 +241,7 @@ where let (flush_oracle_ids, flush_selectors, flush_final_layer_claims) = reorder_for_flushing_by_n_vars( &oracles, - flush_oracle_ids, + &flush_oracle_ids, flush_selectors, flush_final_layer_claims, ); @@ -423,7 +423,7 @@ where let system = ring_switch::EvalClaimSystem::new( &oracles, &commit_meta, - oracle_to_commit_index, + &oracle_to_commit_index, &eval_claims, )?; diff --git a/crates/core/src/constraint_system/verify.rs b/crates/core/src/constraint_system/verify.rs index 467efd938..50c528b72 100644 --- a/crates/core/src/constraint_system/verify.rs +++ b/crates/core/src/constraint_system/verify.rs @@ -155,7 +155,7 @@ where let (flush_oracle_ids, flush_selectors, flush_final_layer_claims) = reorder_for_flushing_by_n_vars( &oracles, - flush_oracle_ids, + &flush_oracle_ids, flush_selectors, flush_final_layer_claims, ); @@ -284,7 +284,7 @@ where let system = ring_switch::EvalClaimSystem::new( &oracles, &commit_meta, - oracle_to_commit_index, + &oracle_to_commit_index, &eval_claims, )?; @@ -698,7 +698,7 @@ pub fn get_flush_dedup_sumcheck_claims( pub fn reorder_for_flushing_by_n_vars( oracles: &MultilinearOracleSet, - flush_oracle_ids: Vec, + flush_oracle_ids: &[OracleId], flush_selectors: Vec, flush_final_layer_claims: Vec>, ) -> (Vec, Vec, Vec>) { diff --git a/crates/core/src/piop/prove.rs b/crates/core/src/piop/prove.rs index 9faed21c9..25ea62296 100644 --- a/crates/core/src/piop/prove.rs +++ b/crates/core/src/piop/prove.rs @@ -133,7 +133,7 @@ where let rs_code = ReedSolomonCode::new( fri_params.rs_code().log_dim(), fri_params.rs_code().log_inv_rate(), - NTTOptions { + &NTTOptions { precompute_twiddles: true, thread_settings: ThreadingSettings::MultithreadedDefault, }, @@ -234,7 +234,7 @@ where merkle_prover, sumcheck_provers, codeword, - committed, + &committed, transcript, )?; @@ -247,7 +247,7 @@ fn prove_interleaved_fri_sumcheck>, codeword: &[P], - committed: MTProver::Committed, + committed: &MTProver::Committed, transcript: &mut ProverTranscript, ) -> Result<(), Error> where @@ -259,7 +259,7 @@ where Challenger_: Challenger, { let mut fri_prover = - FRIFolder::new(fri_params, merkle_prover, P::unpack_scalars(codeword), &committed)?; + FRIFolder::new(fri_params, merkle_prover, P::unpack_scalars(codeword), committed)?; let mut sumcheck_batch_prover = SumcheckBatchProver::new(sumcheck_provers, transcript)?; diff --git a/crates/core/src/piop/verify.rs b/crates/core/src/piop/verify.rs index 96858f7fd..c7618538d 100644 --- a/crates/core/src/piop/verify.rs +++ b/crates/core/src/piop/verify.rs @@ -137,7 +137,7 @@ where let log_batch_size = fold_arities.first().copied().unwrap_or(0); let log_dim = commit_meta.total_vars - log_batch_size; - let rs_code = ReedSolomonCode::new(log_dim, log_inv_rate, NTTOptions::default())?; + let rs_code = ReedSolomonCode::new(log_dim, log_inv_rate, &NTTOptions::default())?; let n_test_queries = fri::calculate_n_test_queries::(security_bits, &rs_code)?; let fri_params = FRIParams::new(rs_code, log_batch_size, fold_arities, n_test_queries)?; Ok(fri_params) diff --git a/crates/core/src/protocols/evalcheck/error.rs b/crates/core/src/protocols/evalcheck/error.rs index 54426fe8f..d5bf9447b 100644 --- a/crates/core/src/protocols/evalcheck/error.rs +++ b/crates/core/src/protocols/evalcheck/error.rs @@ -45,7 +45,7 @@ pub enum VerificationError { impl VerificationError { pub fn incorrect_composite_poly_evaluation( - oracle: CompositePolyOracle, + oracle: &CompositePolyOracle, ) -> Self { let names = oracle .inner_polys() diff --git a/crates/core/src/protocols/fri/common.rs b/crates/core/src/protocols/fri/common.rs index e8e7318fd..9d74660b6 100644 --- a/crates/core/src/protocols/fri/common.rs +++ b/crates/core/src/protocols/fri/common.rs @@ -343,13 +343,13 @@ mod tests { #[test] fn test_calculate_n_test_queries() { let security_bits = 96; - let rs_code = ReedSolomonCode::new(28, 1, NTTOptions::default()).unwrap(); + let rs_code = ReedSolomonCode::new(28, 1, &NTTOptions::default()).unwrap(); let n_test_queries = calculate_n_test_queries::(security_bits, &rs_code) .unwrap(); assert_eq!(n_test_queries, 232); - let rs_code = ReedSolomonCode::new(28, 2, NTTOptions::default()).unwrap(); + let rs_code = ReedSolomonCode::new(28, 2, &NTTOptions::default()).unwrap(); let n_test_queries = calculate_n_test_queries::(security_bits, &rs_code) .unwrap(); @@ -359,7 +359,7 @@ mod tests { #[test] fn test_calculate_n_test_queries_unsatisfiable() { let security_bits = 128; - let rs_code = ReedSolomonCode::new(28, 1, NTTOptions::default()).unwrap(); + let rs_code = ReedSolomonCode::new(28, 1, &NTTOptions::default()).unwrap(); assert_matches!( calculate_n_test_queries::(security_bits, &rs_code), Err(Error::ParameterError) diff --git a/crates/core/src/protocols/fri/tests.rs b/crates/core/src/protocols/fri/tests.rs index df4ecf22a..1798b3c14 100644 --- a/crates/core/src/protocols/fri/tests.rs +++ b/crates/core/src/protocols/fri/tests.rs @@ -46,14 +46,14 @@ fn test_commit_prove_verify_success( let committed_rs_code_packed = ReedSolomonCode::>::new( log_dimension, log_inv_rate, - NTTOptions::default(), + &NTTOptions::default(), ) .unwrap(); let merkle_prover = BinaryMerkleTreeProver::<_, Groestl256, _>::new(Groestl256ByteCompression); let committed_rs_code = - ReedSolomonCode::::new(log_dimension, log_inv_rate, NTTOptions::default()).unwrap(); + ReedSolomonCode::::new(log_dimension, log_inv_rate, &NTTOptions::default()).unwrap(); let n_test_queries = 3; let params = diff --git a/crates/core/src/protocols/gkr_gpa/verify.rs b/crates/core/src/protocols/gkr_gpa/verify.rs index 056f4cb76..a9d96366d 100644 --- a/crates/core/src/protocols/gkr_gpa/verify.rs +++ b/crates/core/src/protocols/gkr_gpa/verify.rs @@ -54,7 +54,7 @@ where &mut reverse_sorted_evalcheck_claims, ); - layer_claims = reduce_layer_claim_batch(layer_claims, transcript)?; + layer_claims = reduce_layer_claim_batch(&layer_claims, transcript)?; } process_finished_claims( n_claims, @@ -102,7 +102,7 @@ fn process_finished_claims( /// * `proof` - The batch layer proof that reduces the kth layer claims of the product circuits to the (k+1)th /// * `transcript` - The verifier transcript fn reduce_layer_claim_batch( - claims: Vec>, + claims: &[LayerClaim], transcript: &mut VerifierTranscript, ) -> Result>, Error> where diff --git a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs index 23d4a573e..90df740ea 100644 --- a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs @@ -229,7 +229,7 @@ where // This is a regular multilinear zerocheck constructor, split over two creation stages. ZerocheckProver::new( multilinears, - self.switchover_rounds, + &self.switchover_rounds, compositions, partial_eq_ind_evals, self.zerocheck_challenges, @@ -375,7 +375,7 @@ where .switchover_rounds .into_iter() .map(|switchover_round| switchover_round.saturating_sub(skip_rounds)) - .collect(); + .collect::>(); let zerocheck_challenges = self.zerocheck_challenges.clone(); @@ -392,7 +392,7 @@ where // to use later round evaluator (as this _is_ a "later" round, albeit numbered at zero) let regular_prover = ZerocheckProver::new( partial_low_multilinears, - switchover_rounds, + &switchover_rounds, compositions, partial_eq_ind_evals, zerocheck_challenges, @@ -457,7 +457,7 @@ where #[allow(clippy::too_many_arguments)] fn new( multilinears: Vec, - switchover_rounds: Vec, + switchover_rounds: &[usize], compositions: Vec, partial_eq_ind_evals: Backend::Vec

, zerocheck_challenges: Vec, @@ -477,7 +477,7 @@ where let state = ProverState::new_with_switchover_rounds( multilinears, - &switchover_rounds, + switchover_rounds, claimed_prime_sums, evaluation_points, backend, diff --git a/crates/core/src/protocols/test_utils.rs b/crates/core/src/protocols/test_utils.rs index 1ad370ad1..892a5c276 100644 --- a/crates/core/src/protocols/test_utils.rs +++ b/crates/core/src/protocols/test_utils.rs @@ -122,7 +122,7 @@ where } pub fn transform_poly( - multilin: MultilinearExtension, + multilin: &MultilinearExtension, ) -> Result, PolynomialError> where F: Field, diff --git a/crates/core/src/reed_solomon/reed_solomon.rs b/crates/core/src/reed_solomon/reed_solomon.rs index 771a306f1..292744223 100644 --- a/crates/core/src/reed_solomon/reed_solomon.rs +++ b/crates/core/src/reed_solomon/reed_solomon.rs @@ -40,7 +40,7 @@ where pub fn new( log_dimension: usize, log_inv_rate: usize, - ntt_options: NTTOptions, + ntt_options: &NTTOptions, ) -> Result { // Since we split work between log_inv_rate threads, we need to decrease the number of threads per each NTT transformation. let ntt_log_threads = ntt_options @@ -49,11 +49,11 @@ where .saturating_sub(log_inv_rate); let ntt = DynamicDispatchNTT::new( log_dimension + log_inv_rate, - NTTOptions { + &NTTOptions { thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: ntt_log_threads, }, - ..ntt_options + precompute_twiddles: ntt_options.precompute_twiddles, }, )?; diff --git a/crates/core/src/ring_switch/common.rs b/crates/core/src/ring_switch/common.rs index e86876e49..fb76fe80d 100644 --- a/crates/core/src/ring_switch/common.rs +++ b/crates/core/src/ring_switch/common.rs @@ -72,7 +72,7 @@ impl<'a, F: TowerField> EvalClaimSystem<'a, F> { pub fn new( oracles: &MultilinearOracleSet, commit_meta: &'a CommitMeta, - oracle_to_commit_index: SparseIndex, + oracle_to_commit_index: &SparseIndex, eval_claims: &'a [EvalcheckMultilinearClaim], ) -> Result { // Sort evaluation claims in ascending order by number of packed variables. This must diff --git a/crates/core/src/ring_switch/tests.rs b/crates/core/src/ring_switch/tests.rs index 5546e786e..be7d3fafc 100644 --- a/crates/core/src/ring_switch/tests.rs +++ b/crates/core/src/ring_switch/tests.rs @@ -208,7 +208,7 @@ fn with_test_instance_from_oracles( // Finish setting up the test case let system = - EvalClaimSystem::new(oracles, &commit_meta, oracle_to_commit_index, &eval_claims).unwrap(); + EvalClaimSystem::new(oracles, &commit_meta, &oracle_to_commit_index, &eval_claims).unwrap(); check_eval_point_consistency(&system); func(rng, system, witnesses) @@ -304,7 +304,7 @@ fn commit_prove_verify_piop( // Finish setting up the test case let system = - EvalClaimSystem::new(oracles, &commit_meta, oracle_to_commit_index, &eval_claims).unwrap(); + EvalClaimSystem::new(oracles, &commit_meta, &oracle_to_commit_index, &eval_claims).unwrap(); check_eval_point_consistency(&system); let mut proof = ProverTranscript::>::new(); diff --git a/crates/core/src/ring_switch/verify.rs b/crates/core/src/ring_switch/verify.rs index 9c2cf2255..8a841b2aa 100644 --- a/crates/core/src/ring_switch/verify.rs +++ b/crates/core/src/ring_switch/verify.rs @@ -77,7 +77,7 @@ where let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, Tower>( &system.sumcheck_claim_descs, &system.suffix_descs, - row_batch_coeffs, + &row_batch_coeffs, &mixing_coeffs, )?; let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals) @@ -175,7 +175,7 @@ fn accumulate_evaluations_by_prefixes( fn make_ring_switch_eq_inds( sumcheck_claim_descs: &[PIOPSumcheckClaimDesc], suffix_descs: &[EvalClaimSuffixDesc], - row_batch_coeffs: Arc>, + row_batch_coeffs: &Arc>, mixing_coeffs: &[F], ) -> Result>>, Error> where diff --git a/crates/field/src/arch/portable/packed_arithmetic.rs b/crates/field/src/arch/portable/packed_arithmetic.rs index c15f65fab..fbeea240b 100644 --- a/crates/field/src/arch/portable/packed_arithmetic.rs +++ b/crates/field/src/arch/portable/packed_arithmetic.rs @@ -331,7 +331,7 @@ where OP: PackedBinaryField, { pub fn new + Sync>( - transformation: FieldLinearTransformation, + transformation: &FieldLinearTransformation, ) -> Self { Self { bases: transformation @@ -387,7 +387,7 @@ where fn make_packed_transformation + Sync>( transformation: FieldLinearTransformation, ) -> Self::PackedTransformation { - PackedTransformation::new(transformation) + PackedTransformation::new(&transformation) } } diff --git a/crates/field/src/packed.rs b/crates/field/src/packed.rs index 7c0b9ad8f..5fcddf146 100644 --- a/crates/field/src/packed.rs +++ b/crates/field/src/packed.rs @@ -495,7 +495,7 @@ mod tests { } /// Run the test for all the packed fields defined in this crate. - fn run_for_all_packed_fields(test: impl PackedFieldTest) { + fn run_for_all_packed_fields(test: &impl PackedFieldTest) { // canonical tower test.run::(); @@ -665,6 +665,6 @@ mod tests { #[test] fn test_iteration() { - run_for_all_packed_fields(PackedFieldIterationTest); + run_for_all_packed_fields(&PackedFieldIterationTest); } } diff --git a/crates/ntt/src/dynamic_dispatch.rs b/crates/ntt/src/dynamic_dispatch.rs index f7eb830f6..bd1d0f8f8 100644 --- a/crates/ntt/src/dynamic_dispatch.rs +++ b/crates/ntt/src/dynamic_dispatch.rs @@ -54,7 +54,7 @@ pub enum DynamicDispatchNTT { impl DynamicDispatchNTT { /// Create a new AdditiveNTT based on the given settings. - pub fn new(log_domain_size: usize, options: NTTOptions) -> Result { + pub fn new(log_domain_size: usize, options: &NTTOptions) -> Result { let log_threads = options.thread_settings.log_threads_count(); let result = match (options.precompute_twiddles, log_threads) { (false, 0) => Self::SingleThreaded(SingleThreadedNTT::new(log_domain_size)?), @@ -144,24 +144,24 @@ mod tests { #[test] fn test_creation() { - fn make_ntt(options: NTTOptions) -> DynamicDispatchNTT { + fn make_ntt(options: &NTTOptions) -> DynamicDispatchNTT { DynamicDispatchNTT::::new(6, options).unwrap() } - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: false, thread_settings: ThreadingSettings::SingleThreaded, }); assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_))); - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: true, thread_settings: ThreadingSettings::SingleThreaded, }); assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_))); let multithreaded = get_log_max_threads() > 0; - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: false, thread_settings: ThreadingSettings::MultithreadedDefault, }); @@ -171,7 +171,7 @@ mod tests { assert!(matches!(ntt, DynamicDispatchNTT::SingleThreaded(_))); } - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: true, thread_settings: ThreadingSettings::MultithreadedDefault, }); @@ -181,19 +181,19 @@ mod tests { assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_))); } - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: false, thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 2 }, }); assert!(matches!(ntt, DynamicDispatchNTT::MultiThreaded(_))); - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: true, thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 }, }); assert!(matches!(ntt, DynamicDispatchNTT::SingleThreadedPrecompute(_))); - let ntt = make_ntt(NTTOptions { + let ntt = make_ntt(&NTTOptions { precompute_twiddles: false, thread_settings: ThreadingSettings::ExplicitThreadsCount { log_threads: 0 }, }); diff --git a/crates/ntt/src/single_threaded.rs b/crates/ntt/src/single_threaded.rs index a5723a62e..2448894c6 100644 --- a/crates/ntt/src/single_threaded.rs +++ b/crates/ntt/src/single_threaded.rs @@ -187,7 +187,7 @@ pub fn forward_transform>( // packed twiddles for all packed butterfly units. let log_block_len = i + log_b; let block_twiddle = calculate_twiddle::

( - s_evals[i].coset(log_domain_size - 1 - cutoff, 0), + &s_evals[i].coset(log_domain_size - 1 - cutoff, 0), log_block_len, ); @@ -263,7 +263,7 @@ pub fn inverse_transform>( // packed twiddles for all packed butterfly units. let log_block_len = i + log_b; let block_twiddle = calculate_twiddle::

( - s_evals[i].coset(log_domain_size - 1 - cutoff, 0), + &s_evals[i].coset(log_domain_size - 1 - cutoff, 0), log_block_len, ); @@ -357,7 +357,7 @@ pub const fn check_batch_transform_inputs_and_params( } #[inline] -fn calculate_twiddle

(s_evals: impl TwiddleAccess, log_block_len: usize) -> P +fn calculate_twiddle

(s_evals: &impl TwiddleAccess, log_block_len: usize) -> P where P: PackedField, { diff --git a/examples/keccakf_circuit.rs b/examples/keccakf_circuit.rs index ed87a920a..1fa583e4a 100644 --- a/examples/keccakf_circuit.rs +++ b/examples/keccakf_circuit.rs @@ -52,7 +52,7 @@ fn main() -> Result<()> { let trace_gen_scope = tracing::info_span!("generating trace").entered(); let input_witness = vec![]; let _state_out = - binius_circuits::keccakf::keccakf(&mut builder, Some(input_witness), log_size)?; + binius_circuits::keccakf::keccakf(&mut builder, &Some(input_witness), log_size)?; drop(trace_gen_scope); let witness = builder From 3c97cf98d3225047bcde8375889262e1f3f8b8be Mon Sep 17 00:00:00 2001 From: Dmitry Gordon Date: Tue, 4 Feb 2025 10:41:32 +0000 Subject: [PATCH 05/50] [math] Fix `fold_right` crash on big multilinears and make it single threaded --- crates/math/src/fold.rs | 228 ++++++++++++++++------------------------ 1 file changed, 91 insertions(+), 137 deletions(-) diff --git a/crates/math/src/fold.rs b/crates/math/src/fold.rs index 96621f3ab..1663368ce 100644 --- a/crates/math/src/fold.rs +++ b/crates/math/src/fold.rs @@ -14,13 +14,8 @@ use binius_field::{ AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ExtensionField, Field, PackedField, }; -use binius_maybe_rayon::{ - iter::{IndexedParallelIterator, ParallelIterator}, - slice::ParallelSliceMut, -}; use binius_utils::bail; use bytemuck::fill_zeroes; -use itertools::max; use lazy_static::lazy_static; use stackalloc::helpers::slice_assume_init_mut; @@ -30,6 +25,9 @@ use crate::Error; /// /// Every consequent `1 << log_query_size` scalar values are dot-producted with the corresponding /// query elements. The result is stored in the `output` slice of packed values. +/// +/// Please note that this method is single threaded. Currently we always have some +/// parallelism above this level, so it's not a problem. pub fn fold_right( evals: &[P], log_evals_size: usize, @@ -61,7 +59,7 @@ where /// with the corresponding query element. The results is written to the `output` slice of packed values. /// If the function returns `Ok(())`, then `out` can be safely interpreted as initialized. /// -/// Please note that unlike `fold_right`, this method is single threaded. Currently we always have some +/// Please note that this method is single threaded. Currently we always have some /// parallelism above this level, so it's not a problem. Having no parallelism inside allows us to /// use more efficient optimizations for special cases. If we ever need a parallel version of this /// function, we can implement it separately. @@ -143,18 +141,8 @@ where if LOG_QUERY_SIZE >= 3 { return false; } - let chunk_size = 1 - << max(&[ - 10, - (P::LOG_WIDTH + LOG_QUERY_SIZE).saturating_sub(PE::LOG_WIDTH), - PE::LOG_WIDTH, - ]) - .expect("sequence of max values is not empty"); - if out.len() % chunk_size != 0 { - return false; - } - if P::WIDTH << LOG_QUERY_SIZE > chunk_size << PE::LOG_WIDTH { + if P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH { return false; } @@ -171,52 +159,43 @@ where }) .collect::>(); - out.par_chunks_mut(chunk_size) - .enumerate() - .for_each(|(index, chunk)| { - let input_offset = - ((index * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; - let input_end = - (((index + 1) * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; - - struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> { - chunk: &'a mut [PE], - cached_table: &'a [PE::Scalar], - } + struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> { + out: &'a mut [PE], + cached_table: &'a [PE::Scalar], + } - impl ByteIteratorCallback - for Callback<'_, PE, LOG_QUERY_SIZE> - { - #[inline(always)] - fn call(&mut self, iterator: impl Iterator) { - let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1; - let values_in_byte = 1 << (3 - LOG_QUERY_SIZE); - let mut current_index = 0; - for byte in iterator { - for k in 0..values_in_byte { - let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask; - // Safety: `i` is less than `chunk_size` - unsafe { - set_packed_slice_unchecked( - self.chunk, - current_index + k, - self.cached_table[index as usize], - ); - } - } - - current_index += values_in_byte; + impl ByteIteratorCallback + for Callback<'_, PE, LOG_QUERY_SIZE> + { + #[inline(always)] + fn call(&mut self, iterator: impl Iterator) { + let mask = (1 << (1 << LOG_QUERY_SIZE)) - 1; + let values_in_byte = 1 << (3 - LOG_QUERY_SIZE); + let mut current_index = 0; + for byte in iterator { + for k in 0..values_in_byte { + let index = (byte >> (k * (1 << LOG_QUERY_SIZE))) & mask; + // Safety: `i` is less than `chunk_size` + unsafe { + set_packed_slice_unchecked( + self.out, + current_index + k, + self.cached_table[index as usize], + ); } } + + current_index += values_in_byte; } + } + } - let mut callback = Callback::<'_, PE, LOG_QUERY_SIZE> { - chunk, - cached_table: &cached_table, - }; + let mut callback = Callback::<'_, PE, LOG_QUERY_SIZE> { + out, + cached_table: &cached_table, + }; - iterate_bytes(&evals[input_offset..input_end], &mut callback); - }); + iterate_bytes(evals, &mut callback); true } @@ -234,71 +213,52 @@ where if LOG_QUERY_SIZE < 3 { return false; } - let chunk_size = 1 - << max(&[ - 10, - (P::LOG_WIDTH + LOG_QUERY_SIZE).saturating_sub(PE::LOG_WIDTH), - PE::LOG_WIDTH, - ]) - .expect("sequence of max values is not empty"); - if out.len() % chunk_size != 0 { + + if P::LOG_WIDTH + LOG_QUERY_SIZE > PE::LOG_WIDTH { return false; } let cached_tables = create_partial_sums_lookup_tables(PackedSlice::new(query, 1 << LOG_QUERY_SIZE)); - out.par_chunks_mut(chunk_size) - .enumerate() - .for_each(|(index, chunk)| { - let input_offset = - ((index * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; - let input_end = - (((index + 1) * chunk_size) << (LOG_QUERY_SIZE + PE::LOG_WIDTH)) / P::WIDTH; - - struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> { - chunk: &'a mut [PE], - cached_tables: &'a [PE::Scalar], - current_value: PE::Scalar, - current_table: usize, - } - impl ByteIteratorCallback - for Callback<'_, PE, LOG_QUERY_SIZE> - { - #[inline(always)] - fn call(&mut self, iterator: impl Iterator) { - let log_tables_count = LOG_QUERY_SIZE - 3; - let tables_count = 1 << log_tables_count; - for (current_index, byte) in iterator.enumerate() { - self.current_value += - self.cached_tables[(self.current_table << 8) + byte as usize]; - self.current_table += 1; - - if self.current_table == tables_count { - // Safety: `i` is less than `chunk_size` - unsafe { - set_packed_slice_unchecked( - self.chunk, - current_index, - self.current_value, - ); - } - self.current_table = 0; - self.current_value = PE::Scalar::ZERO; - } + struct Callback<'a, PE: PackedField, const LOG_QUERY_SIZE: usize> { + out: &'a mut [PE], + cached_tables: &'a [PE::Scalar], + } + + impl ByteIteratorCallback + for Callback<'_, PE, LOG_QUERY_SIZE> + { + #[inline(always)] + fn call(&mut self, iterator: impl Iterator) { + let log_tables_count = LOG_QUERY_SIZE - 3; + let tables_count = 1 << log_tables_count; + let mut current_index = 0; + let mut current_table = 0; + let mut current_value = PE::Scalar::ZERO; + for byte in iterator { + current_value += self.cached_tables[(current_table << 8) + byte as usize]; + current_table += 1; + + if current_table == tables_count { + // Safety: `i` is less than `chunk_size` + unsafe { + set_packed_slice_unchecked(self.out, current_index, current_value); } + current_index += 1; + current_table = 0; + current_value = PE::Scalar::ZERO; } } + } + } - let mut callback = Callback::<'_, _, LOG_QUERY_SIZE> { - chunk, - cached_tables: &cached_tables, - current_value: PE::Scalar::ZERO, - current_table: 0, - }; + let mut callback = Callback::<'_, _, LOG_QUERY_SIZE> { + out, + cached_tables: &cached_tables, + }; - iterate_bytes(&evals[input_offset..input_end], &mut callback); - }); + iterate_bytes(evals, &mut callback); true } @@ -351,34 +311,26 @@ fn fold_right_fallback( P: PackedField, PE: PackedField>, { - const CHUNK_SIZE: usize = 1 << 10; - let packed_result_evals = out; - packed_result_evals - .par_chunks_mut(CHUNK_SIZE) - .enumerate() - .for_each(|(i, packed_result_evals)| { - for (k, packed_result_eval) in packed_result_evals.iter_mut().enumerate() { - let offset = i * CHUNK_SIZE; - for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) { - let index = ((offset + k) << PE::LOG_WIDTH) | j; + for (k, packed_result_eval) in out.iter_mut().enumerate() { + for j in 0..min(PE::WIDTH, 1 << (log_evals_size - log_query_size)) { + let index = (k << PE::LOG_WIDTH) | j; - let offset = index << log_query_size; + let offset = index << log_query_size; - let mut result_eval = PE::Scalar::ZERO; - for (t, query_expansion) in PackedField::iter_slice(query) - .take(1 << log_query_size) - .enumerate() - { - result_eval += query_expansion * get_packed_slice(evals, t + offset); - } + let mut result_eval = PE::Scalar::ZERO; + for (t, query_expansion) in PackedField::iter_slice(query) + .take(1 << log_query_size) + .enumerate() + { + result_eval += query_expansion * get_packed_slice(evals, t + offset); + } - // Safety: `j` < `PE::WIDTH` - unsafe { - packed_result_eval.set_unchecked(j, result_eval); - } - } + // Safety: `j` < `PE::WIDTH` + unsafe { + packed_result_eval.set_unchecked(j, result_eval); } - }); + } + } } type ArchOptimaType = ::OptimalThroughputPacked; @@ -607,7 +559,7 @@ mod tests { use binius_field::{ packed::set_packed_slice, PackedBinaryField128x1b, PackedBinaryField16x32b, - PackedBinaryField16x8b, PackedBinaryField512x1b, + PackedBinaryField16x8b, PackedBinaryField512x1b, PackedBinaryField64x8b, }; use rand::{rngs::StdRng, SeedableRng}; @@ -689,7 +641,9 @@ mod tests { let evals = repeat_with(|| PackedBinaryField128x1b::random(&mut rng)) .take(1 << LOG_EVALS_SIZE) .collect::>(); - let query = vec![PackedBinaryField512x1b::random(&mut rng)]; + let query = repeat_with(|| PackedBinaryField64x8b::random(&mut rng)) + .take(8) + .collect::>(); for log_query_size in 0..10 { check_fold_right( From f76b29b4030a2393b0d090253fca6ea628876344 Mon Sep 17 00:00:00 2001 From: Nikita Lesnikov Date: Tue, 4 Feb 2025 12:57:37 +0000 Subject: [PATCH 06/50] [math] Use specialized zero variable folding in the first sumcheck round. --- .../protocols/sumcheck/prove/prover_state.rs | 5 +- crates/math/src/fold.rs | 53 ++++++++++++++++- crates/math/src/mle_adapters.rs | 58 ------------------- crates/math/src/multilinear_extension.rs | 32 +++++++--- 4 files changed, 78 insertions(+), 70 deletions(-) diff --git a/crates/core/src/protocols/sumcheck/prove/prover_state.rs b/crates/core/src/protocols/sumcheck/prove/prover_state.rs index 606152664..a7e252fe9 100644 --- a/crates/core/src/protocols/sumcheck/prove/prover_state.rs +++ b/crates/core/src/protocols/sumcheck/prove/prover_state.rs @@ -191,8 +191,11 @@ where ref mut large_field_folded_multilinear, } => { // Post-switchover, simply plug in challenge for the zeroth variable. + let single_variable_query = MultilinearQuery::expand(&[challenge]); *large_field_folded_multilinear = MLEDirectAdapter::from( - large_field_folded_multilinear.evaluate_zeroth_variable(challenge)?, + large_field_folded_multilinear + .as_ref() + .evaluate_partial_low(single_variable_query.to_ref())?, ); } }; diff --git a/crates/math/src/fold.rs b/crates/math/src/fold.rs index 1663368ce..a35271a30 100644 --- a/crates/math/src/fold.rs +++ b/crates/math/src/fold.rs @@ -9,7 +9,7 @@ use binius_field::{ can_iterate_bytes, create_partial_sums_lookup_tables, is_sequential_bytes, iterate_bytes, ByteIteratorCallback, PackedSlice, }, - packed::{get_packed_slice, set_packed_slice_unchecked}, + packed::{get_packed_slice, get_packed_slice_unchecked, set_packed_slice_unchecked}, underlier::{UnderlierWithBitOps, WithUnderlier}, AESTowerField128b, BinaryField128b, BinaryField128bPolyval, BinaryField1b, ExtensionField, Field, PackedField, @@ -48,7 +48,16 @@ where return Ok(()); } - fold_right_fallback(evals, log_evals_size, query, log_query_size, out); + // Use linear interpolation for single variable multilinear queries. + let is_lerp = log_query_size == 1 + && get_packed_slice(query, 0) + get_packed_slice(query, 1) == PE::Scalar::ONE; + + if is_lerp { + let lerp_query = get_packed_slice(query, 1); + fold_right_lerp(evals, log_evals_size, lerp_query, out); + } else { + fold_right_fallback(evals, log_evals_size, query, log_query_size, out); + } Ok(()) } @@ -300,6 +309,46 @@ where } } +/// Specialized implementation for a single parameter right fold using linear interpolation +/// instead of tensor expansion resulting in a single multiplication instead of two: +/// f(r||w) = r * (f(1||w) - f(0||w)) + f(0||w). +/// +/// The same approach may be generalized to higher variable counts, with diminishing returns. +fn fold_right_lerp( + evals: &[P], + log_evals_size: usize, + lerp_query: PE::Scalar, + out: &mut [PE], +) where + P: PackedField, + PE: PackedField>, +{ + assert_eq!(1 << log_evals_size.saturating_sub(PE::LOG_WIDTH + 1), out.len()); + + out.iter_mut() + .enumerate() + .for_each(|(i, packed_result_eval)| { + for j in 0..min(PE::WIDTH, 1 << (log_evals_size - 1)) { + let index = (i << PE::LOG_WIDTH) | j; + + let (eval0, eval1) = unsafe { + ( + get_packed_slice_unchecked(evals, index << 1), + get_packed_slice_unchecked(evals, (index << 1) | 1), + ) + }; + + let result_eval = + PE::Scalar::from(eval1 - eval0) * lerp_query + PE::Scalar::from(eval0); + + // Safety: `j` < `PE::WIDTH` + unsafe { + packed_result_eval.set_unchecked(j, result_eval); + } + } + }) +} + /// Fallback implementation for fold that can be executed for any field types and sizes. fn fold_right_fallback( evals: &[P], diff --git a/crates/math/src/mle_adapters.rs b/crates/math/src/mle_adapters.rs index 079fc4ddd..1a8d8fec8 100644 --- a/crates/math/src/mle_adapters.rs +++ b/crates/math/src/mle_adapters.rs @@ -8,7 +8,6 @@ use binius_field::{ }, ExtensionField, Field, PackedField, RepackedExtension, }; -use binius_maybe_rayon::prelude::*; use binius_utils::bail; use super::{Error, MultilinearExtension, MultilinearPoly, MultilinearQueryRef}; @@ -299,44 +298,6 @@ where pub fn upcast_arc_dyn(self) -> Arc + Send + Sync + 'a> { Arc::new(self) } - - /// Given a ($mu$-variate) multilinear function $f$ and an element $r$, - /// return the multilinear function $f(r, X_1, ..., X_{\mu - 1})$. - pub fn evaluate_zeroth_variable(&self, r: P::Scalar) -> Result, Error> { - let multilin = &self.0; - let mu = multilin.n_vars(); - if mu == 0 { - bail!(Error::ConstantFold); - } - let packed_length = 1 << mu.saturating_sub(P::LOG_WIDTH + 1); - // in general, the formula is: f(r||w) = r * (f(1||w) - f(0||w)) + f(0||w). - let result = (0..packed_length) - .into_par_iter() - .map(|i| { - let eval0_minus_eval1 = P::from_fn(|j| { - let index = (i << P::LOG_WIDTH) | j; - // necessary if `mu_minus_one` < `P::LOG_WIDTH` - if index >= 1 << (mu - 1) { - return P::Scalar::ZERO; - } - let eval0 = get_packed_slice(multilin.evals(), index << 1); - let eval1 = get_packed_slice(multilin.evals(), (index << 1) | 1); - eval0 - eval1 - }); - let eval0 = P::from_fn(|j| { - let index = (i << P::LOG_WIDTH) | j; - // necessary if `mu_minus_one` < `P::LOG_WIDTH` - if index >= 1 << (mu - 1) { - return P::Scalar::ZERO; - } - get_packed_slice(multilin.evals(), index << 1) - }); - eval0_minus_eval1 * r + eval0 - }) - .collect::>(); - - MultilinearExtension::new(mu - 1, result) - } } impl From> for MLEDirectAdapter @@ -699,23 +660,4 @@ mod tests { .unwrap(); assert_eq!(evals_out, poly.packed_evals().unwrap()); } - - #[test] - fn test_evaluate_zeroth_evaluate_partial_low_consistent() { - let mut rng = StdRng::seed_from_u64(0); - let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng)) - .take(1 << 8) - .collect(); - - let me = MultilinearExtension::from_values(values).unwrap(); - let mled = MLEDirectAdapter::from(me); - let r = ::random(&mut rng); - - let eval_1: MultilinearExtension = - mled.evaluate_zeroth_variable(r).unwrap(); - let eval_2 = mled - .evaluate_partial_low(multilinear_query(&[r]).to_ref()) - .unwrap(); - assert_eq!(eval_1, eval_2); - } } diff --git a/crates/math/src/multilinear_extension.rs b/crates/math/src/multilinear_extension.rs index 91c379c83..8403476ee 100644 --- a/crates/math/src/multilinear_extension.rs +++ b/crates/math/src/multilinear_extension.rs @@ -245,6 +245,7 @@ where PE::Scalar: ExtensionField, { let query = query.into(); + if self.mu < query.n_vars() { bail!(Error::IncorrectQuerySize { expected: self.mu }); } @@ -275,15 +276,6 @@ where PE: PackedField, PE::Scalar: ExtensionField, { - if self.mu < query.n_vars() { - bail!(Error::IncorrectQuerySize { expected: self.mu }); - } - if out.len() != 1 << ((self.mu - query.n_vars()).saturating_sub(PE::LOG_WIDTH)) { - bail!(Error::IncorrectOutputPolynomialSize { - expected: self.mu - query.n_vars(), - }); - } - // This operation is a matrix-vector product of the matrix of multilinear coefficients with // the vector of tensor product-expanded query coefficients. fold_right(&self.evals, self.mu, query.expansion(), query.n_vars(), out) @@ -559,6 +551,28 @@ mod tests { ); } + #[test] + fn test_evaluate_partial_low_single_and_multiple_var_consistent() { + let mut rng = StdRng::seed_from_u64(0); + let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng)) + .take(1 << 8) + .collect(); + + let mle = MultilinearExtension::from_values(values).unwrap(); + let r1 = ::random(&mut rng); + let r2 = ::random(&mut rng); + + let eval_1: MultilinearExtension = mle + .evaluate_partial_low::(multilinear_query(&[r1]).to_ref()) + .unwrap() + .evaluate_partial_low(multilinear_query(&[r2]).to_ref()) + .unwrap(); + let eval_2 = mle + .evaluate_partial_low(multilinear_query(&[r1, r2]).to_ref()) + .unwrap(); + assert_eq!(eval_1, eval_2); + } + #[test] fn test_new_mle_with_tiny_nvars() { MultilinearExtension::new( From 3296da5169cef46cfeb461dfe3df165a9f0e13ea Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Tue, 4 Feb 2025 13:01:06 +0000 Subject: [PATCH 07/50] [test]: add test coverage for eq_ind_partial_eval --- .../protocols/sumcheck/prove/univariate.rs | 2 +- crates/math/src/multilinear_query.rs | 2 +- crates/math/src/tensor_prod_eq_ind.rs | 65 +++++++++++++++++-- 3 files changed, 62 insertions(+), 7 deletions(-) diff --git a/crates/core/src/protocols/sumcheck/prove/univariate.rs b/crates/core/src/protocols/sumcheck/prove/univariate.rs index 7c80eff51..7c373f96c 100644 --- a/crates/core/src/protocols/sumcheck/prove/univariate.rs +++ b/crates/core/src/protocols/sumcheck/prove/univariate.rs @@ -388,7 +388,7 @@ where // univariatized subcube. // NB: expansion of the first `skip_rounds` variables is applied to the round evals sum let partial_eq_ind_evals = backend.tensor_product_full_query(zerocheck_challenges)?; - let partial_eq_ind_evals_scalars = P::unpack_scalars(&partial_eq_ind_evals[..]); + let partial_eq_ind_evals_scalars = P::unpack_scalars(&partial_eq_ind_evals); // Evaluate each composition on a minimal packed prefix corresponding to the degree let pbase_prefix_lens = composition_degrees diff --git a/crates/math/src/multilinear_query.rs b/crates/math/src/multilinear_query.rs index fc27a689a..a4b3e5bd6 100644 --- a/crates/math/src/multilinear_query.rs +++ b/crates/math/src/multilinear_query.rs @@ -36,7 +36,7 @@ impl<'a, P: PackedField, Data: DerefMut> From<&'a MultilinearQuery for MultilinearQueryRef<'a, P> { fn from(query: &'a MultilinearQuery) -> Self { - MultilinearQueryRef::new(query) + Self::new(query) } } diff --git a/crates/math/src/tensor_prod_eq_ind.rs b/crates/math/src/tensor_prod_eq_ind.rs index 8bcba36fe..ed48bb1de 100644 --- a/crates/math/src/tensor_prod_eq_ind.rs +++ b/crates/math/src/tensor_prod_eq_ind.rs @@ -62,7 +62,7 @@ pub fn tensor_prod_eq_ind( xs.par_iter_mut() .zip(ys.par_iter_mut()) .with_min_len(64) - .for_each(|(x, y): (&mut P, &mut P)| { + .for_each(|(x, y)| { // x = x * (1 - packed_r_i) = x - x * packed_r_i // y = x * packed_r_i // Notice that we can reuse the multiplication: (x * packed_r_i) @@ -95,8 +95,7 @@ pub fn eq_ind_partial_eval(point: &[P::Scalar]) -> Vec

{ let len = 1 << n.saturating_sub(P::LOG_WIDTH); let mut buffer = zeroed_vec::

(len); buffer[0].set(0, P::Scalar::ONE); - tensor_prod_eq_ind(0, &mut buffer[..], point) - .expect("buffer is allocated with the correct length"); + tensor_prod_eq_ind(0, &mut buffer, point).expect("buffer is allocated with the correct length"); buffer } @@ -107,10 +106,11 @@ mod tests { use super::*; + type P = PackedBinaryField4x32b; + type F =

::Scalar; + #[test] fn test_tensor_prod_eq_ind() { - type P = PackedBinaryField4x32b; - type F =

::Scalar; let v0 = F::new(1); let v1 = F::new(2); let query = vec![v0, v1]; @@ -128,4 +128,59 @@ mod tests { ] ); } + + #[test] + fn test_eq_ind_partial_eval_empty() { + let result = eq_ind_partial_eval::

(&[]); + let expected = vec![P::set_single(F::ONE)]; + assert_eq!(result, expected); + } + + #[test] + fn test_eq_ind_partial_eval_single_var() { + // Only one query coordinate + let r0 = F::new(2); + let result = eq_ind_partial_eval::

(&[r0]); + let expected = vec![(F::ONE - r0), r0, F::ZERO, F::ZERO]; + let result = PackedField::iter_slice(&result).collect_vec(); + assert_eq!(result, expected); + } + + #[test] + fn test_eq_ind_partial_eval_two_vars() { + // Two query coordinates + let r0 = F::new(2); + let r1 = F::new(3); + let result = eq_ind_partial_eval::

(&[r0, r1]); + let result = PackedField::iter_slice(&result).collect_vec(); + let expected = vec![ + (F::ONE - r0) * (F::ONE - r1), + r0 * (F::ONE - r1), + (F::ONE - r0) * r1, + r0 * r1, + ]; + assert_eq!(result, expected); + } + + #[test] + fn test_eq_ind_partial_eval_three_vars() { + // Case with three query coordinates + let r0 = F::new(2); + let r1 = F::new(3); + let r2 = F::new(5); + let result = eq_ind_partial_eval::

(&[r0, r1, r2]); + let result = PackedField::iter_slice(&result).collect_vec(); + + let expected = vec![ + (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2), + r0 * (F::ONE - r1) * (F::ONE - r2), + (F::ONE - r0) * r1 * (F::ONE - r2), + r0 * r1 * (F::ONE - r2), + (F::ONE - r0) * (F::ONE - r1) * r2, + r0 * (F::ONE - r1) * r2, + (F::ONE - r0) * r1 * r2, + r0 * r1 * r2, + ]; + assert_eq!(result, expected); + } } From 6161f7f79077fa06caa1225c20e3229bc1db9f86 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Thu, 30 Jan 2025 19:28:21 +0100 Subject: [PATCH 08/50] [test]: add test coverage for inner_product_par --- crates/field/src/util.rs | 97 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 92 insertions(+), 5 deletions(-) diff --git a/crates/field/src/util.rs b/crates/field/src/util.rs index 3bc6ec295..12f99a256 100644 --- a/crates/field/src/util.rs +++ b/crates/field/src/util.rs @@ -16,7 +16,7 @@ where F: Field, FE: ExtensionField, { - iter::zip(a, b).map(|(a_i, b_i)| a_i * b_i).sum::() + iter::zip(a, b).map(|(a_i, b_i)| a_i * b_i).sum() } /// Calculate inner product for potentially big slices of xs and ys. @@ -38,9 +38,8 @@ where return inner_product_unchecked(PackedField::iter_slice(xs), PackedField::iter_slice(ys)); } - let calc_product_by_ys = |x_offset, ys: &[PY]| { + let calc_product_by_ys = |xs: &[PX], ys: &[PY]| { let mut result = FX::ZERO; - let xs = &xs[x_offset..]; for (j, y) in ys.iter().enumerate() { for (k, y) in y.iter().enumerate() { @@ -56,14 +55,14 @@ where // For different field sizes, the numbers may need to be adjusted. const CHUNK_SIZE: usize = 64; if ys.len() < 16 * CHUNK_SIZE { - calc_product_by_ys(0, ys) + calc_product_by_ys(xs, ys) } else { // According to benchmark results iterating by chunks here is more efficient than using `par_iter` with `min_length` directly. ys.par_chunks(CHUNK_SIZE) .enumerate() .map(|(i, ys)| { let offset = i * checked_int_div(CHUNK_SIZE * PY::WIDTH, PX::WIDTH); - calc_product_by_ys(offset, ys) + calc_product_by_ys(&xs[offset..], ys) }) .sum() } @@ -79,3 +78,91 @@ pub fn eq(x: F, y: F) -> F { pub fn powers(val: F) -> impl Iterator { iter::successors(Some(F::ONE), move |&power| Some(power * val)) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::PackedBinaryField4x32b; + + type P = PackedBinaryField4x32b; + type F =

::Scalar; + + #[test] + fn test_inner_product_par_equal_length() { + // xs and ys have the same number of packed elements + let xs1 = F::new(1); + let xs2 = F::new(2); + let xs = vec![P::set_single(xs1), P::set_single(xs2)]; + let ys1 = F::new(3); + let ys2 = F::new(4); + let ys = vec![P::set_single(ys1), P::set_single(ys2)]; + + let result = inner_product_par::(&xs, &ys); + let expected = xs1 * ys1 + xs2 * ys2; + + assert_eq!(result, expected); + } + + #[test] + fn test_inner_product_par_unequal_length() { + // ys is larger than xs due to packing differences + let xs1 = F::new(1); + let xs = vec![P::set_single(xs1)]; + let ys1 = F::new(2); + let ys2 = F::new(3); + let ys = vec![P::set_single(ys1), P::set_single(ys2)]; + + let result = inner_product_par::(&xs, &ys); + let expected = xs1 * ys1; + + assert_eq!(result, expected); + } + + #[test] + fn test_inner_product_par_large_input_single_threaded() { + // Large input but not enough to trigger parallel execution + let size = 256; + let xs: Vec

= (0..size).map(|i| P::set_single(F::new(i as u32))).collect(); + let ys: Vec

= (0..size) + .map(|i| P::set_single(F::new((i + 1) as u32))) + .collect(); + + let result = inner_product_par::(&xs, &ys); + + let expected = (0..size) + .map(|i| F::new(i as u32) * F::new((i + 1) as u32)) + .sum::(); + + assert_eq!(result, expected); + } + + #[test] + fn test_inner_product_par_large_input_par() { + // Large input to test parallel execution + let size = 2000; + let xs: Vec

= (0..size).map(|i| P::set_single(F::new(i as u32))).collect(); + let ys: Vec

= (0..size) + .map(|i| P::set_single(F::new((i + 1) as u32))) + .collect(); + + let result = inner_product_par::(&xs, &ys); + + let expected = (0..size) + .map(|i| F::new(i as u32) * F::new((i + 1) as u32)) + .sum::(); + + assert_eq!(result, expected); + } + + #[test] + fn test_inner_product_par_empty() { + // Case: Empty input should return 0 + let xs: Vec

= vec![]; + let ys: Vec

= vec![]; + + let result = inner_product_par::(&xs, &ys); + let expected = F::ZERO; + + assert_eq!(result, expected); + } +} From 78f7d5b7120df30674e48ce1f8bf9fdad220a787 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Tue, 4 Feb 2025 13:20:05 +0000 Subject: [PATCH 09/50] [test]: test coverage for MultilinearQuery update --- crates/math/src/multilinear_query.rs | 102 ++++++++++++++++++++++++++- 1 file changed, 100 insertions(+), 2 deletions(-) diff --git a/crates/math/src/multilinear_query.rs b/crates/math/src/multilinear_query.rs index a4b3e5bd6..9593367ac 100644 --- a/crates/math/src/multilinear_query.rs +++ b/crates/math/src/multilinear_query.rs @@ -73,7 +73,7 @@ impl MultilinearQuery> { } pub fn expand(query: &[P::Scalar]) -> Self { - let expanded_query = eq_ind_partial_eval::

(query); + let expanded_query = eq_ind_partial_eval(query); Self { expanded_query, n_vars: query.len(), @@ -148,12 +148,16 @@ impl> MultilinearQuery { #[cfg(test)] mod tests { - use binius_field::{Field, PackedField}; + use binius_field::{Field, PackedBinaryField4x32b, PackedField}; use binius_utils::felts; + use itertools::Itertools; use super::*; use crate::tensor_prod_eq_ind; + type P = PackedBinaryField4x32b; + type F =

::Scalar; + fn tensor_prod(p: &[P::Scalar]) -> Vec

{ let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)]; result[0] = P::set_single(P::Scalar::ONE); @@ -252,4 +256,98 @@ mod tests { felts!(BinaryField16b[3, 2, 2, 1, 2, 1, 1, 3, 2, 1, 1, 3, 1, 3, 3, 2]) ); } + + #[test] + fn test_update_single_var() { + let query = MultilinearQuery::

::with_capacity(2); + let r0 = F::new(2); + let extra_query = [r0]; + + let updated_query = query.update(&extra_query).unwrap(); + + assert_eq!(updated_query.n_vars(), 1); + + let expansion = updated_query.into_expansion(); + let expansion = PackedField::iter_slice(&expansion).collect_vec(); + + assert_eq!(expansion, vec![(F::ONE - r0), r0, F::ZERO, F::ZERO]); + } + + #[test] + fn test_update_two_vars() { + let query = MultilinearQuery::

::with_capacity(3); + let r0 = F::new(2); + let r1 = F::new(3); + let extra_query = [r0, r1]; + + let updated_query = query.update(&extra_query).unwrap(); + assert_eq!(updated_query.n_vars(), 2); + + let expansion = updated_query.expansion(); + let expansion = PackedField::iter_slice(expansion).collect_vec(); + + assert_eq!( + expansion, + vec![ + (F::ONE - r0) * (F::ONE - r1), + r0 * (F::ONE - r1), + (F::ONE - r0) * r1, + r0 * r1, + ] + ); + } + + #[test] + fn test_update_three_vars() { + let query = MultilinearQuery::

::with_capacity(4); + let r0 = F::new(2); + let r1 = F::new(3); + let r2 = F::new(5); + let extra_query = [r0, r1, r2]; + + let updated_query = query.update(&extra_query).unwrap(); + assert_eq!(updated_query.n_vars(), 3); + + let expansion = updated_query.expansion(); + let expansion = PackedField::iter_slice(expansion).collect_vec(); + + assert_eq!( + expansion, + vec![ + (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2), + r0 * (F::ONE - r1) * (F::ONE - r2), + (F::ONE - r0) * r1 * (F::ONE - r2), + r0 * r1 * (F::ONE - r2), + (F::ONE - r0) * (F::ONE - r1) * r2, + r0 * (F::ONE - r1) * r2, + (F::ONE - r0) * r1 * r2, + r0 * r1 * r2, + ] + ); + } + + #[test] + fn test_update_exceeds_capacity() { + let query = MultilinearQuery::

::with_capacity(2); + // More than allowed capacity + let extra_query = [F::new(2), F::new(3), F::new(5)]; + + let result = query.update(&extra_query); + // Expecting an error due to exceeding max_query_vars + assert!(result.is_err()); + } + + #[test] + fn test_update_empty() { + let query = MultilinearQuery::

::with_capacity(2); + // Updating with no new coordinates should be fine + let updated_query = query.update(&[]).unwrap(); + + assert_eq!(updated_query.n_vars(), 0); + + let expansion = updated_query.expansion(); + let expansion = PackedField::iter_slice(expansion).collect_vec(); + + assert_eq!(expansion, vec![F::ONE, F::ZERO, F::ZERO, F::ZERO]); + } } From ce7f5308c9303cd418461ec9ff37888cb9fad277 Mon Sep 17 00:00:00 2001 From: Aliaksei Dziadziuk Date: Tue, 4 Feb 2025 13:27:37 +0000 Subject: [PATCH 10/50] [tracing] Display proof size in graph --- Cargo.toml | 2 +- crates/core/src/transcript/mod.rs | 38 +++++++++++++++++++++---------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9cf5f0969..5c6ff5406 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -118,7 +118,7 @@ thread_local = "1.1.7" tiny-keccak = { version = "2.0.2", features = ["keccak"] } trait-set = "0.3.0" tracing = "0.1.38" -tracing-profile = "0.9.0" +tracing-profile = "0.10.1" transpose = "0.2.2" [profile.release] diff --git a/crates/core/src/transcript/mod.rs b/crates/core/src/transcript/mod.rs index e6081b6c5..17d6ec021 100644 --- a/crates/core/src/transcript/mod.rs +++ b/crates/core/src/transcript/mod.rs @@ -334,20 +334,25 @@ impl TranscriptWriter<'_, B> { } pub fn write(&mut self, value: &T) { - value - .serialize(self.buffer()) - .expect("TODO: propagate error") + self.proof_size_event_wrapper(|buffer| { + value.serialize(buffer).expect("TODO: propagate error"); + }); } pub fn write_slice(&mut self, values: &[T]) { - let mut buffer = self.buffer(); - for value in values { - value.serialize(&mut buffer).expect("TODO: propagate error") - } + self.proof_size_event_wrapper(|buffer| { + for value in values { + value + .serialize(&mut *buffer) + .expect("TODO: propagate error"); + } + }); } pub fn write_bytes(&mut self, data: &[u8]) { - self.buffer().put_slice(data); + self.proof_size_event_wrapper(|buffer| { + buffer.put_slice(data); + }); } pub fn write_scalar(&mut self, f: F) { @@ -355,10 +360,11 @@ impl TranscriptWriter<'_, B> { } pub fn write_scalar_slice(&mut self, elems: &[F]) { - let mut buffer = self.buffer(); - for elem in elems { - serialize_canonical(*elem, &mut buffer).expect("TODO: propagate error"); - } + self.proof_size_event_wrapper(|buffer| { + for elem in elems { + serialize_canonical(*elem, &mut *buffer).expect("TODO: propagate error"); + } + }); } pub fn write_packed>(&mut self, packed: P) { @@ -378,6 +384,14 @@ impl TranscriptWriter<'_, B> { self.write_bytes(msg.as_bytes()) } } + + fn proof_size_event_wrapper(&mut self, f: F) { + let buffer = self.buffer(); + let start_bytes = buffer.remaining_mut(); + f(buffer); + let end_bytes = buffer.remaining_mut(); + tracing::event!(name: "proof_size", tracing::Level::INFO, counter=true, incremental=true, value=start_bytes - end_bytes); + } } impl CanSample for VerifierTranscript From 7ae9bad51ba75e69798e00873c62e6b02edd4ce2 Mon Sep 17 00:00:00 2001 From: Milos Backonja Date: Tue, 4 Feb 2025 15:02:49 +0100 Subject: [PATCH 11/50] [ci]: Setting Up GitHub Pipelines --- .github/workflows/benchmark.yml | 85 ++++++++++++++++++++ .github/workflows/ci.yml | 138 ++++++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+) create mode 100644 .github/workflows/benchmark.yml create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 000000000..f50554c1a --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,85 @@ +name: Nightly Benchmark + +on: + push: + branches: [ main ] + workflow_dispatch: + inputs: + ec2_instance_type: + description: 'Select EC2 instance type' + required: true + default: 'c7a-4xlarge' + type: choice + options: + - c7a-2xlarge + - c7a-4xlarge + - c8g-2xlarge + +permissions: + contents: write + checks: write + pull-requests: write + +jobs: + benchmark: + name: Continuous Benchmarking with Bencher + container: rustlang/rust:nightly + permissions: + checks: write + actions: write + runs-on: ${{ github.event_name == 'push' && github.ref_name == 'main' && 'c7a-4xlarge' || github.event.inputs.ec2_instance_type }} + steps: + - name: Checkout Private GitLab Repository # Will be replaced with actual repository + uses: actions/checkout@v4 + with: + repository: ulvetanna/binius + github-server-url: https://gitlab.com + ref: anexj/benchmark_script + ssh-key: ${{ secrets.GITLAB_SSH_KEY }} + ssh-known-hosts: | + gitlab.com ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCsj2bNKTBSpIYDEGk9KxsGh3mySTRgMtXL583qmBpzeQ+jqCMRgBqB98u3z++J1sKlXHWfM9dyhSevkMwSbhoR8XIq/U0tCNyokEi/ueaBMCvbcTHhO7FcwzY92WK4Yt0aGROY5qX2UKSeOvuP4D6TPqKF1onrSzH9bx9XUf2lEdWT/ia1NEKjunUqu1xOB/StKDHMoX4/OKyIzuS0q/T1zOATthvasJFoPrAjkohTyaDUz2LN5JoH839hViyEG82yB+MjcFV5MU3N1l1QL3cVUCh93xSaua1N85qivl+siMkPGbO5xR/En4iEY6K2XPASUEMaieWVNTRCtJ4S8H+9 + - name: Setup Bencher + uses: bencherdev/bencher@main + - name: Create Output Directory + run: mkdir output + - name: Execute Benchmark Tests + run: ./scripts/nightly_benchmarks.py --export-file output/result.json + - name: Track base branch benchmarks with Bencher + run: | + bencher run \ + --project ben \ + --token '${{ secrets.BENCHER_API_TOKEN }}' \ + --branch main \ + --testbed c7a-4xlarge \ + --threshold-measure latency \ + --threshold-test t_test \ + --threshold-max-sample-size 64 \ + --threshold-upper-boundary 0.99 \ + --thresholds-reset \ + --err \ + --adapter json \ + --github-actions '${{ secrets.GITHUB_TOKEN }}' \ + --file output/result.json + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: gh-pages + path: output/ + publish_results: + name: Publish Results to Github Page + needs: [benchmark] + runs-on: ubuntu-latest + steps: + - name: Download artifact + uses: actions/download-artifact@v4 + with: + name: gh-pages + - name: Deploy to GitHub Pages + uses: crazy-max/ghaction-github-pages@v4 + with: + repo: irreducibleoss/binius-benchmark + fqdn: benchmark.binius.xyz + target_branch: main + build_dir: ./ + env: + GITHUB_TOKEN: ${{ secrets.GH_TOKEN }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..940efa679 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,138 @@ +name: Rust CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +concurrency: + group: ${{ github.event_name }}-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + +jobs: + lint: + name: ${{ matrix.expand.name }} + runs-on: ${{ matrix.expand.runner }} + container: rustlang/rust:nightly + strategy: + matrix: + expand: + - runner: "ubuntu-latest" + name: "copyright-check" + cmd: "./scripts/check_copyright_notice.sh" + - runner: "ubuntu-latest" + name: "cargofmt" + cmd: "cargo fmt --check" + - runner: "ubuntu-latest" + name: "clippy" + cmd: "cargo clippy --all --all-features --tests --benches --examples -- -D warnings" + continue-on-error: true + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Run Command + run: ${{ matrix.expand.cmd }} + build: + name: build-${{ matrix.expand.name }} + needs: [lint] + runs-on: ${{ matrix.expand.runner }} + env: + RUST_VERSION: 1.83.0 + container: rustlang/rust:nightly + strategy: + matrix: + expand: + - runner: "c7a-2xlarge" + name: "debug-wasm" + cmd: "rustup target add wasm32-unknown-unknown && cargo build --package binius_field --target wasm32-unknown-unknown" + - runner: "c7a-2xlarge" + name: "debug-amd" + cmd: "cargo build --tests --benches --examples" + - runner: "c7a-2xlarge" + name: "debug-amd-no-default-features" + cmd: "cargo build --tests --benches --examples --no-default-features" + - runner: "c7a-2xlarge" + name: "debug-amd-stable" + cmd: "cargo +$RUST_VERSION build --tests --benches --examples -p binius_core --features stable_only" + - runner: "c8g-2xlarge" + name: "debug-arm" + cmd: "cargo build --tests --benches --examples" + - runner: "c7a-2xlarge" + name: "docs" + cmd: 'cargo doc --no-deps; echo "" > target/doc/index.html' + continue-on-error: true + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + - name: AMD job configuration template with stable Rust + if: ${{ matrix.expand.name == 'debug-amd-stable' }} + run: | + rustup set auto-self-update disable + rustup toolchain install $RUST_VERSION + - name: Run Command + run: ${{ matrix.expand.cmd }} + - name: Upload static files as artifact + if: ${{ matrix.expand.name == 'docs' }} + id: deployment + uses: actions/upload-pages-artifact@v3 + with: + path: "target/doc" + test: + name: unit-test-${{ matrix.expand.name }} + needs: [build] + runs-on: ${{ matrix.expand.runner }} + env: + RUST_VERSION: 1.83.0 + container: rustlang/rust:nightly + strategy: + matrix: + expand: + - runner: "c7a-2xlarge" + name: "amd" + cmd: 'RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh' + - runner: "c7a-2xlarge" + name: "amd-portable" + cmd: 'RUSTFLAGS="-C target-cpu=generic" ./scripts/run_tests_and_examples.sh' + - runner: "c7a-2xlarge" + name: "amd-stable" + cmd: 'RUSTFLAGS="-C target-cpu=native" CARGO_STABLE=true ./scripts/run_tests_and_examples.sh' + - runner: "c7a-2xlarge" + name: "single-threaded" + cmd: 'RAYON_NUM_THREADS=1 RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh' + - runner: "c7a-2xlarge" + name: "no-default-features" + cmd: 'CARGO_EXTRA_FLAGS="--no-default-features" RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh' + - runner: "c8g-2xlarge" + name: "arm" + cmd: 'RUSTFLAGS="-C target-cpu=native -C target-feature=+aes" ./scripts/run_tests_and_examples.sh' + - runner: "c8g-2xlarge" + name: "arm-portable" + cmd: 'RUSTFLAGS="-C target-cpu=generic" ./scripts/run_tests_and_examples.sh' + continue-on-error: true + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + - name: AMD job configuration template with stable Rust + if: ${{ matrix.expand.name == 'amd-stable' }} + run: | + rustup set auto-self-update disable + rustup toolchain install $RUST_VERSION + - name: Run Command + run: ${{ matrix.expand.cmd }} + deploy: + name: deploy-pages + needs: [build] + runs-on: ubuntu-latest + if: github.ref_name == 'main' + permissions: + pages: write + id-token: write + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 From 612e77a1800b36738c55529ce35be2472fcfd86d Mon Sep 17 00:00:00 2001 From: Milos Backonja <35807060+milosbackonja@users.noreply.github.com> Date: Thu, 6 Feb 2025 14:57:26 +0100 Subject: [PATCH 12/50] [ci]: Setting Up Mirror to GitLab (#8) --- .github/workflows/mirror.yml | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 .github/workflows/mirror.yml diff --git a/.github/workflows/mirror.yml b/.github/workflows/mirror.yml new file mode 100644 index 000000000..7d10296ee --- /dev/null +++ b/.github/workflows/mirror.yml @@ -0,0 +1,26 @@ +name: Mirror Repository + +on: + push: + branches: [ main ] + +permissions: + contents: read + +jobs: + mirror: + name: Mirror Repository to GitLab + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: shimataro/ssh-key-action@v2 + with: + key: ${{ secrets.GIT_SSH_PRIVATE_KEY }} + name: id_rsa + known_hosts: ${{ secrets.GIT_SSH_KNOWN_HOSTS }} + - name: Mirror current ref to GitLab + run: | + git remote add gitlab ssh://git@gitlab.com/IrreducibleOSS/binius.git + git push gitlab ${{ github.ref }} From 47081af6a161a3a9ccd2a59d5cf2023fe4a5031a Mon Sep 17 00:00:00 2001 From: chloefeal <188809157+chloefeal@users.noreply.github.com> Date: Fri, 7 Feb 2025 20:45:51 +0800 Subject: [PATCH 13/50] Fix typos (#2) [nicetohave] fix typos --- crates/core/src/lib.rs | 2 +- crates/core/src/protocols/sumcheck/error.rs | 2 +- crates/core/src/protocols/sumcheck/prove/zerocheck.rs | 2 +- crates/hash/src/groestl/hasher.rs | 2 +- crates/utils/src/thread_local_mut.rs | 2 +- scripts/run_tests_and_examples.sh | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index c7ebeab0e..4ff8e8ff0 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -7,7 +7,7 @@ //! performance, while verifier-side functions are optimized for auditability and security. // This is to silence clippy errors around suspicious usage of XOR -// in our arithmetic. This is safe to do becasue we're operating +// in our arithmetic. This is safe to do because we're operating // over binary fields. #![allow(clippy::suspicious_arithmetic_impl)] #![allow(clippy::suspicious_op_assign_impl)] diff --git a/crates/core/src/protocols/sumcheck/error.rs b/crates/core/src/protocols/sumcheck/error.rs index a0195e68b..848123a2a 100644 --- a/crates/core/src/protocols/sumcheck/error.rs +++ b/crates/core/src/protocols/sumcheck/error.rs @@ -45,7 +45,7 @@ pub enum Error { oracle: String, hypercube_index: usize, }, - #[error("constraint set containts multilinears of different heights")] + #[error("constraint set contains multilinears of different heights")] ConstraintSetNumberOfVariablesMismatch, #[error("batching sumchecks and zerochecks is not supported yet")] MixedBatchingNotSupported, diff --git a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs index 90df740ea..bc6cbd852 100644 --- a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs @@ -387,7 +387,7 @@ where // This is also regular multilinear zerocheck constructor, but "jump started" in round // `skip_rounds` while using witness with a projected univariate round. - // NB: first round evaluator has to be overriden due to issues proving + // NB: first round evaluator has to be overridden due to issues proving // `P: RepackedExtension

` relation in the generic context, as well as the need // to use later round evaluator (as this _is_ a "later" round, albeit numbered at zero) let regular_prover = ZerocheckProver::new( diff --git a/crates/hash/src/groestl/hasher.rs b/crates/hash/src/groestl/hasher.rs index 74f9e49c9..ae8f3144c 100644 --- a/crates/hash/src/groestl/hasher.rs +++ b/crates/hash/src/groestl/hasher.rs @@ -300,7 +300,7 @@ mod tests { } #[test] - fn test_aes_binary_convertion() { + fn test_aes_binary_conversion() { let mut rng = thread_rng(); let input_aes: [PackedAESBinaryField32x8b; 90] = array::from_fn(|_| PackedAESBinaryField32x8b::random(&mut rng)); diff --git a/crates/utils/src/thread_local_mut.rs b/crates/utils/src/thread_local_mut.rs index 0f196c5b2..9bf80ad32 100644 --- a/crates/utils/src/thread_local_mut.rs +++ b/crates/utils/src/thread_local_mut.rs @@ -6,7 +6,7 @@ use thread_local::ThreadLocal; /// Creates a "scratch space" within each thread with mutable access. /// -/// This is mainly meant to be used as an optimization to avoid unneccesary allocs/frees within rayon code. +/// This is mainly meant to be used as an optimization to avoid unnecessary allocs/frees within rayon code. /// You only pay for allocation of this scratch space once per thread. /// /// Since the space is local to each thread you also don't have to worry about atomicity. diff --git a/scripts/run_tests_and_examples.sh b/scripts/run_tests_and_examples.sh index 85b0ffec2..70da6e396 100755 --- a/scripts/run_tests_and_examples.sh +++ b/scripts/run_tests_and_examples.sh @@ -19,7 +19,7 @@ if [ -z "$CARGO_STABLE" ]; then # Execute examples. # Unfortunately there cargo doesn't support executing all examples with a single command. - # Cargo plugins such as "cargo-examples" do suport it but without a possibility to specify "release" profile. + # Cargo plugins such as "cargo-examples" do support it but without a possibility to specify "release" profile. for example in examples/*.rs do cargo run --profile $CARGO_PROFILE --example "$(basename "${example%.rs}")" $CARGO_EXTRA_FLAGS From 08b11a94c026e0e995e23d66f57dacb9a731efc1 Mon Sep 17 00:00:00 2001 From: Milos Backonja <35807060+milosbackonja@users.noreply.github.com> Date: Mon, 10 Feb 2025 18:00:36 +0100 Subject: [PATCH 14/50] [ci]: Improvements (#17) [ci]: Removing continue or error, and depricating Gitlab pipelines --- .github/workflows/ci.yml | 3 - .gitlab-ci.yml | 203 --------------------------------------- 2 files changed, 206 deletions(-) delete mode 100644 .gitlab-ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 940efa679..3341a164f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,7 +27,6 @@ jobs: - runner: "ubuntu-latest" name: "clippy" cmd: "cargo clippy --all --all-features --tests --benches --examples -- -D warnings" - continue-on-error: true steps: - name: Checkout Repository uses: actions/checkout@v4 @@ -62,7 +61,6 @@ jobs: - runner: "c7a-2xlarge" name: "docs" cmd: 'cargo doc --no-deps; echo "" > target/doc/index.html' - continue-on-error: true steps: - name: Checkout Repository uses: actions/checkout@v4 @@ -110,7 +108,6 @@ jobs: - runner: "c8g-2xlarge" name: "arm-portable" cmd: 'RUSTFLAGS="-C target-cpu=generic" ./scripts/run_tests_and_examples.sh' - continue-on-error: true steps: - name: Checkout Repository uses: actions/checkout@v4 diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index d2a03abf4..000000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,203 +0,0 @@ -workflow: - rules: - - if: $CI_PIPELINE_SOURCE == "merge_request_event" - - if: $CI_COMMIT_BRANCH == 'main' - -variables: - CARGO_HOME: "$CI_PROJECT_DIR/toolchains/cargo" - RUSTUP_HOME: "$CI_PROJECT_DIR/toolchains" - GIT_CLEAN_FLAGS: "-ffdx --exclude toolchains" - FF_TIMESTAMPS: true - - -stages: - - lint - - build - - test - - deploy - -# AMD job configuration template -.job_template_amd: - image: rustlang/rust:nightly - variables: - KUBERNETES_NODE_SELECTOR_INSTANCE_TYPE: "ulvt-node-pool=ulvt-c7i-2xlarge" - KUBERNETES_CPU_REQUEST: "6" - KUBERNETES_MEMORY_REQUEST: "14Gi" - GIT_CLONE_PATH: "$CI_BUILDS_DIR/binius_amd" - tags: - - k8s - -# AMD job configuration template with stable Rust -.job_template_amd_stable: - extends: .test_job_template_amd - variables: - RUST_VERSION: "1.83.0" - before_script: - # workaround for https://github.com/rust-lang/rustup/issues/2886 - - rustup set auto-self-update disable - - rustup toolchain install $RUST_VERSION - -# ARM job configuration template -.job_template_arm: - image: rustlang/rust:nightly - variables: - KUBERNETES_NODE_SELECTOR_INSTANCE_TYPE: "ulvt-node-pool=ulvt-c8g-2xlarge" - KUBERNETES_NODE_SELECTOR_ARCH: 'kubernetes.io/arch=arm64' - KUBERNETES_CPU_REQUEST: "6" - KUBERNETES_MEMORY_REQUEST: "14Gi" - GIT_CLONE_PATH: "$CI_BUILDS_DIR/binius_arm" - before_script: - - if [ "$(uname -m)" != "aarch64" ]; then echo "This job is intended to run on ARM architecture only."; exit 1; fi - tags: - - k8s - -# Linting jobs -copyright-check: - extends: .job_template_amd - stage: lint - script: - - ./scripts/check_copyright_notice.sh - -cargofmt: - extends: .job_template_amd - stage: lint - script: - - cargo fmt --check - -clippy: - extends: .job_template_amd - stage: lint - script: - - cargo clippy --all --all-features --tests --benches --examples -- -D warnings - -# Building jobs - -# TODO: use a docker image with `wasm32-unknown-unknown` target preinstalled -build-debug-wasm: - extends: .job_template_amd - stage: build - script: - - rustup target add wasm32-unknown-unknown - - cargo build --package binius_field --target wasm32-unknown-unknown - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -build-debug-amd: - extends: .job_template_amd - stage: build - script: - - cargo build --tests --benches --examples - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -# Build without default features -# This checks if build without `rayon` feature works. -build-debug-amd-no-default-features: - extends: .job_template_amd - stage: build - script: - - cargo build --tests --benches --examples --no-default-features - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -build-debug-amd-stable: - extends: .job_template_amd_stable - stage: build - script: - - cargo +$RUST_VERSION build --tests --benches --examples -p binius_core --features stable_only - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -build-debug-arm: - extends: .job_template_arm - stage: build - script: - - cargo build --tests --benches --examples - artifacts: - paths: - - Cargo.lock - expire_in: 1 day - -.test_job_template_amd: - extends: .job_template_amd - dependencies: - - build-debug-amd - -.test_job_template_amd_stable: - extends: .job_template_amd_stable - dependencies: - - build-debug-amd-stable - -.test_job_template_arm: - extends: .job_template_arm - dependencies: - - build-debug-arm - -unit-test-amd-portable: - extends: .test_job_template_amd - script: - - RUSTFLAGS="-C target-cpu=generic" ./scripts/run_tests_and_examples.sh - -unit-test-arm-portable: - extends: .test_job_template_arm - script: - - RUSTFLAGS="-C target-cpu=generic" ./scripts/run_tests_and_examples.sh - -unit-test-single-threaded: - extends: .test_job_template_amd - script: - - RAYON_NUM_THREADS=1 RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh - -unit-test-no-default-features: - extends: .test_job_template_amd - script: - - CARGO_EXTRA_FLAGS="--no-default-features" RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh - -unit-test-amd: - extends: .test_job_template_amd - script: - - RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh - -unit-test-amd-stable: - extends: .test_job_template_amd_stable - script: - - RUSTFLAGS="-C target-cpu=native" CARGO_STABLE=true ./scripts/run_tests_and_examples.sh - -unit-test-arm: - extends: .test_job_template_arm - script: - - RUSTFLAGS="-C target-cpu=native -C target-feature=+aes" ./scripts/run_tests_and_examples.sh - -# Documentation and pages jobs -build-docs: - extends: .job_template_amd - stage: build - script: - - cargo doc --no-deps - artifacts: - paths: - - target/doc - expire_in: 1 week - -pages: - extends: .job_template_amd - stage: deploy - dependencies: - - build-docs - script: - - mv target/doc public - - echo "/ /binius_core 302" > public/_redirects - artifacts: - paths: - - public - only: - refs: - - main # Deploy for every push to the main branch, for now From e4c41eda1598296763430ee67773d818c6b0ea36 Mon Sep 17 00:00:00 2001 From: Dmytro Gordon Date: Wed, 12 Feb 2025 14:02:37 +0200 Subject: [PATCH 15/50] Improve test compilation time (#10) Co-authored-by: Dmytro Gordon This MR addresses tow issues that make cargo test slow: Thin LTO slows down compilation of all the crates a bit. It takes quite a time to compile and link all the examples with test profile which are not actually executed. So I've added an alias to compile and run tests only for fast local usage. --- .cargo/config.toml | 3 +++ Cargo.toml | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.cargo/config.toml b/.cargo/config.toml index 7acd84b05..6287bb569 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,2 +1,5 @@ [build] rustdocflags = ["-Dwarnings", "--html-in-header", "doc/katex-header.html"] + +[alias] +fast_test = "test --tests" diff --git a/Cargo.toml b/Cargo.toml index 5c6ff5406..adf326766 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -137,4 +137,4 @@ opt-level = 1 debug = true debug-assertions = true overflow-checks = true -lto = false +lto = "off" From 33f94b8e0ac17a70bb66429ebb936a7343f6a570 Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Wed, 12 Feb 2025 13:04:37 +0100 Subject: [PATCH 16/50] [serialization] Add canonical serialize/deserialize traits + derive macros Introduces the following traits: SerializeCanonical (which replaces most uses of SerializeBytes) DeserializeCanonical (which replaces most uses of DeserializeBytes) Conveniently, this also comes with proc-macros for deriving these traits for an arbitrary struct/enum (unions are not supported). --- Cargo.toml | 2 +- crates/core/Cargo.toml | 1 + crates/core/src/constraint_system/channel.rs | 28 +-- crates/core/src/constraint_system/prove.rs | 7 +- crates/core/src/constraint_system/verify.rs | 7 +- .../src/merkle_tree/binary_merkle_tree.rs | 4 +- crates/core/src/merkle_tree/scheme.rs | 7 +- crates/core/src/piop/prove.rs | 8 +- crates/core/src/piop/tests.rs | 8 +- crates/core/src/piop/verify.rs | 8 +- crates/core/src/protocols/fri/prove.rs | 8 +- crates/core/src/protocols/fri/verify.rs | 6 +- crates/core/src/ring_switch/tests.rs | 6 +- crates/core/src/transcript/error.rs | 2 +- crates/core/src/transcript/mod.rs | 30 +-- crates/field/Cargo.toml | 1 + crates/field/src/aes_field.rs | 40 ++- crates/field/src/binary_field.rs | 18 +- crates/field/src/lib.rs | 3 + crates/field/src/polyval.rs | 15 +- .../src/serialization/bytes.rs} | 8 +- crates/field/src/serialization/canonical.rs | 234 ++++++++++++++++++ crates/field/src/serialization/error.rs | 13 + crates/field/src/serialization/mod.rs | 9 + crates/hash/src/serialization.rs | 8 +- crates/macros/src/lib.rs | 227 ++++++++++++++++- crates/utils/Cargo.toml | 1 - crates/utils/src/lib.rs | 1 - 28 files changed, 569 insertions(+), 141 deletions(-) rename crates/{utils/src/serialization.rs => field/src/serialization/bytes.rs} (89%) create mode 100644 crates/field/src/serialization/canonical.rs create mode 100644 crates/field/src/serialization/error.rs create mode 100644 crates/field/src/serialization/mod.rs diff --git a/Cargo.toml b/Cargo.toml index adf326766..b6916ac36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -112,7 +112,7 @@ seq-macro = "0.3.5" sha2 = "0.10.8" stackalloc = "1.2.1" subtle = "2.5.0" -syn = { version = "2.0.60", features = ["full"] } +syn = { version = "2.0.98", features = ["extra-traits"] } thiserror = "2.0.3" thread_local = "1.1.7" tiny-keccak = { version = "2.0.2", features = ["keccak"] } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 113cac7a9..e0033c614 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -10,6 +10,7 @@ workspace = true [dependencies] assert_matches.workspace = true auto_impl.workspace = true +binius_macros = { path = "../macros" } binius_field = { path = "../field" } binius_hal = { path = "../hal" } binius_hash = { path = "../hash" } diff --git a/crates/core/src/constraint_system/channel.rs b/crates/core/src/constraint_system/channel.rs index 93a28cb09..e73d83067 100644 --- a/crates/core/src/constraint_system/channel.rs +++ b/crates/core/src/constraint_system/channel.rs @@ -52,10 +52,10 @@ use std::collections::HashMap; use binius_field::{as_packed_field::PackScalar, underlier::UnderlierType, TowerField}; -use bytes::BufMut; +use binius_macros::{DeserializeCanonical, SerializeCanonical}; use super::error::{Error, VerificationError}; -use crate::{oracle::OracleId, transcript::TranscriptWriter, witness::MultilinearExtensionIndex}; +use crate::{oracle::OracleId, witness::MultilinearExtensionIndex}; pub type ChannelId = usize; @@ -68,7 +68,7 @@ pub struct Flush { pub multiplicity: u64, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] pub struct Boundary { pub values: Vec, pub channel_id: ChannelId, @@ -76,7 +76,7 @@ pub struct Boundary { pub multiplicity: u64, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, SerializeCanonical, DeserializeCanonical)] pub enum FlushDirection { Push, Pull, @@ -220,26 +220,6 @@ impl Channel { } } -impl Boundary { - pub fn write_to(&self, writer: &mut TranscriptWriter) { - writer.buffer().put_u64(self.values.len() as u64); - writer.write_slice( - &self - .values - .iter() - .copied() - .map(F::Canonical::from) - .collect::>(), - ); - writer.buffer().put_u64(self.channel_id as u64); - writer.buffer().put_u64(self.multiplicity); - writer.buffer().put_u64(match self.direction { - FlushDirection::Pull => 0, - FlushDirection::Push => 1, - }); - } -} - #[cfg(test)] mod tests { use binius_field::BinaryField64b; diff --git a/crates/core/src/constraint_system/prove.rs b/crates/core/src/constraint_system/prove.rs index ef167b671..2a8be85fa 100644 --- a/crates/core/src/constraint_system/prove.rs +++ b/crates/core/src/constraint_system/prove.rs @@ -104,12 +104,7 @@ where let fast_domain_factory = IsomorphicEvaluationDomainFactory::>::default(); let mut transcript = ProverTranscript::::new(); - { - let mut observer = transcript.observe(); - for boundary in boundaries { - boundary.write_to(&mut observer); - } - } + transcript.observe().write_slice(boundaries); let ConstraintSystem { mut oracles, diff --git a/crates/core/src/constraint_system/verify.rs b/crates/core/src/constraint_system/verify.rs index 50c528b72..57d7822f6 100644 --- a/crates/core/src/constraint_system/verify.rs +++ b/crates/core/src/constraint_system/verify.rs @@ -75,12 +75,7 @@ where let Proof { transcript } = proof; let mut transcript = VerifierTranscript::::new(transcript); - { - let mut observer = transcript.observe(); - for boundary in boundaries { - boundary.write_to(&mut observer); - } - } + transcript.observe().write_slice(boundaries); let merkle_scheme = BinaryMerkleTreeScheme::<_, Hash, _>::new(Compress::default()); let (commit_meta, oracle_to_commit_index) = piop::make_oracle_commit_meta(&oracles)?; diff --git a/crates/core/src/merkle_tree/binary_merkle_tree.rs b/crates/core/src/merkle_tree/binary_merkle_tree.rs index f15d689be..7cde01b8d 100644 --- a/crates/core/src/merkle_tree/binary_merkle_tree.rs +++ b/crates/core/src/merkle_tree/binary_merkle_tree.rs @@ -2,7 +2,7 @@ use std::{array, fmt::Debug, mem::MaybeUninit}; -use binius_field::{serialize_canonical, TowerField}; +use binius_field::{SerializeCanonical, TowerField}; use binius_hash::{HashBuffer, PseudoCompressionFunction}; use binius_maybe_rayon::{prelude::*, slice::ParallelSlice}; use binius_utils::{bail, checked_arithmetics::log2_strict_usize}; @@ -210,7 +210,7 @@ where { let mut hash_buffer = HashBuffer::new(hasher); for elem in elems { - serialize_canonical(elem, &mut hash_buffer) + SerializeCanonical::serialize_canonical(&elem, &mut hash_buffer) .expect("HashBuffer has infinite capacity"); } } diff --git a/crates/core/src/merkle_tree/scheme.rs b/crates/core/src/merkle_tree/scheme.rs index c29940cf1..3d27cc32c 100644 --- a/crates/core/src/merkle_tree/scheme.rs +++ b/crates/core/src/merkle_tree/scheme.rs @@ -2,7 +2,7 @@ use std::{array, fmt::Debug, marker::PhantomData}; -use binius_field::{serialize_canonical, TowerField}; +use binius_field::{SerializeCanonical, TowerField}; use binius_hash::{HashBuffer, PseudoCompressionFunction}; use binius_utils::{ bail, @@ -178,8 +178,9 @@ where let mut hasher = H::new(); { let mut buffer = HashBuffer::new(&mut hasher); - for &elem in elems { - serialize_canonical(elem, &mut buffer).expect("HashBuffer has infinite capacity"); + for elem in elems { + SerializeCanonical::serialize_canonical(elem, &mut buffer) + .expect("HashBuffer has infinite capacity"); } } hasher.finalize() diff --git a/crates/core/src/piop/prove.rs b/crates/core/src/piop/prove.rs index 25ea62296..49a24066a 100644 --- a/crates/core/src/piop/prove.rs +++ b/crates/core/src/piop/prove.rs @@ -2,7 +2,7 @@ use binius_field::{ packed::set_packed_slice, BinaryField, ExtensionField, Field, PackedExtension, PackedField, - PackedFieldIndexable, TowerField, + PackedFieldIndexable, SerializeCanonical, TowerField, }; use binius_hal::ComputationBackend; use binius_math::{ @@ -10,7 +10,7 @@ use binius_math::{ }; use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*}; use binius_ntt::{NTTOptions, ThreadingSettings}; -use binius_utils::{bail, serialization::SerializeBytes, sorting::is_sorted_ascending}; +use binius_utils::{bail, sorting::is_sorted_ascending}; use either::Either; use itertools::{chain, Itertools}; @@ -175,7 +175,7 @@ where + PackedExtension, M: MultilinearPoly

+ Send + Sync, DomainFactory: EvaluationDomainFactory, - MTScheme: MerkleTreeScheme, + MTScheme: MerkleTreeScheme, MTProver: MerkleTreeProver, Challenger_: Challenger, Backend: ComputationBackend, @@ -254,7 +254,7 @@ where F: TowerField + ExtensionField, FEncode: BinaryField, P: PackedFieldIndexable + PackedExtension, - MTScheme: MerkleTreeScheme, + MTScheme: MerkleTreeScheme, MTProver: MerkleTreeProver, Challenger_: Challenger, { diff --git a/crates/core/src/piop/tests.rs b/crates/core/src/piop/tests.rs index 69f8a0548..1c9377a82 100644 --- a/crates/core/src/piop/tests.rs +++ b/crates/core/src/piop/tests.rs @@ -3,15 +3,15 @@ use std::iter::repeat_with; use binius_field::{ - BinaryField, BinaryField16b, BinaryField8b, ExtensionField, Field, PackedBinaryField2x128b, - PackedExtension, PackedField, PackedFieldIndexable, TowerField, + BinaryField, BinaryField16b, BinaryField8b, DeserializeCanonical, ExtensionField, Field, + PackedBinaryField2x128b, PackedExtension, PackedField, PackedFieldIndexable, + SerializeCanonical, TowerField, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::{ DefaultEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly, }; -use binius_utils::serialization::{DeserializeBytes, SerializeBytes}; use groestl_crypto::Groestl256; use rand::{rngs::StdRng, Rng, SeedableRng}; @@ -111,7 +111,7 @@ fn commit_prove_verify( + PackedExtension + PackedExtension + PackedExtension, - MTScheme: MerkleTreeScheme, + MTScheme: MerkleTreeScheme, { let merkle_scheme = merkle_prover.scheme(); diff --git a/crates/core/src/piop/verify.rs b/crates/core/src/piop/verify.rs index c7618538d..8761a0a0b 100644 --- a/crates/core/src/piop/verify.rs +++ b/crates/core/src/piop/verify.rs @@ -2,10 +2,10 @@ use std::{borrow::Borrow, cmp::Ordering, iter, ops::Range}; -use binius_field::{BinaryField, ExtensionField, Field, TowerField}; +use binius_field::{BinaryField, DeserializeCanonical, ExtensionField, Field, TowerField}; use binius_math::evaluate_piecewise_multilinear; use binius_ntt::NTTOptions; -use binius_utils::{bail, serialization::DeserializeBytes}; +use binius_utils::bail; use getset::CopyGetters; use tracing::instrument; @@ -291,7 +291,7 @@ where F: TowerField + ExtensionField, FEncode: BinaryField, Challenger_: Challenger, - MTScheme: MerkleTreeScheme, + MTScheme: MerkleTreeScheme, { // Map of n_vars to sumcheck claim descriptions let sumcheck_claim_descs = make_sumcheck_claim_descs( @@ -412,7 +412,7 @@ where F: TowerField + ExtensionField, FEncode: BinaryField, Challenger_: Challenger, - MTScheme: MerkleTreeScheme, + MTScheme: MerkleTreeScheme, { let mut arities_iter = fri_params.fold_arities().iter(); let mut fri_commitments = Vec::with_capacity(fri_params.n_oracles()); diff --git a/crates/core/src/protocols/fri/prove.rs b/crates/core/src/protocols/fri/prove.rs index ba07ebb72..287eef45f 100644 --- a/crates/core/src/protocols/fri/prove.rs +++ b/crates/core/src/protocols/fri/prove.rs @@ -1,9 +1,11 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField, TowerField}; +use binius_field::{ + BinaryField, ExtensionField, PackedExtension, PackedField, SerializeCanonical, TowerField, +}; use binius_hal::{make_portable_backend, ComputationBackend}; use binius_maybe_rayon::prelude::*; -use binius_utils::{bail, serialization::SerializeBytes}; +use binius_utils::bail; use bytemuck::zeroed_vec; use bytes::BufMut; use itertools::izip; @@ -285,7 +287,7 @@ where F: TowerField + ExtensionField, FA: BinaryField, MerkleProver: MerkleTreeProver, - VCS: MerkleTreeScheme, + VCS: MerkleTreeScheme, { /// Constructs a new folder. pub fn new( diff --git a/crates/core/src/protocols/fri/verify.rs b/crates/core/src/protocols/fri/verify.rs index 0abc46044..69c22b88d 100644 --- a/crates/core/src/protocols/fri/verify.rs +++ b/crates/core/src/protocols/fri/verify.rs @@ -2,9 +2,9 @@ use std::iter; -use binius_field::{BinaryField, ExtensionField, TowerField}; +use binius_field::{BinaryField, DeserializeCanonical, ExtensionField, TowerField}; use binius_hal::{make_portable_backend, ComputationBackend}; -use binius_utils::{bail, serialization::DeserializeBytes}; +use binius_utils::bail; use bytes::Buf; use itertools::izip; use tracing::instrument; @@ -44,7 +44,7 @@ impl<'a, F, FA, VCS> FRIVerifier<'a, F, FA, VCS> where F: TowerField + ExtensionField, FA: BinaryField, - VCS: MerkleTreeScheme, + VCS: MerkleTreeScheme, { #[allow(clippy::too_many_arguments)] pub fn new( diff --git a/crates/core/src/ring_switch/tests.rs b/crates/core/src/ring_switch/tests.rs index be7d3fafc..b8c361cff 100644 --- a/crates/core/src/ring_switch/tests.rs +++ b/crates/core/src/ring_switch/tests.rs @@ -6,7 +6,8 @@ use binius_field::{ arch::OptimalUnderlier128b, as_packed_field::{PackScalar, PackedType}, underlier::UnderlierType, - ExtensionField, Field, PackedField, PackedFieldIndexable, TowerField, + DeserializeCanonical, ExtensionField, Field, PackedField, PackedFieldIndexable, + SerializeCanonical, TowerField, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; @@ -14,7 +15,6 @@ use binius_math::{ DefaultEvaluationDomainFactory, MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly, MultilinearQuery, }; -use binius_utils::serialization::{DeserializeBytes, SerializeBytes}; use groestl_crypto::Groestl256; use rand::prelude::*; @@ -269,7 +269,7 @@ fn commit_prove_verify_piop( Tower: TowerFamily, PackedType>: PackedFieldIndexable, FExt: PackedTop, - MTScheme: MerkleTreeScheme, Digest: SerializeBytes + DeserializeBytes>, + MTScheme: MerkleTreeScheme, Digest: SerializeCanonical + DeserializeCanonical>, MTProver: MerkleTreeProver, Scheme = MTScheme>, { let mut rng = StdRng::seed_from_u64(0); diff --git a/crates/core/src/transcript/error.rs b/crates/core/src/transcript/error.rs index 68a5eefd7..97b6754fd 100644 --- a/crates/core/src/transcript/error.rs +++ b/crates/core/src/transcript/error.rs @@ -1,6 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_utils::serialization::Error as SerializationError; +use binius_field::serialization::Error as SerializationError; #[derive(Debug, thiserror::Error)] pub enum Error { diff --git a/crates/core/src/transcript/mod.rs b/crates/core/src/transcript/mod.rs index 17d6ec021..08089985f 100644 --- a/crates/core/src/transcript/mod.rs +++ b/crates/core/src/transcript/mod.rs @@ -16,8 +16,7 @@ mod error; use std::{iter::repeat_with, slice}; -use binius_field::{deserialize_canonical, serialize_canonical, PackedField, TowerField}; -use binius_utils::serialization::{DeserializeBytes, SerializeBytes}; +use binius_field::{DeserializeCanonical, PackedField, SerializeCanonical, TowerField}; use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut}; pub use error::Error; use tracing::warn; @@ -258,13 +257,13 @@ impl TranscriptReader<'_, B> { self.buffer } - pub fn read(&mut self) -> Result { - T::deserialize(self.buffer()).map_err(Into::into) + pub fn read(&mut self) -> Result { + T::deserialize_canonical(self.buffer()).map_err(Into::into) } - pub fn read_vec(&mut self, n: usize) -> Result, Error> { + pub fn read_vec(&mut self, n: usize) -> Result, Error> { let mut buffer = self.buffer(); - repeat_with(move || T::deserialize(&mut buffer).map_err(Into::into)) + repeat_with(move || T::deserialize_canonical(&mut buffer).map_err(Into::into)) .take(n) .collect() } @@ -287,7 +286,7 @@ impl TranscriptReader<'_, B> { pub fn read_scalar_slice_into(&mut self, buf: &mut [F]) -> Result<(), Error> { let mut buffer = self.buffer(); for elem in buf { - *elem = deserialize_canonical(&mut buffer)?; + *elem = DeserializeCanonical::deserialize_canonical(&mut buffer)?; } Ok(()) } @@ -333,17 +332,19 @@ impl TranscriptWriter<'_, B> { self.buffer } - pub fn write(&mut self, value: &T) { + pub fn write(&mut self, value: &T) { self.proof_size_event_wrapper(|buffer| { - value.serialize(buffer).expect("TODO: propagate error"); + value + .serialize_canonical(buffer) + .expect("TODO: propagate error"); }); } - pub fn write_slice(&mut self, values: &[T]) { + pub fn write_slice(&mut self, values: &[T]) { self.proof_size_event_wrapper(|buffer| { for value in values { value - .serialize(&mut *buffer) + .serialize_canonical(&mut *buffer) .expect("TODO: propagate error"); } }); @@ -362,7 +363,8 @@ impl TranscriptWriter<'_, B> { pub fn write_scalar_slice(&mut self, elems: &[F]) { self.proof_size_event_wrapper(|buffer| { for elem in elems { - serialize_canonical(*elem, &mut *buffer).expect("TODO: propagate error"); + SerializeCanonical::serialize_canonical(elem, &mut *buffer) + .expect("TODO: propagate error"); } }); } @@ -400,7 +402,7 @@ where Challenger_: Challenger, { fn sample(&mut self) -> F { - deserialize_canonical(self.combined.challenger.sampler()) + DeserializeCanonical::deserialize_canonical(self.combined.challenger.sampler()) .expect("challenger has infinite buffer") } } @@ -411,7 +413,7 @@ where Challenger_: Challenger, { fn sample(&mut self) -> F { - deserialize_canonical(self.combined.challenger.sampler()) + DeserializeCanonical::deserialize_canonical(self.combined.challenger.sampler()) .expect("challenger has infinite buffer") } } diff --git a/crates/field/Cargo.toml b/crates/field/Cargo.toml index 36de13de8..7b0a1b421 100644 --- a/crates/field/Cargo.toml +++ b/crates/field/Cargo.toml @@ -14,6 +14,7 @@ bytemuck.workspace = true bytes.workspace = true cfg-if.workspace = true derive_more.workspace = true +generic-array.workspace = true rand.workspace = true seq-macro.workspace = true subtle.workspace = true diff --git a/crates/field/src/aes_field.rs b/crates/field/src/aes_field.rs index ab650f4cd..61e261de0 100644 --- a/crates/field/src/aes_field.rs +++ b/crates/field/src/aes_field.rs @@ -288,10 +288,10 @@ mod tests { use super::*; use crate::{ - binary_field::tests::is_binary_field_valid_generator, deserialize_canonical, - serialize_canonical, underlier::WithUnderlier, PackedAESBinaryField16x32b, - PackedAESBinaryField4x32b, PackedAESBinaryField8x32b, PackedBinaryField16x32b, - PackedBinaryField4x32b, PackedBinaryField8x32b, + binary_field::tests::is_binary_field_valid_generator, underlier::WithUnderlier, + DeserializeCanonical, PackedAESBinaryField16x32b, PackedAESBinaryField4x32b, + PackedAESBinaryField8x32b, PackedBinaryField16x32b, PackedBinaryField4x32b, + PackedBinaryField8x32b, SerializeCanonical, }; fn check_square(f: impl Field) { @@ -590,28 +590,22 @@ mod tests { let aes64 = ::random(&mut rng); let aes128 = ::random(&mut rng); - serialize_canonical(aes8, &mut buffer).unwrap(); - serialize_canonical(aes16, &mut buffer).unwrap(); - serialize_canonical(aes32, &mut buffer).unwrap(); - serialize_canonical(aes64, &mut buffer).unwrap(); - serialize_canonical(aes128, &mut buffer).unwrap(); + SerializeCanonical::serialize_canonical(&aes8, &mut buffer).unwrap(); + SerializeCanonical::serialize_canonical(&aes16, &mut buffer).unwrap(); + SerializeCanonical::serialize_canonical(&aes32, &mut buffer).unwrap(); + SerializeCanonical::serialize_canonical(&aes64, &mut buffer).unwrap(); + SerializeCanonical::serialize_canonical(&aes128, &mut buffer).unwrap(); - serialize_canonical(aes128, &mut buffer).unwrap(); + SerializeCanonical::serialize_canonical(&aes128, &mut buffer).unwrap(); let mut read_buffer = buffer.freeze(); - assert_eq!(deserialize_canonical::(&mut read_buffer).unwrap(), aes8); - assert_eq!(deserialize_canonical::(&mut read_buffer).unwrap(), aes16); - assert_eq!(deserialize_canonical::(&mut read_buffer).unwrap(), aes32); - assert_eq!(deserialize_canonical::(&mut read_buffer).unwrap(), aes64); - assert_eq!( - deserialize_canonical::(&mut read_buffer).unwrap(), - aes128 - ); - - assert_eq!( - deserialize_canonical::(&mut read_buffer).unwrap(), - aes128.into() - ) + assert_eq!(AESTowerField8b::deserialize_canonical(&mut read_buffer).unwrap(), aes8); + assert_eq!(AESTowerField16b::deserialize_canonical(&mut read_buffer).unwrap(), aes16); + assert_eq!(AESTowerField32b::deserialize_canonical(&mut read_buffer).unwrap(), aes32); + assert_eq!(AESTowerField64b::deserialize_canonical(&mut read_buffer).unwrap(), aes64); + assert_eq!(AESTowerField128b::deserialize_canonical(&mut read_buffer).unwrap(), aes128); + + assert_eq!(BinaryField128b::deserialize_canonical(&mut read_buffer).unwrap(), aes128.into()) } } diff --git a/crates/field/src/binary_field.rs b/crates/field/src/binary_field.rs index f38a43739..0707d04b6 100644 --- a/crates/field/src/binary_field.rs +++ b/crates/field/src/binary_field.rs @@ -7,7 +7,6 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; -use binius_utils::serialization::{DeserializeBytes, Error as SerializationError, SerializeBytes}; use bytemuck::{Pod, Zeroable}; use bytes::{Buf, BufMut}; use rand::RngCore; @@ -17,6 +16,7 @@ use super::{ binary_field_arithmetic::TowerFieldArithmetic, error::Error, extension::ExtensionField, }; use crate::{ + serialization::{DeserializeBytes, Error as SerializationError, SerializeBytes}, underlier::{SmallU, U1, U2, U4}, Field, }; @@ -766,22 +766,6 @@ serialize_deserialize!(BinaryField32b, u32); serialize_deserialize!(BinaryField64b, u64); serialize_deserialize!(BinaryField128b, u128); -/// Serializes a [`TowerField`] element to a byte buffer with a canonical encoding. -pub fn serialize_canonical( - elem: F, - mut writer: W, -) -> Result<(), SerializationError> { - F::Canonical::from(elem).serialize(&mut writer) -} - -/// Deserializes a [`TowerField`] element from a byte buffer with a canonical encoding. -pub fn deserialize_canonical( - mut reader: R, -) -> Result { - let as_canonical = F::Canonical::deserialize(&mut reader)?; - Ok(F::from(as_canonical)) -} - impl From for Choice { fn from(val: BinaryField1b) -> Self { Self::from(val.val().val()) diff --git a/crates/field/src/lib.rs b/crates/field/src/lib.rs index 76414c038..f09a44b53 100644 --- a/crates/field/src/lib.rs +++ b/crates/field/src/lib.rs @@ -33,6 +33,7 @@ pub mod packed_extension; pub mod packed_extension_ops; mod packed_polyval; pub mod polyval; +pub mod serialization; #[cfg(test)] mod tests; pub mod tower_levels; @@ -44,6 +45,7 @@ pub mod util; pub use aes_field::*; pub use arch::byte_sliced::*; pub use binary_field::*; +pub use bytes; pub use error::*; pub use extension::*; pub use field::Field; @@ -54,4 +56,5 @@ pub use packed_extension::*; pub use packed_extension_ops::*; pub use packed_polyval::*; pub use polyval::*; +pub use serialization::{DeserializeCanonical, SerializeCanonical}; pub use transpose::{square_transpose, transpose_scalars, Error as TransposeError}; diff --git a/crates/field/src/polyval.rs b/crates/field/src/polyval.rs index 7b535fbbf..0337ddede 100644 --- a/crates/field/src/polyval.rs +++ b/crates/field/src/polyval.rs @@ -1035,11 +1035,10 @@ mod tests { packed_polyval_512::PackedBinaryPolyval4x128b, }, binary_field::tests::is_binary_field_valid_generator, - deserialize_canonical, linear_transformation::PackedTransformationFactory, - serialize_canonical, AESTowerField128b, PackedAESBinaryField1x128b, + AESTowerField128b, DeserializeCanonical, PackedAESBinaryField1x128b, PackedAESBinaryField2x128b, PackedAESBinaryField4x128b, PackedBinaryField1x128b, - PackedBinaryField2x128b, PackedBinaryField4x128b, PackedField, + PackedBinaryField2x128b, PackedBinaryField4x128b, PackedField, SerializeCanonical, }; #[test] @@ -1188,19 +1187,17 @@ mod tests { let b128_poly1 = ::random(&mut rng); let b128_poly2 = ::random(&mut rng); - serialize_canonical(b128_poly1, &mut buffer).unwrap(); - serialize_canonical(b128_poly2, &mut buffer).unwrap(); + SerializeCanonical::serialize_canonical(&b128_poly1, &mut buffer).unwrap(); + SerializeCanonical::serialize_canonical(&b128_poly2, &mut buffer).unwrap(); let mut read_buffer = buffer.freeze(); assert_eq!( - deserialize_canonical::(&mut read_buffer).unwrap(), + BinaryField128bPolyval::deserialize_canonical(&mut read_buffer).unwrap(), b128_poly1 ); assert_eq!( - BinaryField128bPolyval::from( - deserialize_canonical::(&mut read_buffer).unwrap() - ), + BinaryField128bPolyval::deserialize_canonical(&mut read_buffer).unwrap(), b128_poly2 ); } diff --git a/crates/utils/src/serialization.rs b/crates/field/src/serialization/bytes.rs similarity index 89% rename from crates/utils/src/serialization.rs rename to crates/field/src/serialization/bytes.rs index 434befccb..59335f69d 100644 --- a/crates/utils/src/serialization.rs +++ b/crates/field/src/serialization/bytes.rs @@ -3,13 +3,7 @@ use bytes::{Buf, BufMut}; use generic_array::{ArrayLength, GenericArray}; -#[derive(Clone, thiserror::Error, Debug)] -pub enum Error { - #[error("Write buffer is full")] - WriteBufferFull, - #[error("Not enough data in read buffer to deserialize")] - NotEnoughBytes, -} +use super::Error; /// Represents type that can be serialized to a byte buffer. pub trait SerializeBytes { diff --git a/crates/field/src/serialization/canonical.rs b/crates/field/src/serialization/canonical.rs new file mode 100644 index 000000000..394a7c409 --- /dev/null +++ b/crates/field/src/serialization/canonical.rs @@ -0,0 +1,234 @@ +// Copyright 2025 Irreducible Inc. + +use bytes::{Buf, BufMut}; +use generic_array::{ArrayLength, GenericArray}; + +use super::{DeserializeBytes, Error, SerializeBytes}; +use crate::TowerField; + +/// Serialization where [`TowerField`] elements are written with canonical encoding. +pub trait SerializeCanonical { + fn serialize_canonical(&self, write_buf: impl BufMut) -> Result<(), Error>; +} + +/// Deserialization where [`TowerField`] elements are read with a canonical encoding. +pub trait DeserializeCanonical { + fn deserialize_canonical(read_buf: impl Buf) -> Result + where + Self: Sized; +} + +impl SerializeCanonical for F { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + SerializeBytes::serialize(&F::Canonical::from(*self), &mut write_buf) + } +} + +impl DeserializeCanonical for F { + fn deserialize_canonical(read_buf: impl Buf) -> Result + where + Self: Sized, + { + let canonical: F::Canonical = DeserializeBytes::deserialize(read_buf)?; + Ok(F::from(canonical)) + } +} + +impl SerializeCanonical for usize { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + SerializeCanonical::serialize_canonical(&(*self as u64), &mut write_buf) + } +} + +impl DeserializeCanonical for usize { + fn deserialize_canonical(mut read_buf: impl Buf) -> Result + where + Self: Sized, + { + let value: u64 = DeserializeCanonical::deserialize_canonical(&mut read_buf)?; + Ok(value as Self) + } +} + +impl SerializeCanonical for u128 { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u128(*self); + Ok(()) + } +} + +impl DeserializeCanonical for u128 { + fn deserialize_canonical(mut read_buf: impl Buf) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u128()) + } +} + +impl SerializeCanonical for u64 { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u64(*self); + Ok(()) + } +} + +impl DeserializeCanonical for u64 { + fn deserialize_canonical(mut read_buf: impl Buf) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u64()) + } +} + +impl SerializeCanonical for u32 { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u32(*self); + Ok(()) + } +} + +impl DeserializeCanonical for u32 { + fn deserialize_canonical(mut read_buf: impl Buf) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u32()) + } +} + +impl SerializeCanonical for u16 { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u16(*self); + Ok(()) + } +} + +impl DeserializeCanonical for u16 { + fn deserialize_canonical(mut read_buf: impl Buf) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u16()) + } +} + +impl SerializeCanonical for u8 { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u8(*self); + Ok(()) + } +} + +impl DeserializeCanonical for u8 { + fn deserialize_canonical(mut read_buf: impl Buf) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u8()) + } +} + +impl SerializeCanonical for std::marker::PhantomData { + fn serialize_canonical(&self, _write_buf: impl BufMut) -> Result<(), Error> { + Ok(()) + } +} + +impl DeserializeCanonical for std::marker::PhantomData { + fn deserialize_canonical(_read_buf: impl Buf) -> Result + where + Self: Sized, + { + Ok(Self) + } +} + +impl SerializeCanonical for &str { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + let bytes = self.as_bytes(); + SerializeCanonical::serialize_canonical(&bytes.len(), &mut write_buf)?; + assert_enough_space_for(&write_buf, bytes.len())?; + write_buf.put_slice(bytes); + Ok(()) + } +} + +impl SerializeCanonical for String { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + SerializeCanonical::serialize_canonical(&self.as_str(), &mut write_buf) + } +} + +impl DeserializeCanonical for String { + fn deserialize_canonical(mut read_buf: impl Buf) -> Result + where + Self: Sized, + { + let len = DeserializeCanonical::deserialize_canonical(&mut read_buf)?; + assert_enough_data_for(&read_buf, len)?; + Ok(Self::from_utf8(read_buf.copy_to_bytes(len).to_vec())?) + } +} + +impl SerializeCanonical for Vec { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + SerializeCanonical::serialize_canonical(&self.len(), &mut write_buf)?; + self.iter() + .try_for_each(|item| SerializeCanonical::serialize_canonical(item, &mut write_buf)) + } +} + +impl DeserializeCanonical for Vec { + fn deserialize_canonical(mut read_buf: impl Buf) -> Result + where + Self: Sized, + { + let len: usize = DeserializeCanonical::deserialize_canonical(&mut read_buf)?; + (0..len) + .map(|_| DeserializeCanonical::deserialize_canonical(&mut read_buf)) + .collect() + } +} + +impl> SerializeCanonical for GenericArray { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + assert_enough_space_for(&write_buf, N::USIZE)?; + write_buf.put_slice(self); + Ok(()) + } +} + +impl> DeserializeCanonical for GenericArray { + fn deserialize_canonical(mut read_buf: impl Buf) -> Result { + assert_enough_data_for(&read_buf, N::USIZE)?; + let mut ret = Self::default(); + read_buf.copy_to_slice(&mut ret); + Ok(ret) + } +} + +fn assert_enough_space_for(write_buf: &impl BufMut, size: usize) -> Result<(), Error> { + if write_buf.remaining_mut() < size { + return Err(Error::WriteBufferFull); + } + Ok(()) +} + +fn assert_enough_data_for(read_buf: &impl Buf, size: usize) -> Result<(), Error> { + if read_buf.remaining() < size { + return Err(Error::NotEnoughBytes); + } + Ok(()) +} diff --git a/crates/field/src/serialization/error.rs b/crates/field/src/serialization/error.rs new file mode 100644 index 000000000..bbab7b99c --- /dev/null +++ b/crates/field/src/serialization/error.rs @@ -0,0 +1,13 @@ +// Copyright 2024-2025 Irreducible Inc. + +#[derive(Clone, thiserror::Error, Debug)] +pub enum Error { + #[error("Write buffer is full")] + WriteBufferFull, + #[error("Not enough data in read buffer to deserialize")] + NotEnoughBytes, + #[error("Unknown enum variant index {name}::{index}")] + UnknownEnumVariant { name: &'static str, index: u8 }, + #[error("FromUtf8Error: {0}")] + FromUtf8Error(#[from] std::string::FromUtf8Error), +} diff --git a/crates/field/src/serialization/mod.rs b/crates/field/src/serialization/mod.rs new file mode 100644 index 000000000..99e080769 --- /dev/null +++ b/crates/field/src/serialization/mod.rs @@ -0,0 +1,9 @@ +// Copyright 2024-2025 Irreducible Inc. + +mod bytes; +mod canonical; +mod error; + +pub use bytes::{DeserializeBytes, SerializeBytes}; +pub use canonical::{DeserializeCanonical, SerializeCanonical}; +pub use error::Error; diff --git a/crates/hash/src/serialization.rs b/crates/hash/src/serialization.rs index ff64c9756..d2a3699e6 100644 --- a/crates/hash/src/serialization.rs +++ b/crates/hash/src/serialization.rs @@ -2,7 +2,7 @@ use std::{borrow::Borrow, cmp::min}; -use binius_utils::serialization::SerializeBytes; +use binius_field::SerializeCanonical; use bytes::{buf::UninitSlice, BufMut}; use digest::{ core_api::{Block, BlockSizeUser}, @@ -11,7 +11,7 @@ use digest::{ /// Adapter that wraps a [`Digest`] references and exposes the [`BufMut`] interface. /// -/// This adapter is useful so that structs that implement [`SerializeBytes`] can be serialized +/// This adapter is useful so that structs that implement [`SerializeCanonical`] can be serialized /// directly to a hasher. #[derive(Debug)] pub struct HashBuffer<'a, D: Digest + BlockSizeUser> { @@ -67,7 +67,7 @@ impl Drop for HashBuffer<'_, D> { /// Hashes a sequence of serializable items. pub fn hash_serialize(items: impl IntoIterator>) -> Output where - T: SerializeBytes, + T: SerializeCanonical, D: Digest + BlockSizeUser, { let mut hasher = D::new(); @@ -75,7 +75,7 @@ where let mut buffer = HashBuffer::new(&mut hasher); for item in items { item.borrow() - .serialize(&mut buffer) + .serialize_canonical(&mut buffer) .expect("HashBuffer has infinite capacity"); } } diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index c2d8f069e..f72cef222 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -9,7 +9,7 @@ use std::collections::BTreeSet; use proc_macro::TokenStream; use quote::{quote, ToTokens}; -use syn::{parse_macro_input, Data, DeriveInput, Fields}; +use syn::{parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields}; use crate::{ arith_circuit_poly::ArithCircuitPolyItem, arith_expr::ArithExprItem, @@ -76,6 +76,231 @@ pub fn arith_circuit_poly(input: TokenStream) -> TokenStream { .into() } +/// Derives the trait binius_field::SerializeCanonical for a struct or enum +/// +/// See the DeserializeCanonical derive macro docs for examples/tests +#[proc_macro_derive(SerializeCanonical)] +pub fn derive_serialize_canonical(input: TokenStream) -> TokenStream { + let input: DeriveInput = parse_macro_input!(input); + let span = input.span(); + let name = input.ident; + let mut generics = input.generics.clone(); + generics.type_params_mut().for_each(|type_param| { + type_param + .bounds + .push(parse_quote!(binius_field::SerializeCanonical)) + }); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let body = match input.data { + Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(), + Data::Struct(data) => { + let fields = field_names(data.fields, None); + quote! { + #(binius_field::SerializeCanonical::serialize_canonical(&self.#fields, &mut write_buf)?;)* + } + } + Data::Enum(data) => { + let variants = data + .variants + .into_iter() + .enumerate() + .map(|(i, variant)| { + let variant_ident = &variant.ident; + let variant_index = i as u8; + let fields = field_names(variant.fields.clone(), Some("field_")); + let serialize_variant = quote! { + binius_field::SerializeCanonical::serialize_canonical(&#variant_index, &mut write_buf)?; + #(binius_field::SerializeCanonical::serialize_canonical(#fields, &mut write_buf)?;)* + }; + match variant.fields { + Fields::Named(_) => quote! { + Self::#variant_ident { #(#fields),* } => { + #serialize_variant + } + }, + Fields::Unnamed(_) => quote! { + Self::#variant_ident(#(#fields),*) => { + #serialize_variant + } + }, + Fields::Unit => quote! { + Self::#variant_ident => { + #serialize_variant + } + }, + } + }) + .collect::>(); + + quote! { + match self { + #(#variants)* + } + } + } + }; + quote! { + impl #impl_generics binius_field::SerializeCanonical for #name #ty_generics #where_clause { + fn serialize_canonical(&self, mut write_buf: impl binius_field::bytes::BufMut) -> Result<(), binius_field::serialization::Error> { + #body + Ok(()) + } + } + }.into() +} + +/// Derives the trait binius_field::DeserializeCanonical for a struct or enum +/// +/// ``` +/// use binius_field::{BinaryField128b, SerializeCanonical, DeserializeCanonical}; +/// use binius_macros::{SerializeCanonical, DeserializeCanonical}; +/// +/// #[derive(Debug, PartialEq, SerializeCanonical, DeserializeCanonical)] +/// enum MyEnum { +/// A(usize), +/// B { x: u32, y: u32 }, +/// C +/// } +/// +/// +/// let mut buf = vec![]; +/// let value = MyEnum::B { x: 42, y: 1337 }; +/// MyEnum::serialize_canonical(&value, &mut buf).unwrap(); +/// assert_eq!( +/// MyEnum::deserialize_canonical(buf.as_slice()).unwrap(), +/// value +/// ); +/// +/// +/// #[derive(Debug, PartialEq, SerializeCanonical, DeserializeCanonical)] +/// struct MyStruct { +/// data: Vec +/// } +/// +/// let mut buf = vec![]; +/// let value = MyStruct { +/// data: vec![BinaryField128b::new(1234), BinaryField128b::new(5678)] +/// }; +/// MyStruct::serialize_canonical(&value, &mut buf).unwrap(); +/// assert_eq!( +/// MyStruct::::deserialize_canonical(buf.as_slice()).unwrap(), +/// value +/// ); +/// ``` +#[proc_macro_derive(DeserializeCanonical)] +pub fn derive_deserialize_canonical(input: TokenStream) -> TokenStream { + let input: DeriveInput = parse_macro_input!(input); + let span = input.span(); + let name = input.ident; + let mut generics = input.generics.clone(); + generics.type_params_mut().for_each(|type_param| { + type_param + .bounds + .push(parse_quote!(binius_field::DeserializeCanonical)) + }); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let deserialize_value = quote! { + binius_field::DeserializeCanonical::deserialize_canonical(&mut read_buf)? + }; + let body = match input.data { + Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(), + Data::Struct(data) => { + let fields = field_names(data.fields, None); + quote! { + Ok(Self { + #(#fields: #deserialize_value,)* + }) + } + } + Data::Enum(data) => { + let variants = data + .variants + .into_iter() + .enumerate() + .map(|(i, variant)| { + let variant_ident = &variant.ident; + let variant_index: u8 = i as u8; + match variant.fields { + Fields::Named(fields) => { + let fields = fields + .named + .into_iter() + .map(|field| field.ident) + .map(|field_name| quote!(#field_name: #deserialize_value)) + .collect::>(); + + quote! { + #variant_index => Self::#variant_ident { #(#fields,)* } + } + } + Fields::Unnamed(fields) => { + let fields = fields + .unnamed + .into_iter() + .map(|_| quote!(#deserialize_value)) + .collect::>(); + + quote! { + #variant_index => Self::#variant_ident(#(#fields,)*) + } + } + Fields::Unit => quote! { + #variant_index => Self::#variant_ident + }, + } + }) + .collect::>(); + + let name = name.to_string(); + quote! { + let variant_index: u8 = #deserialize_value; + Ok(match variant_index { + #(#variants,)* + _ => { + return Err(binius_field::serialization::Error::UnknownEnumVariant { + name: #name, + index: variant_index + }) + } + }) + } + } + }; + quote! { + impl #impl_generics binius_field::DeserializeCanonical for #name #ty_generics #where_clause { + fn deserialize_canonical(mut read_buf: impl binius_field::bytes::Buf) -> Result + where + Self: Sized + { + #body + } + } + } + .into() +} + +fn field_names(fields: Fields, positional_prefix: Option<&str>) -> Vec { + match fields { + Fields::Named(fields) => fields + .named + .into_iter() + .map(|field| field.ident.into_token_stream()) + .collect(), + Fields::Unnamed(fields) => fields + .unnamed + .into_iter() + .enumerate() + .map(|(i, _)| match positional_prefix { + Some(prefix) => { + quote::format_ident!("{}{}", prefix, syn::Index::from(i)).into_token_stream() + } + None => syn::Index::from(i).into_token_stream(), + }) + .collect(), + Fields::Unit => vec![], + } +} + /// Implements `pub fn iter_oracles(&self) -> impl Iterator`. /// /// Detects and includes fields with type `OracleId`, `[OracleId; N]` diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index 0d04d2917..5bdc5ad42 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -9,7 +9,6 @@ workspace = true [dependencies] binius_maybe_rayon = { path = "../maybe_rayon", default-features = false } -bytes.workspace = true bytemuck = { workspace = true, features = ["extern_crate_alloc"] } cfg-if.workspace = true generic-array.workspace = true diff --git a/crates/utils/src/lib.rs b/crates/utils/src/lib.rs index 493fe6e20..70606d3f0 100644 --- a/crates/utils/src/lib.rs +++ b/crates/utils/src/lib.rs @@ -13,7 +13,6 @@ pub mod felts; pub mod graph; pub mod iter; pub mod rayon; -pub mod serialization; pub mod sorting; pub mod sparse_index; pub mod thread_local_mut; From ee410877a94d1022fa9f2518ba4ff82cfabb2938 Mon Sep 17 00:00:00 2001 From: Milos Backonja <35807060+milosbackonja@users.noreply.github.com> Date: Wed, 12 Feb 2025 13:21:48 +0100 Subject: [PATCH 17/50] [security]: Add CODEOWNERS file for GitHub --- CODEOWNERS | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 CODEOWNERS diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 000000000..29e34188a --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,5 @@ +# Code owners for the entire repository +* @jimpo-ulvt @onesk + +# Code owners for the .github path +.github/* @IrreducibleOSS/Infrastructure From f54479f844bd5b8fdf026fce0356b3a41c5092ec Mon Sep 17 00:00:00 2001 From: Anex007 Date: Wed, 12 Feb 2025 07:45:53 -0500 Subject: [PATCH 18/50] [scripts] Added benchmarking script This adds the script to benchmark various of our examples, default sampling is set to 5 to reduce total time to benchmark. --- scripts/nightly_benchmarks.py | 273 ++++++++++++++++++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100755 scripts/nightly_benchmarks.py diff --git a/scripts/nightly_benchmarks.py b/scripts/nightly_benchmarks.py new file mode 100755 index 000000000..7aea6980d --- /dev/null +++ b/scripts/nightly_benchmarks.py @@ -0,0 +1,273 @@ +#!/usr/bin/python3 + +import argparse +import csv +import json +import os +import re +import subprocess +from typing import Union + +ENV_VARS = { + "RUSTFLAGS": "-C target-cpu=native", +} + +SAMPLE_SIZE = 5 + +KECCAKF_PERMS = 1 << 13 +GROESTLP_PERMS = 1 << 14 +VISION32B_PERMS = 1 << 14 +SHA256_PERMS = 1 << 14 +NUM_BINARY_OPS = 1 << 22 +NUM_MULS = 1 << 20 + +HASHER_TO_RUN = { + r"keccakf": { + "type": "hasher", + "display": r"Keccak-f", + "export": "keccakf-report.csv", + "args": ["keccakf_circuit", "--", "--n-permutations"], + "n_ops": KECCAKF_PERMS, + }, + "groestlp": { + "type": "hasher", + "display": r"Grøstl P", + "export": "groestl-report.csv", + "args": ["groestl_circuit", "--", "--n-permutations"], + "n_ops": GROESTLP_PERMS, + }, + "vision32b": { + "type": "hasher", + "display": r"Vision Mark-32", + "export": "vision32b-report.csv", + "args": ["vision32b_circuit", "--", "--n-permutations"], + "n_ops": VISION32B_PERMS, + }, + "sha256": { + "type": "hasher", + "display": "SHA-256", + "export": "sha256-report.csv", + "args": ["sha256_circuit", "--", "--n-compressions"], + "n_ops": SHA256_PERMS, + }, + "b32_mul": { + "type": "binary_ops", + "display": "BinaryField32b mul", + "export": "b32-mul-report.csv", + "args": ["b32_mul", "--", "--n-ops"], + "n_ops": NUM_MULS, + }, + "u32_add": { + "type": "binary_ops", + "display": "u32 add", + "export": "u32-add-report.csv", + "args": ["u32_add", "--", "--n-additions"], + "n_ops": NUM_BINARY_OPS, + }, + "u32_mul": { + "type": "binary_ops", + "display": "u32 mul", + "export": "u32-mul-report.csv", + "args": ["u32_mul", "--", "--n-muls"], + "n_ops": NUM_MULS, + }, + "xor": { + "type": "binary_ops", + "display": "Xor", + "export": "xor-report.csv", + "args": ["bitwise_ops", "--", "--op", "xor", "--n-u32-ops"], + "n_ops": NUM_BINARY_OPS, + }, + "and": { + "type": "binary_ops", + "display": "And", + "export": "and-report.csv", + "args": ["bitwise_ops", "--", "--op", "and", "--n-u32-ops"], + "n_ops": NUM_BINARY_OPS, + }, + "or": { + "type": "binary_ops", + "display": "Or", + "export": "or-report.csv", + "args": ["bitwise_ops", "--", "--op", "or", "--n-u32-ops"], + "n_ops": NUM_BINARY_OPS, + }, +} + +HASHER_BENCHMARKS = {} +BINARY_OPS_BENCHMARKS = {} + + +def run_benchmark(benchmark_args) -> tuple[bytes, bytes]: + command = ( + ["cargo", "run", "--release", "--example"] + + benchmark_args["args"] + + [f"{benchmark_args['n_ops']}"] + ) + env_vars_to_run = { + **os.environ, + **ENV_VARS, + "PROFILE_CSV_FILE": benchmark_args["export"], + } + process = subprocess.run( + command, env=env_vars_to_run, capture_output=True, check=True + ) + return process.stdout, process.stderr + + +def parse_csv_file(file_name) -> dict: + data = {} + with open(file_name) as file: + reader = csv.reader(file) + for row in reader: + if row[0] == "generating trace": + data.update({"trace_gen_time": int(row[2])}) + elif row[0] == "constraint_system::prove": + data.update({"proving_time": int(row[2])}) + elif row[0] == "constraint_system::verify": + data.update({"verification_time": int(row[2])}) + return data + + +KIB_TO_BYTES = 1024.0 +MIB_TO_BYTES = KIB_TO_BYTES * 1024.0 +GIB_TO_BYTES = MIB_TO_BYTES * 1024.0 +KB_TO_BYTES = 1000.0 +MB_TO_BYTES = KB_TO_BYTES * 1000.0 +GB_TO_BYTES = MB_TO_BYTES * 1000.0 + +SIZE_CONVERSIONS = { + "KiB": KIB_TO_BYTES, + "MiB": MIB_TO_BYTES, + "GiB": GIB_TO_BYTES, + " B": 1, + "KB": KB_TO_BYTES, + "MB": MB_TO_BYTES, + "GB": GB_TO_BYTES, +} + + +def parse_proof_size(proof_size: bytes) -> int: + proof_size = proof_size.decode("utf-8").strip() + for unit, factor in SIZE_CONVERSIONS.items(): + if proof_size.endswith(unit): + byte_len = float(proof_size[: -len(unit)]) * factor + break + else: + raise ValueError(f"Unknown proof size format: {proof_size}") + + # Convert to KiB + return int(byte_len / KIB_TO_BYTES) + + +def nano_to_milli(nano) -> float: + return float(nano) / 1000000.0 + + +def nano_to_seconds(nano) -> float: + return float(nano) / 1000000000.0 + + +def run_and_parse_benchmark(benchmark, benchmark_args) -> tuple[dict, int]: + data = {} + stdout = None + print(f"Running benchmark: {benchmark} with {SAMPLE_SIZE} samples") + for _ in range(SAMPLE_SIZE): + stdout, _stderr = run_benchmark(benchmark_args) + result = parse_csv_file(benchmark_args["export"]) + # Parse the csv file + if len(result.keys()) != 3: + print(f"Failed to parse csv file for benchmark: {benchmark}") + exit(1) + + # Append the results to the data + for key, value in result.items(): + if data.get(key) is None: + data[key] = [] + data[key].append(value) + # Get proof sizes + found = re.search(rb"Proof size: (.*)", stdout) + if found: + return data, parse_proof_size(found.group(1)) + else: + print(f"Failed to get proof size for benchmark: {benchmark}") + exit(1) + + +def run_benchmark_group(benchmarks) -> dict: + benchmark_results = {} + for benchmark, benchmark_args in benchmarks.items(): + try: + data, proof_size = run_and_parse_benchmark(benchmark, benchmark_args) + benchmark_results[benchmark] = {"proof_size_kib": proof_size} + data["n_ops"] = benchmark_args["n_ops"] + data["display"] = benchmark_args["display"] + data["type"] = benchmark_args["type"] + benchmark_results[benchmark].update(data) + + except Exception as e: + print(f"Failed to run benchmark: {benchmark} with error {e} \nExiting...") + exit(1) + return benchmark_results + + +def value_to_bencher(value: Union[list[float], int], throughput: bool = False) -> dict: + if isinstance(value, list): + avg_value = sum(value) / len(value) + max_value = max(value) + min_value = min(value) + else: + avg_value = max_value = min_value = value + + metric_type = "throughput" if throughput else "latency" + return { + metric_type: { + "value": avg_value, + "upper_value": max_value, + "lower_value": min_value, + } + } + + +def dict_to_bencher(data: dict) -> dict: + bencher_data = {} + for benchmark, value in data.items(): + # Name is of the following format: ::::(trace_gen_time | proving_time | verification_time | proof_size_kib | n_ops) + common_name = f"{value['type']}::{value['display']}" + for key in [ + "trace_gen_time", + "proving_time", + "verification_time", + "proof_size_kib", + "n_ops", + ]: + bencher_data[f"{common_name}::{key}"] = value_to_bencher(value[key]) + return bencher_data + + +def main(): + parser = argparse.ArgumentParser( + description="Run nightly benchmarks and export results" + ) + parser.add_argument( + "--export-file", + required=False, + type=str, + help="Export benchmarks results to file (defaults to stdout)", + ) + + args = parser.parse_args() + + benchmarks = run_benchmark_group(HASHER_TO_RUN) + + bencher_data = dict_to_bencher(benchmarks) + if args.export_file is None: + print("Couldn't find export file for hashers writing to stdout instead") + print(json.dumps(bencher_data)) + else: + with open(args.export_file, "w") as file: + json.dump(bencher_data, file) + + +if __name__ == "__main__": + main() From 3f2033156cc391ad910c0510e22248b95afcc8ac Mon Sep 17 00:00:00 2001 From: Dmytro Gordon Date: Wed, 12 Feb 2025 15:38:47 +0200 Subject: [PATCH 19/50] [field] Implement PackedField::unzip --- crates/field/src/arch/aarch64/m128.rs | 31 ++++ .../byte_sliced/packed_byte_sliced.rs | 23 ++- crates/field/src/arch/portable/packed.rs | 8 + .../src/arch/portable/packed_arithmetic.rs | 12 ++ .../field/src/arch/portable/packed_scaled.rs | 37 +++++ crates/field/src/arch/x86_64/m128.rs | 63 +++++++- crates/field/src/arch/x86_64/m256.rs | 58 +++++++ crates/field/src/arch/x86_64/m512.rs | 95 ++++++++++++ crates/field/src/packed.rs | 18 +++ crates/field/src/packed_binary_field.rs | 146 ++++++++++++++++++ crates/math/src/deinterleave.rs | 13 +- 11 files changed, 488 insertions(+), 16 deletions(-) diff --git a/crates/field/src/arch/aarch64/m128.rs b/crates/field/src/arch/aarch64/m128.rs index a32a82b7d..8155d6034 100644 --- a/crates/field/src/arch/aarch64/m128.rs +++ b/crates/field/src/arch/aarch64/m128.rs @@ -401,6 +401,37 @@ impl UnderlierWithBitConstants for M128 { } } } + + #[inline] + fn transpose(self, other: Self, log_block_len: usize) -> (Self, Self) { + unsafe { + match log_block_len { + 0..=3 => { + let (a, b) = (self.into(), other.into()); + let (mut a, mut b) = (Self::from(vuzp1q_u8(a, b)), Self::from(vuzp2q_u8(a, b))); + + for log_block_len in (log_block_len..3).rev() { + (a, b) = a.interleave(b, log_block_len); + } + + (a, b) + } + 4 => { + let (a, b) = (self.into(), other.into()); + (vuzp1q_u16(a, b).into(), vuzp2q_u16(a, b).into()) + } + 5 => { + let (a, b) = (self.into(), other.into()); + (vuzp1q_u32(a, b).into(), vuzp2q_u32(a, b).into()) + } + 6 => { + let (a, b) = (self.into(), other.into()); + (vuzp1q_u64(a, b).into(), vuzp2q_u64(a, b).into()) + } + _ => panic!("Unsupported block length"), + } + } + } } impl From for PackedPrimitiveType { diff --git a/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs b/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs index 7746795d1..fd4c9b84d 100644 --- a/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs +++ b/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs @@ -55,6 +55,7 @@ macro_rules! define_byte_sliced { const LOG_WIDTH: usize = 5; + #[inline(always)] unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar { let mut result_underlier = 0; for (byte_index, val) in self.data.iter().enumerate() { @@ -70,6 +71,7 @@ macro_rules! define_byte_sliced { Self::Scalar::from_underlier(result_underlier) } + #[inline(always)] unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) { let underlier = scalar.to_underlier(); @@ -86,6 +88,7 @@ macro_rules! define_byte_sliced { Self::from_scalars([Self::Scalar::random(rng); 32]) } + #[inline] fn broadcast(scalar: Self::Scalar) -> Self { Self { data: array::from_fn(|byte_index| { @@ -96,6 +99,7 @@ macro_rules! define_byte_sliced { } } + #[inline] fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self { let mut result = Self::default(); @@ -107,6 +111,7 @@ macro_rules! define_byte_sliced { result } + #[inline] fn square(self) -> Self { let mut result = Self::default(); @@ -115,22 +120,34 @@ macro_rules! define_byte_sliced { result } + #[inline] fn invert_or_zero(self) -> Self { let mut result = Self::default(); invert_or_zero::<$tower_level>(&self.data, &mut result.data); result } + #[inline] fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) { let mut result1 = Self::default(); let mut result2 = Self::default(); for byte_num in 0..<$tower_level as TowerLevel>::WIDTH { - let (this_byte_result1, this_byte_result2) = + (result1.data[byte_num], result2.data[byte_num]) = self.data[byte_num].interleave(other.data[byte_num], log_block_len); + } + + (result1, result2) + } - result1.data[byte_num] = this_byte_result1; - result2.data[byte_num] = this_byte_result2; + #[inline] + fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) { + let mut result1 = Self::default(); + let mut result2 = Self::default(); + + for byte_num in 0..<$tower_level as TowerLevel>::WIDTH { + (result1.data[byte_num], result2.data[byte_num]) = + self.data[byte_num].unzip(other.data[byte_num], log_block_len); } (result1, result2) diff --git a/crates/field/src/arch/portable/packed.rs b/crates/field/src/arch/portable/packed.rs index ad372dd9e..66bbb9650 100644 --- a/crates/field/src/arch/portable/packed.rs +++ b/crates/field/src/arch/portable/packed.rs @@ -336,6 +336,14 @@ where (c.into(), d.into()) } + #[inline] + fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) { + assert!(log_block_len < Self::LOG_WIDTH); + let log_bit_len = Self::Scalar::N_BITS.ilog2() as usize; + let (c, d) = self.0.transpose(other.0, log_block_len + log_bit_len); + (c.into(), d.into()) + } + #[inline] unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self { debug_assert!(log_block_len <= Self::LOG_WIDTH, "{} <= {}", log_block_len, Self::LOG_WIDTH); diff --git a/crates/field/src/arch/portable/packed_arithmetic.rs b/crates/field/src/arch/portable/packed_arithmetic.rs index fbeea240b..98d0f1304 100644 --- a/crates/field/src/arch/portable/packed_arithmetic.rs +++ b/crates/field/src/arch/portable/packed_arithmetic.rs @@ -37,6 +37,18 @@ where (c, d) } + + /// Transpose with the given bit size + fn transpose(mut self, mut other: Self, log_block_len: usize) -> (Self, Self) { + // There are 2^7 = 128 bits in a u128 + assert!(log_block_len < Self::INTERLEAVE_EVEN_MASK.len()); + + for log_block_len in (log_block_len..Self::LOG_BITS).rev() { + (self, other) = self.interleave(other, log_block_len); + } + + (self, other) + } } /// Abstraction for a packed tower field of height greater than 0. diff --git a/crates/field/src/arch/portable/packed_scaled.rs b/crates/field/src/arch/portable/packed_scaled.rs index a4b55a2cf..93f693bcd 100644 --- a/crates/field/src/arch/portable/packed_scaled.rs +++ b/crates/field/src/arch/portable/packed_scaled.rs @@ -216,14 +216,17 @@ where Self(array::from_fn(|_| PT::random(&mut rng))) } + #[inline] fn broadcast(scalar: Self::Scalar) -> Self { Self(array::from_fn(|_| PT::broadcast(scalar))) } + #[inline] fn square(self) -> Self { Self(self.0.map(|v| v.square())) } + #[inline] fn invert_or_zero(self) -> Self { Self(self.0.map(|v| v.invert_or_zero())) } @@ -253,6 +256,40 @@ where (Self(first), Self(second)) } + fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self) { + let mut first = [Default::default(); N]; + let mut second = [Default::default(); N]; + + if log_block_len >= PT::LOG_WIDTH { + let block_in_pts = 1 << (log_block_len - PT::LOG_WIDTH); + for i in (0..N / 2).step_by(block_in_pts) { + first[i..i + block_in_pts].copy_from_slice(&self.0[2 * i..2 * i + block_in_pts]); + + second[i..i + block_in_pts] + .copy_from_slice(&self.0[2 * i + block_in_pts..2 * (i + block_in_pts)]); + } + + for i in (0..N / 2).step_by(block_in_pts) { + first[i + N / 2..i + N / 2 + block_in_pts] + .copy_from_slice(&other.0[2 * i..2 * i + block_in_pts]); + + second[i + N / 2..i + N / 2 + block_in_pts] + .copy_from_slice(&other.0[2 * i + block_in_pts..2 * (i + block_in_pts)]); + } + } else { + for i in 0..N / 2 { + (first[i], second[i]) = self.0[2 * i].unzip(self.0[2 * i + 1], log_block_len); + } + + for i in 0..N / 2 { + (first[i + N / 2], second[i + N / 2]) = + other.0[2 * i].unzip(other.0[2 * i + 1], log_block_len); + } + } + + (Self(first), Self(second)) + } + #[inline] unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self { let log_n = checked_log_2(N); diff --git a/crates/field/src/arch/x86_64/m128.rs b/crates/field/src/arch/x86_64/m128.rs index f75d152d2..8223f63fa 100644 --- a/crates/field/src/arch/x86_64/m128.rs +++ b/crates/field/src/arch/x86_64/m128.rs @@ -779,14 +779,27 @@ impl UnderlierWithBitConstants for M128 { #[inline(always)] fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) { - unsafe { - let (c, d) = interleave_bits( + let (c, d) = unsafe { + interleave_bits( Into::::into(self).into(), Into::::into(other).into(), log_block_len, - ); - (Self::from(c), Self::from(d)) - } + ) + }; + (Self::from(c), Self::from(d)) + } + + #[inline(always)] + fn transpose(self, other: Self, log_block_len: usize) -> (Self, Self) { + let (c, d) = unsafe { + transpose_bits( + Into::::into(self).into(), + Into::::into(other).into(), + log_block_len, + ) + }; + + (Self::from(c), Self::from(d)) } } @@ -882,6 +895,46 @@ unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m1 } } +#[inline] +unsafe fn transpose_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) { + match log_block_len { + 0..=3 => { + let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); + let (mut a, mut b) = transpose_with_shuffle(a, b, shuffle); + for log_block_len in (log_block_len..3).rev() { + (a, b) = interleave_bits(a, b, log_block_len); + } + + (a, b) + } + 4 => { + let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0); + transpose_with_shuffle(a, b, shuffle) + } + 5 => { + let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); + transpose_with_shuffle(a, b, shuffle) + } + 6 => { + let c = _mm_unpacklo_epi64(a, b); + let d = _mm_unpackhi_epi64(a, b); + (c, d) + } + _ => panic!("unsupported block length"), + } +} + +#[inline(always)] +unsafe fn transpose_with_shuffle( + a: __m128i, + b: __m128i, + shuffle_mask: __m128i, +) -> (__m128i, __m128i) { + let a = _mm_shuffle_epi8(a, shuffle_mask); + let b = _mm_shuffle_epi8(b, shuffle_mask); + (_mm_unpacklo_epi64(a, b), _mm_unpackhi_epi64(a, b)) +} + #[inline] unsafe fn interleave_bits_imm( a: __m128i, diff --git a/crates/field/src/arch/x86_64/m256.rs b/crates/field/src/arch/x86_64/m256.rs index 3c36827fc..4f3663d7c 100644 --- a/crates/field/src/arch/x86_64/m256.rs +++ b/crates/field/src/arch/x86_64/m256.rs @@ -910,6 +910,13 @@ impl UnderlierWithBitConstants for M256 { let (a, b) = unsafe { interleave_bits(self.0, other.0, log_block_len) }; (Self(a), Self(b)) } + + fn transpose(mut self, mut other: Self, log_block_len: usize) -> (Self, Self) { + let (a, b) = unsafe { transpose_bits(self.0, other.0, log_block_len) }; + self.0 = a; + other.0 = b; + (self, other) + } } #[inline] @@ -974,6 +981,57 @@ unsafe fn interleave_bits(a: __m256i, b: __m256i, log_block_len: usize) -> (__m2 } } +#[inline] +unsafe fn transpose_bits(a: __m256i, b: __m256i, log_block_len: usize) -> (__m256i, __m256i) { + match log_block_len { + 0..=3 => { + let shuffle = _mm256_set_epi8( + 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0, 15, 13, 11, 9, 7, 5, 3, 1, + 14, 12, 10, 8, 6, 4, 2, 0, + ); + let (mut a, mut b) = transpose_with_shuffle(a, b, shuffle); + for log_block_len in (log_block_len..3).rev() { + (a, b) = interleave_bits(a, b, log_block_len); + } + + (a, b) + } + 4 => { + let shuffle = _mm256_set_epi8( + 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, + 13, 12, 9, 8, 5, 4, 1, 0, + ); + + transpose_with_shuffle(a, b, shuffle) + } + 5 => { + let shuffle = _mm256_set_epi8( + 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0, 15, 14, 13, 12, 7, 6, 5, 4, + 11, 10, 9, 8, 3, 2, 1, 0, + ); + + transpose_with_shuffle(a, b, shuffle) + } + 6 => { + let (a, b) = (_mm256_unpacklo_epi64(a, b), _mm256_unpackhi_epi64(a, b)); + + (_mm256_permute4x64_epi64(a, 0b11011000), _mm256_permute4x64_epi64(b, 0b11011000)) + } + 7 => (_mm256_permute2x128_si256(a, b, 0x20), _mm256_permute2x128_si256(a, b, 0x31)), + _ => panic!("unsupported block length"), + } +} + +#[inline(always)] +unsafe fn transpose_with_shuffle(a: __m256i, b: __m256i, shuffle: __m256i) -> (__m256i, __m256i) { + let a = _mm256_shuffle_epi8(a, shuffle); + let b = _mm256_shuffle_epi8(b, shuffle); + + let (a, b) = (_mm256_unpacklo_epi64(a, b), _mm256_unpackhi_epi64(a, b)); + + (_mm256_permute4x64_epi64(a, 0b11011000), _mm256_permute4x64_epi64(b, 0b11011000)) +} + #[inline] unsafe fn interleave_bits_imm( a: __m256i, diff --git a/crates/field/src/arch/x86_64/m512.rs b/crates/field/src/arch/x86_64/m512.rs index caa821c62..80c4bccbf 100644 --- a/crates/field/src/arch/x86_64/m512.rs +++ b/crates/field/src/arch/x86_64/m512.rs @@ -932,6 +932,12 @@ impl UnderlierWithBitConstants for M512 { let (a, b) = unsafe { interleave_bits(self.0, other.0, log_block_len) }; (Self(a), Self(b)) } + + #[inline(always)] + fn transpose(self, other: Self, log_bit_len: usize) -> (Self, Self) { + let (a, b) = unsafe { transpose_bits(self.0, other.0, log_bit_len) }; + (Self(a), Self(b)) + } } #[inline] @@ -1103,6 +1109,95 @@ const fn precompute_spread_mask( m512_masks } +#[inline(always)] +unsafe fn transpose_bits(a: __m512i, b: __m512i, log_block_len: usize) -> (__m512i, __m512i) { + match log_block_len { + 0..=3 => { + let shuffle = _mm512_set_epi8( + 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0, 15, 13, 11, 9, 7, 5, 3, 1, + 14, 12, 10, 8, 6, 4, 2, 0, 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0, + 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0, + ); + let (mut a, mut b) = transpose_with_shuffle(a, b, shuffle); + for log_block_len in (log_block_len..3).rev() { + (a, b) = interleave_bits(a, b, log_block_len); + } + + (a, b) + } + 4 => { + let shuffle = _mm512_set_epi8( + 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, + 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, + 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, + ); + transpose_with_shuffle(a, b, shuffle) + } + 5 => { + let shuffle = _mm512_set_epi8( + 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0, 15, 14, 13, 12, 7, 6, 5, 4, + 11, 10, 9, 8, 3, 2, 1, 0, 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0, 15, + 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0, + ); + transpose_with_shuffle(a, b, shuffle) + } + 6 => ( + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1110, 0b1100, 0b1010, 0b1000, 0b0110, 0b0100, 0b0010, 0b0000), + b, + ), + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1111, 0b1101, 0b1011, 0b1001, 0b0111, 0b0101, 0b0011, 0b0001), + b, + ), + ), + 7 => ( + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1101, 0b1100, 0b1001, 0b1000, 0b0101, 0b0100, 0b0001, 0b0000), + b, + ), + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1111, 0b1110, 0b1011, 0b1010, 0b0111, 0b0110, 0b0011, 0b0010), + b, + ), + ), + 8 => ( + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1011, 0b1010, 0b1001, 0b1000, 0b0011, 0b0010, 0b0001, 0b0000), + b, + ), + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1111, 0b1110, 0b1101, 0b1100, 0b0111, 0b0110, 0b0101, 0b0100), + b, + ), + ), + _ => panic!("unsupported block length"), + } +} + +unsafe fn transpose_with_shuffle(a: __m512i, b: __m512i, shuffle: __m512i) -> (__m512i, __m512i) { + let (a, b) = (_mm512_shuffle_epi8(a, shuffle), _mm512_shuffle_epi8(b, shuffle)); + + ( + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1110, 0b1100, 0b1010, 0b1000, 0b0110, 0b0100, 0b0010, 0b0000), + b, + ), + _mm512_permutex2var_epi64( + a, + _mm512_set_epi64(0b1111, 0b1101, 0b1011, 0b1001, 0b0111, 0b0101, 0b0011, 0b0001), + b, + ), + ) +} + impl_iteration!(M512, @strategy BitIterationStrategy, U1, @strategy FallbackStrategy, U2, U4, diff --git a/crates/field/src/packed.rs b/crates/field/src/packed.rs index 5fcddf146..471121b01 100644 --- a/crates/field/src/packed.rs +++ b/crates/field/src/packed.rs @@ -216,6 +216,20 @@ pub trait PackedField: /// * `log_block_len` must be strictly less than `LOG_WIDTH`. fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self); + /// Unzips interleaved blocks of this packed vector with another packed vector. + /// + /// Consider this example, where `LOG_WIDTH` is 3 and `log_block_len` is 1: + /// A = [a0, a1, b0, b1, a2, a3, b2, b3] + /// B = [a4, a5, b4, b5, a6, a7, b6, b7] + /// + /// The transposed result is + /// A' = [a0, a1, a2, a3, a4, a5, a6, a7] + /// B' = [b0, b1, b2, b3, b4, b5, b6, b7] + /// + /// ## Preconditions + /// * `log_block_len` must be strictly less than `LOG_WIDTH`. + fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self); + /// Spread takes a block of elements within a packed field and repeats them to the full packing /// width. /// @@ -427,6 +441,10 @@ impl PackedField for F { panic!("cannot interleave when WIDTH = 1"); } + fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) { + panic!("cannot transpose when WIDTH = 1"); + } + fn broadcast(scalar: Self::Scalar) -> Self { scalar } diff --git a/crates/field/src/packed_binary_field.rs b/crates/field/src/packed_binary_field.rs index b22f48fef..152cbf6bd 100644 --- a/crates/field/src/packed_binary_field.rs +++ b/crates/field/src/packed_binary_field.rs @@ -807,6 +807,63 @@ pub mod test_utils { check_interleave::

(lhs, rhs, log_block_len); } } + + pub fn check_unzip( + lhs: P::Underlier, + rhs: P::Underlier, + log_block_len: usize, + ) { + let lhs = P::from_underlier(lhs); + let rhs = P::from_underlier(rhs); + let block_len = 1 << log_block_len; + let (a, b) = lhs.unzip(rhs, log_block_len); + for i in (0..P::WIDTH / 2).step_by(block_len) { + for j in 0..block_len { + assert_eq!( + a.get(i + j), + lhs.get(2 * i + j), + "i: {}, j: {}, log_block_len: {}, P: {:?}", + i, + j, + log_block_len, + P::zero() + ); + assert_eq!( + b.get(i + j), + lhs.get(2 * i + j + block_len), + "i: {}, j: {}, log_block_len: {}, P: {:?}", + i, + j, + log_block_len, + P::zero() + ); + } + } + + for i in (0..P::WIDTH / 2).step_by(block_len) { + for j in 0..block_len { + assert_eq!( + a.get(i + j + P::WIDTH / 2), + rhs.get(2 * i + j), + "i: {}, j: {}, log_block_len: {}, P: {:?}", + i, + j, + log_block_len, + P::zero() + ); + assert_eq!(b.get(i + j + P::WIDTH / 2), rhs.get(2 * i + j + block_len)); + } + } + } + + pub fn check_transpose_all_heights( + lhs: P::Underlier, + rhs: P::Underlier, + ) { + for log_block_len in 0..P::LOG_WIDTH { + check_unzip::

(lhs, rhs, log_block_len); + } + } } #[cfg(test)] @@ -831,6 +888,7 @@ mod tests { }, arithmetic_traits::MulAlpha, linear_transformation::PackedTransformationFactory, + test_utils::check_transpose_all_heights, underlier::{U2, U4}, Field, PackedField, PackedFieldIndexable, }; @@ -1206,5 +1264,93 @@ mod tests { check_interleave_all_heights::(a_val.into(), b_val.into()); check_interleave_all_heights::(a_val.into(), b_val.into()); } + + #[test] + fn check_transpose_2b(a_val in 0u8..3, b_val in 0u8..3) { + check_transpose_all_heights::(U2::new(a_val), U2::new(b_val)); + check_transpose_all_heights::(U2::new(a_val), U2::new(b_val)); + } + + #[test] + fn check_transpose_4b(a_val in 0u8..16, b_val in 0u8..16) { + check_transpose_all_heights::(U4::new(a_val), U4::new(b_val)); + check_transpose_all_heights::(U4::new(a_val), U4::new(b_val)); + check_transpose_all_heights::(U4::new(a_val), U4::new(b_val)); + } + + #[test] + fn check_transpose_8b(a_val in 0u8.., b_val in 0u8..) { + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + } + + #[test] + fn check_transpose_16b(a_val in 0u16.., b_val in 0u16..) { + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + } + + #[test] + fn check_transpose_32b(a_val in 0u32.., b_val in 0u32..) { + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + } + + #[test] + fn check_transpose_64b(a_val in 0u64.., b_val in 0u64..) { + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + check_transpose_all_heights::(a_val, b_val); + } + + #[test] + #[allow(clippy::useless_conversion)] // this warning depends on the target platform + fn check_transpose_128b(a_val in 0u128.., b_val in 0u128..) { + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + } + + #[test] + fn check_transpose_256b(a_val in any::<[u128; 2]>(), b_val in any::<[u128; 2]>()) { + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + } + + #[test] + fn check_transpose_512b(a_val in any::<[u128; 4]>(), b_val in any::<[u128; 4]>()) { + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + check_transpose_all_heights::(a_val.into(), b_val.into()); + } } } diff --git a/crates/math/src/deinterleave.rs b/crates/math/src/deinterleave.rs index 4f01b437d..1a5972906 100644 --- a/crates/math/src/deinterleave.rs +++ b/crates/math/src/deinterleave.rs @@ -30,14 +30,11 @@ pub fn deinterleave( } let deinterleaved = (0..1 << (log_scalar_count - P::LOG_WIDTH)).map(|i| { - let mut even = interleaved[2 * i]; - let mut odd = interleaved[2 * i + 1]; - - for log_block_len in (0..P::LOG_WIDTH).rev() { - let (even_interleaved, odd_interleaved) = even.interleave(odd, log_block_len); - even = even_interleaved; - odd = odd_interleaved; - } + let (even, odd) = if P::LOG_WIDTH > 0 { + P::unzip(interleaved[2 * i], interleaved[2 * i + 1], 0) + } else { + (interleaved[2 * i], interleaved[2 * i + 1]) + }; (i, even, odd) }); From f8436560019aa67da6513b5acf72d9ae0d33e5c7 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Wed, 12 Feb 2025 14:51:31 +0100 Subject: [PATCH 20/50] [cleanup]: Remove some useless checked_log_2 calls --- crates/core/src/reed_solomon/reed_solomon.rs | 5 ++--- crates/core/src/tensor_algebra.rs | 3 +-- crates/ntt/src/additive_ntt.rs | 15 ++------------- 3 files changed, 5 insertions(+), 18 deletions(-) diff --git a/crates/core/src/reed_solomon/reed_solomon.rs b/crates/core/src/reed_solomon/reed_solomon.rs index 292744223..2b080050e 100644 --- a/crates/core/src/reed_solomon/reed_solomon.rs +++ b/crates/core/src/reed_solomon/reed_solomon.rs @@ -15,7 +15,7 @@ use std::marker::PhantomData; use binius_field::{BinaryField, ExtensionField, PackedField, RepackedExtension}; use binius_maybe_rayon::prelude::*; use binius_ntt::{AdditiveNTT, DynamicDispatchNTT, Error, NTTOptions, ThreadingSettings}; -use binius_utils::{bail, checked_arithmetics::checked_log_2}; +use binius_utils::bail; use getset::CopyGetters; use tracing::instrument; @@ -169,7 +169,6 @@ where PE: RepackedExtension

, PE::Scalar: ExtensionField<

::Scalar>, { - let log_degree = checked_log_2(PE::Scalar::DEGREE); - self.encode_batch_inplace(PE::cast_bases_mut(code), log_batch_size + log_degree) + self.encode_batch_inplace(PE::cast_bases_mut(code), log_batch_size + PE::Scalar::LOG_DEGREE) } } diff --git a/crates/core/src/tensor_algebra.rs b/crates/core/src/tensor_algebra.rs index a5d7ea5a7..0832803af 100644 --- a/crates/core/src/tensor_algebra.rs +++ b/crates/core/src/tensor_algebra.rs @@ -10,7 +10,6 @@ use std::{ use binius_field::{ square_transpose, util::inner_product_unchecked, ExtensionField, Field, PackedExtension, }; -use binius_utils::checked_arithmetics::checked_log_2; /// An element of the tensor algebra defined as the tensor product of `FE` and `FE` as fields. /// @@ -64,7 +63,7 @@ where /// Returns $\kappa$, the base-2 logarithm of the extension degree. pub const fn kappa() -> usize { - checked_log_2(FE::DEGREE) + FE::LOG_DEGREE } /// Returns the byte size of an element. diff --git a/crates/ntt/src/additive_ntt.rs b/crates/ntt/src/additive_ntt.rs index a48d7bb9d..a4828d5fa 100644 --- a/crates/ntt/src/additive_ntt.rs +++ b/crates/ntt/src/additive_ntt.rs @@ -1,7 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{ExtensionField, PackedField, RepackedExtension}; -use binius_utils::checked_arithmetics::log2_strict_usize; use super::error::Error; @@ -51,12 +50,7 @@ pub trait AdditiveNTT { PE: RepackedExtension

, PE::Scalar: ExtensionField, { - if !PE::Scalar::DEGREE.is_power_of_two() { - return Err(Error::PowerOfTwoExtensionDegreeRequired); - } - - let log_batch_size = log2_strict_usize(PE::Scalar::DEGREE); - self.forward_transform(PE::cast_bases_mut(data), coset, log_batch_size) + self.forward_transform(PE::cast_bases_mut(data), coset, PE::Scalar::LOG_DEGREE) } fn inverse_transform_ext(&self, data: &mut [PE], coset: u32) -> Result<(), Error> @@ -64,11 +58,6 @@ pub trait AdditiveNTT { PE: RepackedExtension

, PE::Scalar: ExtensionField, { - if !PE::Scalar::DEGREE.is_power_of_two() { - return Err(Error::PowerOfTwoExtensionDegreeRequired); - } - - let log_batch_size = log2_strict_usize(PE::Scalar::DEGREE); - self.inverse_transform(PE::cast_bases_mut(data), coset, log_batch_size) + self.inverse_transform(PE::cast_bases_mut(data), coset, PE::Scalar::LOG_DEGREE) } } From d4134dd602d137ea8a23533c946b7643bf106111 Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Wed, 12 Feb 2025 17:45:16 +0100 Subject: [PATCH 21/50] [field] Add TowerField::min_tower_level(self), and use it to derive ArithExpr tower_level from its constants (#6) In contrast to TowerField::TOWER_LEVEL, TowerField::binary_tower_level(self) returns the smallest tower level that can fit the current value. This can be useful for shrinking field values to the smaller container that fits them, for the purpose of making arithmetic operations (in particular multiplication) cheaper. --- crates/core/src/polynomial/arith_circuit.rs | 23 +++++++------ crates/field/src/aes_field.rs | 7 ++++ crates/field/src/binary_field.rs | 36 +++++++++++++++++---- crates/field/src/binary_field_arithmetic.rs | 4 +++ crates/field/src/polyval.rs | 7 ++++ crates/math/src/arith_expr.rs | 15 ++++++++- 6 files changed, 75 insertions(+), 17 deletions(-) diff --git a/crates/core/src/polynomial/arith_circuit.rs b/crates/core/src/polynomial/arith_circuit.rs index 3d7e1862b..69b4a4224 100644 --- a/crates/core/src/polynomial/arith_circuit.rs +++ b/crates/core/src/polynomial/arith_circuit.rs @@ -119,12 +119,14 @@ pub struct ArithCircuitPoly { retval: CircuitStepArgument, degree: usize, n_vars: usize, + tower_level: usize, } -impl ArithCircuitPoly { +impl ArithCircuitPoly { pub fn new(expr: ArithExpr) -> Self { let degree = expr.degree(); let n_vars = expr.n_vars(); + let tower_level = expr.binary_tower_level(); let (exprs, retval) = circuit_steps_for_expr(&expr); Self { @@ -133,6 +135,7 @@ impl ArithCircuitPoly { retval, degree, n_vars, + tower_level, } } @@ -142,6 +145,7 @@ impl ArithCircuitPoly { /// arithmetic expression. pub fn with_n_vars(n_vars: usize, expr: ArithExpr) -> Result { let degree = expr.degree(); + let tower_level = expr.binary_tower_level(); if n_vars < expr.n_vars() { return Err(Error::IncorrectNumberOfVariables { expected: expr.n_vars(), @@ -156,6 +160,7 @@ impl ArithCircuitPoly { retval, n_vars, degree, + tower_level, }) } } @@ -170,7 +175,7 @@ impl CompositionPoly for ArithCircuitPoly { } fn binary_tower_level(&self) -> usize { - F::TOWER_LEVEL + self.tower_level } fn expression>(&self) -> ArithExpr { @@ -501,7 +506,7 @@ mod tests { let circuit = ArithCircuitPoly::::new(expr); let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + assert_eq!(typed_circuit.binary_tower_level(), 0); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -529,7 +534,7 @@ mod tests { let circuit = ArithCircuitPoly::::new(expr); let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + assert_eq!(typed_circuit.binary_tower_level(), 3); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -549,7 +554,7 @@ mod tests { let circuit = ArithCircuitPoly::::new(expr); let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + assert_eq!(typed_circuit.binary_tower_level(), 3); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -575,7 +580,7 @@ mod tests { let circuit = ArithCircuitPoly::::new(expr); let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + assert_eq!(typed_circuit.binary_tower_level(), 0); assert_eq!(typed_circuit.degree(), 13); assert_eq!(typed_circuit.n_vars(), 1); @@ -601,7 +606,7 @@ mod tests { let circuit = ArithCircuitPoly::::new(expr); let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + assert_eq!(typed_circuit.binary_tower_level(), 3); assert_eq!(typed_circuit.degree(), 3); assert_eq!(typed_circuit.n_vars(), 2); @@ -722,7 +727,7 @@ mod tests { assert_eq!(circuit.steps.len(), 1); let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + assert_eq!(typed_circuit.binary_tower_level(), 1); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -749,7 +754,7 @@ mod tests { assert_eq!(circuit.steps.len(), 1); let typed_circuit: &dyn CompositionPolyOS

= &circuit; - assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); + assert_eq!(typed_circuit.binary_tower_level(), 0); assert_eq!(typed_circuit.degree(), 24); assert_eq!(typed_circuit.n_vars(), 1); diff --git a/crates/field/src/aes_field.rs b/crates/field/src/aes_field.rs index 61e261de0..f37104509 100644 --- a/crates/field/src/aes_field.rs +++ b/crates/field/src/aes_field.rs @@ -78,6 +78,13 @@ impl_arithmetic_using_packed!(AESTowerField128b); impl TowerField for AESTowerField8b { type Canonical = BinaryField8b; + fn min_tower_level(self) -> usize { + match self { + Self::ZERO | Self::ONE => 0, + _ => 3, + } + } + fn mul_primitive(self, iota: usize) -> Result { match iota { 0..=1 => Ok(self * ISOMORPHIC_ALPHAS[iota]), diff --git a/crates/field/src/binary_field.rs b/crates/field/src/binary_field.rs index 0707d04b6..6f2c6e6f0 100644 --- a/crates/field/src/binary_field.rs +++ b/crates/field/src/binary_field.rs @@ -45,6 +45,13 @@ where /// Currently for every tower field, the canonical field is Fan-Paar's binary field of the same degree. type Canonical: TowerField + SerializeBytes + DeserializeBytes; + /// Returns the smallest valid `TOWER_LEVEL` in the tower that can fit the same value. + /// + /// Since which `TOWER_LEVEL` values are valid depends on the tower, + /// `F::Canonical::from(elem).min_tower_level()` can return a different result + /// from `elem.min_tower_level()`. + fn min_tower_level(self) -> usize; + fn basis(iota: usize, i: usize) -> Result { if iota > Self::TOWER_LEVEL { return Err(Error::ExtensionDegreeTooHigh); @@ -625,10 +632,16 @@ pub(super) trait MulPrimitive: Sized { #[macro_export] macro_rules! binary_tower { - ($subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty)) => { - binary_tower!($subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ, $name)); + (BinaryField1b($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty $(, $canonical:ident)?) $(< $extfield_name:ident($extfield_typ:ty $(, $canonical_ext:ident)?))+) => { + binary_tower!([BinaryField1b::TOWER_LEVEL]; BinaryField1b($subfield_typ $(, $canonical_subfield)?) < $name($typ $(, $canonical)?) $(< $extfield_name($extfield_typ $(, $canonical_ext)?))+); + }; + ($subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty $(, $canonical:ident)?) $(< $extfield_name:ident($extfield_typ:ty $(, $canonical_ext:ident)?))+) => { + binary_tower!([BinaryField1b::TOWER_LEVEL, $subfield_name::TOWER_LEVEL]; $subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ $(, $canonical)?) $(< $extfield_name($extfield_typ $(, $canonical_ext)?))+); + }; + ([$($valid_tower_levels:tt)*]; $subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty)) => { + binary_tower!([$($valid_tower_levels)*]; $subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ, $name)); }; - ($subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty, $canonical:ident)) => { + ([$($valid_tower_levels:tt)*]; $subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty, $canonical:ident)) => { impl From<$name> for ($subfield_name, $subfield_name) { #[inline] fn from(src: $name) -> ($subfield_name, $subfield_name) { @@ -652,6 +665,16 @@ macro_rules! binary_tower { type Canonical = $canonical; + fn min_tower_level(self) -> usize { + let zero = <$typ as $crate::underlier::UnderlierWithBitOps>::ZERO; + for level in [$($valid_tower_levels)*] { + if self.0 >> (1 << level) == zero { + return level; + } + } + Self::TOWER_LEVEL + } + fn mul_primitive(self, iota: usize) -> Result { ::mul_primitive(self, iota) } @@ -663,14 +686,13 @@ macro_rules! binary_tower { binary_tower!($subfield_name($subfield_typ) < @1 => $name($typ)); }; - ($subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty $(, $canonical:ident)?) $(< $extfield_name:ident($extfield_typ:ty $(, $canonical_ext:ident)?))+) => { - binary_tower!($subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ $(, $canonical)?)); - binary_tower!($name($typ $(, $canonical)?) $(< $extfield_name($extfield_typ $(, $canonical_ext)?))+); + ([$($valid_tower_levels:tt)*]; $subfield_name:ident($subfield_typ:ty $(, $canonical_subfield:ident)?) < $name:ident($typ:ty $(, $canonical:ident)?) $(< $extfield_name:ident($extfield_typ:ty $(, $canonical_ext:ident)?))+) => { + binary_tower!([$($valid_tower_levels)*]; $subfield_name($subfield_typ $(, $canonical_subfield)?) < $name($typ $(, $canonical)?)); + binary_tower!([$($valid_tower_levels)*, $name::TOWER_LEVEL]; $name($typ $(, $canonical)?) $(< $extfield_name($extfield_typ $(, $canonical_ext)?))+); binary_tower!($subfield_name($subfield_typ) < @2 => $($extfield_name($extfield_typ))<+); }; ($subfield_name:ident($subfield_typ:ty) < @$log_degree:expr => $name:ident($typ:ty)) => { $crate::binary_field::impl_field_extension!($subfield_name($subfield_typ) < @$log_degree => $name($typ)); - $crate::binary_field::binary_tower_subfield_mul!($subfield_name, $name); }; ($subfield_name:ident($subfield_typ:ty) < @$log_degree:expr => $name:ident($typ:ty) $(< $extfield_name:ident($extfield_typ:ty))+) => { diff --git a/crates/field/src/binary_field_arithmetic.rs b/crates/field/src/binary_field_arithmetic.rs index 0e913f506..481b0af17 100644 --- a/crates/field/src/binary_field_arithmetic.rs +++ b/crates/field/src/binary_field_arithmetic.rs @@ -61,6 +61,10 @@ pub(crate) use impl_arithmetic_using_packed; impl TowerField for BinaryField1b { type Canonical = Self; + fn min_tower_level(self) -> usize { + 0 + } + #[inline] fn mul_primitive(self, _: usize) -> Result { Err(crate::Error::ExtensionDegreeMismatch) diff --git a/crates/field/src/polyval.rs b/crates/field/src/polyval.rs index 0337ddede..bdc15111f 100644 --- a/crates/field/src/polyval.rs +++ b/crates/field/src/polyval.rs @@ -457,6 +457,13 @@ impl BinaryField for BinaryField128bPolyval { impl TowerField for BinaryField128bPolyval { type Canonical = BinaryField128b; + fn min_tower_level(self) -> usize { + match self { + Self::ZERO | Self::ONE => 0, + _ => 7, + } + } + fn mul_primitive(self, _iota: usize) -> Result { // This method could be implemented by multiplying by isomorphic alpha value // But it's not being used as for now diff --git a/crates/math/src/arith_expr.rs b/crates/math/src/arith_expr.rs index 6607901f2..3c0e4adcb 100644 --- a/crates/math/src/arith_expr.rs +++ b/crates/math/src/arith_expr.rs @@ -7,7 +7,7 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; -use binius_field::{Field, PackedField}; +use binius_field::{Field, PackedField, TowerField}; use super::error::Error; @@ -197,6 +197,19 @@ impl ArithExpr { } } +impl ArithExpr { + pub fn binary_tower_level(&self) -> usize { + match self { + Self::Const(value) => value.min_tower_level(), + Self::Var(_) => 0, + Self::Add(left, right) | Self::Mul(left, right) => { + max(left.binary_tower_level(), right.binary_tower_level()) + } + Self::Pow(base, _) => base.binary_tower_level(), + } + } +} + impl Default for ArithExpr where F: Field, From d9b4199b1f517b2c2910054725bd26d5b6e88751 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Thu, 13 Feb 2025 09:23:59 +0100 Subject: [PATCH 22/50] [core]: simplify merkle tree `verify_opening` (#14) --- crates/core/src/merkle_tree/scheme.rs | 36 +++++++++++---------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/crates/core/src/merkle_tree/scheme.rs b/crates/core/src/merkle_tree/scheme.rs index 3d27cc32c..63ef48380 100644 --- a/crates/core/src/merkle_tree/scheme.rs +++ b/crates/core/src/merkle_tree/scheme.rs @@ -108,42 +108,36 @@ where fn verify_opening( &self, - index: usize, + mut index: usize, values: &[F], layer_depth: usize, tree_depth: usize, layer_digests: &[Self::Digest], proof: &mut TranscriptReader, ) -> Result<(), Error> { - if 1 << layer_depth != layer_digests.len() { - bail!(VerificationError::IncorrectVectorLength) + if (1 << layer_depth) != layer_digests.len() { + bail!(VerificationError::IncorrectVectorLength); } - if index > (1 << tree_depth) - 1 { + if index >= (1 << tree_depth) { bail!(Error::IndexOutOfRange { - max: (1 << tree_depth) - 1, + max: (1 << tree_depth) - 1 }); } - let leaf_digest = hash_field_elems::<_, H>(values); - let branch = proof.read_vec(tree_depth - layer_depth)?; - - let mut index = index; - let root = branch.into_iter().fold(leaf_digest, |node, branch_node| { - let next_node = if index & 1 == 0 { - self.compression.compress([node, branch_node]) + let mut leaf_digest = hash_field_elems::<_, H>(values); + for branch_node in proof.read_vec(tree_depth - layer_depth)? { + leaf_digest = self.compression.compress(if index & 1 == 0 { + [leaf_digest, branch_node] } else { - self.compression.compress([branch_node, node]) - }; + [branch_node, leaf_digest] + }); index >>= 1; - next_node - }); - - if root == layer_digests[index] { - Ok(()) - } else { - bail!(VerificationError::InvalidProof) } + + (leaf_digest == layer_digests[index]) + .then_some(()) + .ok_or_else(|| VerificationError::InvalidProof.into()) } } From 8d5ed7a7f38b2a2f343608d3b9a40ea9b1e98e27 Mon Sep 17 00:00:00 2001 From: Milos Backonja <35807060+milosbackonja@users.noreply.github.com> Date: Thu, 13 Feb 2025 13:16:57 +0100 Subject: [PATCH 23/50] [ci] Adjusting nightly benchmark repository (#23) * [ci]: Adjusting nightly benchmark repository * [ci]: Adjusting CODEOWNERS for .github/ subdir --- .github/workflows/benchmark.yml | 9 +-------- CODEOWNERS | 2 +- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index f50554c1a..0c5fef7ea 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -29,15 +29,8 @@ jobs: actions: write runs-on: ${{ github.event_name == 'push' && github.ref_name == 'main' && 'c7a-4xlarge' || github.event.inputs.ec2_instance_type }} steps: - - name: Checkout Private GitLab Repository # Will be replaced with actual repository + - name: Checkout Repository uses: actions/checkout@v4 - with: - repository: ulvetanna/binius - github-server-url: https://gitlab.com - ref: anexj/benchmark_script - ssh-key: ${{ secrets.GITLAB_SSH_KEY }} - ssh-known-hosts: | - gitlab.com ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCsj2bNKTBSpIYDEGk9KxsGh3mySTRgMtXL583qmBpzeQ+jqCMRgBqB98u3z++J1sKlXHWfM9dyhSevkMwSbhoR8XIq/U0tCNyokEi/ueaBMCvbcTHhO7FcwzY92WK4Yt0aGROY5qX2UKSeOvuP4D6TPqKF1onrSzH9bx9XUf2lEdWT/ia1NEKjunUqu1xOB/StKDHMoX4/OKyIzuS0q/T1zOATthvasJFoPrAjkohTyaDUz2LN5JoH839hViyEG82yB+MjcFV5MU3N1l1QL3cVUCh93xSaua1N85qivl+siMkPGbO5xR/En4iEY6K2XPASUEMaieWVNTRCtJ4S8H+9 - name: Setup Bencher uses: bencherdev/bencher@main - name: Create Output Directory diff --git a/CODEOWNERS b/CODEOWNERS index 29e34188a..61d772514 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -2,4 +2,4 @@ * @jimpo-ulvt @onesk # Code owners for the .github path -.github/* @IrreducibleOSS/Infrastructure +/.github/ @IrreducibleOSS/Infrastructure From 543506ec24e6db426d37207e4f5b4a348bed582d Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Thu, 13 Feb 2025 15:58:14 +0100 Subject: [PATCH 24/50] [circuits] Simplify usage of ConstraintSystemBuilder by making it less generic (#22) [circuits] Simplfy ConstraintSystemBuilder to only support BinaryField128b for the top field. --- crates/circuits/src/arithmetic/u32.rs | 141 ++--- crates/circuits/src/bitwise.rs | 63 +- .../circuits/src/builder/constraint_system.rs | 27 +- crates/circuits/src/builder/mod.rs | 1 + crates/circuits/src/builder/types.rs | 4 + crates/circuits/src/builder/witness.rs | 108 ++-- crates/circuits/src/collatz.rs | 54 +- crates/circuits/src/groestl.rs | 25 + crates/circuits/src/keccakf.rs | 63 +- crates/circuits/src/lasso/batch.rs | 20 +- .../lasso/big_integer_ops/byte_sliced_add.rs | 35 +- .../byte_sliced_add_carryfree.rs | 35 +- ...yte_sliced_double_conditional_increment.rs | 35 +- .../byte_sliced_modular_mul.rs | 50 +- .../lasso/big_integer_ops/byte_sliced_mul.rs | 45 +- .../big_integer_ops/byte_sliced_test_utils.rs | 77 +-- .../circuits/src/lasso/big_integer_ops/mod.rs | 53 ++ crates/circuits/src/lasso/lasso.rs | 23 +- .../src/lasso/lookups/u8_arithmetic.rs | 213 +++++-- crates/circuits/src/lasso/sha256.rs | 135 +++-- crates/circuits/src/lasso/u32add.rs | 128 ++-- .../lasso/u8_double_conditional_increment.rs | 30 +- crates/circuits/src/lasso/u8add.rs | 30 +- crates/circuits/src/lasso/u8add_carryfree.rs | 30 +- crates/circuits/src/lasso/u8mul.rs | 36 +- crates/circuits/src/lib.rs | 555 +----------------- crates/circuits/src/pack.rs | 17 +- crates/circuits/src/plain_lookup.rs | 128 +++- crates/circuits/src/sha256.rs | 108 +++- crates/circuits/src/transparent.rs | 39 +- crates/circuits/src/u32fib.rs | 42 +- crates/circuits/src/unconstrained.rs | 15 +- crates/circuits/src/vision.rs | 82 +-- examples/Cargo.toml | 5 - examples/b32_mul.rs | 11 +- examples/bitwise_ops.rs | 13 +- examples/collatz.rs | 10 +- ...circuit.rs => groestl_circuit.rs.disabled} | 0 examples/keccakf_circuit.rs | 6 +- examples/modular_mul.rs | 11 +- examples/sha256_circuit.rs | 16 +- examples/sha256_circuit_with_lookup.rs | 18 +- examples/u32_add.rs | 14 +- examples/u32_mul.rs | 20 +- examples/u32add_with_lookup.rs | 14 +- examples/u8mul.rs | 13 +- examples/vision32b_circuit.rs | 9 +- 47 files changed, 1121 insertions(+), 1486 deletions(-) create mode 100644 crates/circuits/src/builder/types.rs rename examples/{groestl_circuit.rs => groestl_circuit.rs.disabled} (100%) diff --git a/crates/circuits/src/arithmetic/u32.rs b/crates/circuits/src/arithmetic/u32.rs index 88a0c9a70..49037356d 100644 --- a/crates/circuits/src/arithmetic/u32.rs +++ b/crates/circuits/src/arithmetic/u32.rs @@ -1,25 +1,17 @@ // Copyright 2024-2025 Irreducible Inc. use binius_core::oracle::{OracleId, ProjectionVariant, ShiftVariant}; -use binius_field::{ - as_packed_field::PackScalar, packed::set_packed_slice, BinaryField1b, BinaryField32b, - ExtensionField, Field, TowerField, -}; +use binius_field::{packed::set_packed_slice, BinaryField1b, BinaryField32b, Field, TowerField}; use binius_macros::arith_expr; use binius_maybe_rayon::prelude::*; -use bytemuck::Pod; use crate::{builder::ConstraintSystemBuilder, transparent}; -pub fn packed( - builder: &mut ConstraintSystemBuilder, +pub fn packed( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, -) -> Result -where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, -{ +) -> Result { let packed = builder.add_packed(name, input, 5)?; if let Some(witness) = builder.witness() { witness.set( @@ -32,17 +24,13 @@ where Ok(packed) } -pub fn mul_const( - builder: &mut ConstraintSystemBuilder, +pub fn mul_const( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, value: u32, flags: super::Flags, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { if value == 0 { let log_rows = builder.log_rows([input])?; return transparent::constant(builder, name, log_rows, BinaryField1b::ZERO); @@ -85,17 +73,13 @@ where Ok(result) } -pub fn add( - builder: &mut ConstraintSystemBuilder, +pub fn add( + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, flags: super::Flags, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([xin, yin])?; let cout = builder.add_committed("cout", log_rows, BinaryField1b::TOWER_LEVEL); @@ -151,17 +135,13 @@ where Ok(zout) } -pub fn sub( - builder: &mut ConstraintSystemBuilder, +pub fn sub( + builder: &mut ConstraintSystemBuilder, name: impl ToString, zin: OracleId, yin: OracleId, flags: super::Flags, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([zin, yin])?; let cout = builder.add_committed("cout", log_rows, BinaryField1b::TOWER_LEVEL); @@ -218,16 +198,12 @@ where Ok(xout) } -pub fn half( - builder: &mut ConstraintSystemBuilder, +pub fn half( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, flags: super::Flags, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { if matches!(flags, super::Flags::Checked) { // Assert that the number is even let lsb = select_bit(builder, "lsb", input, 0)?; @@ -236,23 +212,24 @@ where shr(builder, name, input, 1) } -pub fn shl( - builder: &mut ConstraintSystemBuilder, +pub fn shl( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, offset: usize, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { if offset == 0 { return Ok(input); } let shifted = builder.add_shifted(name, input, offset, 5, ShiftVariant::LogicalLeft)?; if let Some(witness) = builder.witness() { - (witness.new_column(shifted).as_mut_slice::(), witness.get(input)?.as_slice::()) + ( + witness + .new_column::(shifted) + .as_mut_slice::(), + witness.get::(input)?.as_slice::(), + ) .into_par_iter() .for_each(|(shifted, input)| *shifted = *input << offset); } @@ -260,23 +237,24 @@ where Ok(shifted) } -pub fn shr( - builder: &mut ConstraintSystemBuilder, +pub fn shr( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, offset: usize, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { if offset == 0 { return Ok(input); } let shifted = builder.add_shifted(name, input, offset, 5, ShiftVariant::LogicalRight)?; if let Some(witness) = builder.witness() { - (witness.new_column(shifted).as_mut_slice::(), witness.get(input)?.as_slice::()) + ( + witness + .new_column::(shifted) + .as_mut_slice::(), + witness.get::(input)?.as_slice::(), + ) .into_par_iter() .for_each(|(shifted, input)| *shifted = *input >> offset); } @@ -284,16 +262,12 @@ where Ok(shifted) } -pub fn select_bit( - builder: &mut ConstraintSystemBuilder, +pub fn select_bit( + builder: &mut ConstraintSystemBuilder, name: impl ToString, input: OracleId, index: usize, -) -> Result -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { let log_rows = builder.log_rows([input])?; anyhow::ensure!(log_rows >= 5, "Polynomial must have n_vars >= 5. Got {log_rows}"); anyhow::ensure!(index < 32, "Only index values between 0 and 32 are allowed. Got {index}"); @@ -304,7 +278,7 @@ where if let Some(witness) = builder.witness() { let mut bits = witness.new_column::(bits); let bits = bits.packed(); - let input = witness.get(input)?.as_slice::(); + let input = witness.get::(input)?.as_slice::(); input.iter().enumerate().for_each(|(i, &val)| { let value = match (val >> index) & 1 { 0 => BinaryField1b::ZERO, @@ -317,16 +291,12 @@ where Ok(bits) } -pub fn constant( - builder: &mut ConstraintSystemBuilder, +pub fn constant( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_count: usize, value: u32, -) -> Result -where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); // This would not need to be committed if we had `builder.add_unpacked(..)` let output = builder.add_committed("output", log_count + 5, BinaryField1b::TOWER_LEVEL); @@ -361,17 +331,14 @@ where #[cfg(test)] mod tests { use binius_core::constraint_system::validate::validate_witness; - use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b, TowerField}; + use binius_field::{BinaryField1b, TowerField}; use crate::{arithmetic, builder::ConstraintSystemBuilder, unconstrained::unconstrained}; - type U = OptimalUnderlier; - type F = BinaryField128b; - #[test] fn test_mul_const() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let a = builder.add_committed("a", 5, BinaryField1b::TOWER_LEVEL); if let Some(witness) = builder.witness() { @@ -391,13 +358,29 @@ mod tests { validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } + #[test] + fn test_add() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let log_size = 14; + let a = unconstrained::(&mut builder, "a", log_size).unwrap(); + let b = unconstrained::(&mut builder, "b", log_size).unwrap(); + let _c = arithmetic::u32::add(&mut builder, "u32add", a, b, arithmetic::Flags::Unchecked) + .unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } + #[test] fn test_sub() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let a = unconstrained::(&mut builder, "a", 7).unwrap(); - let b = unconstrained::(&mut builder, "a", 7).unwrap(); + let a = unconstrained::(&mut builder, "a", 7).unwrap(); + let b = unconstrained::(&mut builder, "a", 7).unwrap(); let _c = arithmetic::u32::sub(&mut builder, "c", a, b, arithmetic::Flags::Unchecked).unwrap(); diff --git a/crates/circuits/src/bitwise.rs b/crates/circuits/src/bitwise.rs index 33e2d3a11..5a7113541 100644 --- a/crates/circuits/src/bitwise.rs +++ b/crates/circuits/src/bitwise.rs @@ -1,25 +1,18 @@ // Copyright 2024-2025 Irreducible Inc. use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, TowerField, -}; +use binius_field::{BinaryField1b, Field, TowerField}; use binius_macros::arith_expr; use binius_maybe_rayon::prelude::*; -use bytemuck::Pod; use crate::builder::ConstraintSystemBuilder; -pub fn and( - builder: &mut ConstraintSystemBuilder, +pub fn and( + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([xin, yin])?; let zout = builder.add_committed("zout", log_rows, BinaryField1b::TOWER_LEVEL); @@ -45,19 +38,16 @@ where Ok(zout) } -pub fn xor( - builder: &mut ConstraintSystemBuilder, +pub fn xor( + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([xin, yin])?; - let zout = builder.add_linear_combination("zout", log_rows, [(xin, F::ONE), (yin, F::ONE)])?; + let zout = + builder.add_linear_combination("zout", log_rows, [(xin, Field::ONE), (yin, Field::ONE)])?; if let Some(witness) = builder.witness() { ( witness.get::(xin)?.as_slice::(), @@ -75,16 +65,12 @@ where Ok(zout) } -pub fn or( - builder: &mut ConstraintSystemBuilder, +pub fn or( + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar, - F: TowerField, -{ +) -> Result { builder.push_namespace(name); let log_rows = builder.log_rows([xin, yin])?; let zout = builder.add_committed("zout", log_rows, BinaryField1b::TOWER_LEVEL); @@ -109,3 +95,28 @@ where builder.pop_namespace(); Ok(zout) } + +#[cfg(test)] +mod tests { + use binius_core::constraint_system::validate::validate_witness; + use binius_field::BinaryField1b; + + use crate::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; + + #[test] + fn test_bitwise() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let log_size = 6; + let a = unconstrained::(&mut builder, "a", log_size).unwrap(); + let b = unconstrained::(&mut builder, "b", log_size).unwrap(); + let _and = super::and(&mut builder, "and", a, b).unwrap(); + let _xor = super::xor(&mut builder, "xor", a, b).unwrap(); + let _or = super::or(&mut builder, "or", a, b).unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } +} diff --git a/crates/circuits/src/builder/constraint_system.rs b/crates/circuits/src/builder/constraint_system.rs index ba8a2c865..e4010cafa 100644 --- a/crates/circuits/src/builder/constraint_system.rs +++ b/crates/circuits/src/builder/constraint_system.rs @@ -16,35 +16,28 @@ use binius_core::{ transparent::step_down::StepDown, witness::MultilinearExtensionIndex, }; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, TowerField, -}; +use binius_field::{as_packed_field::PackScalar, BinaryField1b}; use binius_math::ArithExpr; use binius_utils::bail; -use crate::builder::witness; +use crate::builder::{ + types::{F, U}, + witness, +}; #[derive(Default)] -pub struct ConstraintSystemBuilder<'arena, U, F> -where - U: UnderlierType + PackScalar, - F: TowerField, -{ +pub struct ConstraintSystemBuilder<'arena> { oracles: Rc>>, constraints: ConstraintSetBuilder, non_zero_oracle_ids: Vec, flushes: Vec, step_down_dedup: HashMap<(usize, usize), OracleId>, - witness: Option>, + witness: Option>, next_channel_id: ChannelId, namespace_path: Vec, } -impl<'arena, U, F> ConstraintSystemBuilder<'arena, U, F> -where - U: UnderlierType + PackScalar, - F: TowerField, -{ +impl<'arena> ConstraintSystemBuilder<'arena> { pub fn new() -> Self { Self::default() } @@ -79,7 +72,7 @@ where }) } - pub fn witness(&mut self) -> Option<&mut witness::Builder<'arena, U, F>> { + pub fn witness(&mut self) -> Option<&mut witness::Builder<'arena>> { self.witness.as_mut() } @@ -354,7 +347,7 @@ where /// /// let log_size = 14; /// - /// let mut builder = ConstraintSystemBuilder::::new(); + /// let mut builder = ConstraintSystemBuilder::new(); /// builder.push_namespace("a"); /// let x = builder.add_committed("x", log_size, BinaryField1b::TOWER_LEVEL); /// builder.push_namespace("b"); diff --git a/crates/circuits/src/builder/mod.rs b/crates/circuits/src/builder/mod.rs index c0b59130b..9c706fe3f 100644 --- a/crates/circuits/src/builder/mod.rs +++ b/crates/circuits/src/builder/mod.rs @@ -1,6 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. pub mod constraint_system; +pub mod types; pub mod witness; pub use constraint_system::ConstraintSystemBuilder; diff --git a/crates/circuits/src/builder/types.rs b/crates/circuits/src/builder/types.rs new file mode 100644 index 000000000..ae33e7e4f --- /dev/null +++ b/crates/circuits/src/builder/types.rs @@ -0,0 +1,4 @@ +// Copyright 2025 Irreducible Inc. + +pub type F = binius_field::BinaryField128b; +pub type U = binius_field::arch::OptimalUnderlier; diff --git a/crates/circuits/src/builder/witness.rs b/crates/circuits/src/builder/witness.rs index 2209a36f4..77e83868d 100644 --- a/crates/circuits/src/builder/witness.rs +++ b/crates/circuits/src/builder/witness.rs @@ -10,35 +10,33 @@ use binius_core::{ use binius_field::{ as_packed_field::{PackScalar, PackedType}, underlier::WithUnderlier, - ExtensionField, Field, PackedField, TowerField, + ExtensionField, PackedField, TowerField, }; use binius_math::MultilinearExtension; use binius_utils::bail; use bytemuck::{must_cast_slice, must_cast_slice_mut, Pod}; -pub struct Builder<'arena, U: PackScalar, FW: TowerField> { +use super::types::{F, U}; + +pub struct Builder<'arena> { bump: &'arena bumpalo::Bump, - oracles: Rc>>, + oracles: Rc>>, #[allow(clippy::type_complexity)] - entries: Rc>>>>, + entries: Rc>>>>, } -struct WitnessBuilderEntry<'arena, U: PackScalar, FW: Field> { - witness: Result>, binius_math::Error>, +struct WitnessBuilderEntry<'arena> { + witness: Result>, binius_math::Error>, tower_level: usize, data: &'arena [U], } -impl<'arena, U, FW> Builder<'arena, U, FW> -where - U: PackScalar, - FW: TowerField, -{ +impl<'arena> Builder<'arena> { pub fn new( allocator: &'arena bumpalo::Bump, - oracles: Rc>>, + oracles: Rc>>, ) -> Self { Self { bump: allocator, @@ -47,10 +45,10 @@ where } } - pub fn new_column(&self, id: OracleId) -> EntryBuilder<'arena, U, FW, FS> + pub fn new_column(&self, id: OracleId) -> EntryBuilder<'arena, FS> where U: PackScalar, - FW: ExtensionField, + F: ExtensionField, { let oracles = self.oracles.borrow(); let log_rows = oracles.n_vars(id); @@ -69,10 +67,10 @@ where &self, id: OracleId, default: FS, - ) -> EntryBuilder<'arena, U, FW, FS> + ) -> EntryBuilder<'arena, FS> where U: PackScalar, - FW: ExtensionField, + F: ExtensionField, { let oracles = self.oracles.borrow(); let log_rows = oracles.n_vars(id); @@ -88,10 +86,11 @@ where } } - pub fn get(&self, id: OracleId) -> Result, Error> + pub fn get(&self, id: OracleId) -> Result, Error> where + FS: TowerField, U: PackScalar, - FW: ExtensionField, + F: ExtensionField, { let entries = self.entries.borrow(); let oracles = self.oracles.borrow(); @@ -122,11 +121,11 @@ where pub fn set( &self, id: OracleId, - entry: WitnessEntry<'arena, U, FS>, + entry: WitnessEntry<'arena, FS>, ) -> Result<(), Error> where U: PackScalar, - FW: ExtensionField, + F: ExtensionField, { let oracles = self.oracles.borrow(); if !oracles.is_valid_oracle_id(id) { @@ -145,7 +144,7 @@ where Ok(()) } - pub fn build(self) -> Result, Error> { + pub fn build(self) -> Result, Error> { let mut result = MultilinearExtensionIndex::new(); let entries = Rc::into_inner(self.entries) .ok_or_else(|| anyhow!("Failed to build. There are still entries refs. Make sure there are no pending column insertions."))? @@ -160,26 +159,37 @@ where } #[derive(Debug, Clone, Copy)] -pub struct WitnessEntry<'arena, U: PackScalar, FS: TowerField> { +pub struct WitnessEntry<'arena, FS: TowerField> +where + U: PackScalar, +{ data: &'arena [U], log_rows: usize, _marker: PhantomData, } -impl<'arena, U: PackScalar, FS: TowerField> WitnessEntry<'arena, U, FS> { +impl<'arena, FS: TowerField> WitnessEntry<'arena, FS> +where + U: PackScalar, +{ #[inline] pub fn packed(&self) -> &'arena [PackedType] { WithUnderlier::from_underliers_ref(self.data) } - pub const fn repacked(&self) -> WitnessEntry<'arena, U, FW> + #[inline] + pub const fn as_slice(&self) -> &'arena [T] { + must_cast_slice(self.data) + } + + pub const fn repacked(&self) -> WitnessEntry<'arena, FE> where - FW: TowerField + ExtensionField, - U: PackScalar, + FE: TowerField + ExtensionField, + U: PackScalar, { WitnessEntry { data: self.data, - log_rows: self.log_rows - >::LOG_DEGREE, + log_rows: self.log_rows - >::LOG_DEGREE, _marker: PhantomData, } } @@ -189,38 +199,36 @@ impl<'arena, U: PackScalar, FS: TowerField> WitnessEntry<'arena, U, FS> { } } -impl<'arena, U: PackScalar + Pod, FS: TowerField> WitnessEntry<'arena, U, FS> { - #[inline] - pub const fn as_slice(&self) -> &'arena [T] { - must_cast_slice(self.data) - } -} - -pub struct EntryBuilder<'arena, U, FW, FS> +pub struct EntryBuilder<'arena, FS> where - U: PackScalar + PackScalar, FS: TowerField, - FW: TowerField + ExtensionField, + U: PackScalar, + F: ExtensionField, { _marker: PhantomData, #[allow(clippy::type_complexity)] - entries: Rc>>>>, + entries: Rc>>>>, id: OracleId, log_rows: usize, data: Option<&'arena mut [U]>, } -impl EntryBuilder<'_, U, FW, FS> +impl EntryBuilder<'_, FS> where - U: PackScalar + PackScalar, FS: TowerField, - FW: TowerField + ExtensionField, + U: PackScalar, + F: ExtensionField, { #[inline] pub fn packed(&mut self) -> &mut [PackedType] { PackedType::::from_underliers_ref_mut(self.underliers()) } + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [T] { + must_cast_slice_mut(self.underliers()) + } + #[inline] fn underliers(&mut self) -> &mut [U] { self.data @@ -229,23 +237,11 @@ where } } -impl EntryBuilder<'_, U, FW, FS> -where - U: PackScalar + PackScalar + Pod, - FS: TowerField, - FW: TowerField + ExtensionField, -{ - #[inline] - pub fn as_mut_slice(&mut self) -> &mut [T] { - must_cast_slice_mut(self.underliers()) - } -} - -impl Drop for EntryBuilder<'_, U, FW, FS> +impl Drop for EntryBuilder<'_, FS> where - U: PackScalar + PackScalar, FS: TowerField, - FW: TowerField + ExtensionField, + U: PackScalar, + F: ExtensionField, { fn drop(&mut self) { let data = Option::take(&mut self.data).expect("data is always Some until this point"); diff --git a/crates/circuits/src/collatz.rs b/crates/circuits/src/collatz.rs index df1f2d052..04f3b3c11 100644 --- a/crates/circuits/src/collatz.rs +++ b/crates/circuits/src/collatz.rs @@ -10,7 +10,14 @@ use binius_field::{ use binius_macros::arith_expr; use bytemuck::Pod; -use crate::{arithmetic, builder::ConstraintSystemBuilder, transparent}; +use crate::{ + arithmetic, + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, + transparent, +}; pub type Advice = (usize, usize); @@ -37,9 +44,9 @@ impl Collatz { (self.evens.len(), self.odds.len()) } - pub fn build( + pub fn build( self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, advice: Advice, ) -> Result>, anyhow::Error> where @@ -58,16 +65,12 @@ impl Collatz { Ok(boundaries) } - fn even( + fn even( &self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, channel: ChannelId, count: usize, - ) -> Result<(), anyhow::Error> - where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, - { + ) -> Result<(), anyhow::Error> { let log_1b_rows = 5 + binius_utils::checked_arithmetics::log2_ceil_usize(count); let even = builder.add_committed("even", log_1b_rows, BinaryField1b::TOWER_LEVEL); if let Some(witness) = builder.witness() { @@ -90,16 +93,12 @@ impl Collatz { Ok(()) } - fn odd( + fn odd( &self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, channel: ChannelId, count: usize, - ) -> Result<(), anyhow::Error> - where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, - { + ) -> Result<(), anyhow::Error> { let log_32b_rows = binius_utils::checked_arithmetics::log2_ceil_usize(count); let log_1b_rows = 5 + log_32b_rows; @@ -136,10 +135,7 @@ impl Collatz { Ok(()) } - fn get_boundaries(&self, channel_id: usize) -> Vec> - where - F: TowerField + From, - { + fn get_boundaries(&self, channel_id: usize) -> Vec> { vec![ Boundary { channel_id, @@ -179,15 +175,11 @@ pub fn collatz_orbit(x0: u32) -> Vec { res } -pub fn ensure_odd( - builder: &mut ConstraintSystemBuilder, +pub fn ensure_odd( + builder: &mut ConstraintSystemBuilder, input: OracleId, count: usize, -) -> Result<(), anyhow::Error> -where - U: PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result<(), anyhow::Error> { let log_32b_rows = builder.log_rows([input])? - 5; let lsb = arithmetic::u32::select_bit(builder, "lsb", input, 0)?; let selector = transparent::step_down(builder, "count", log_32b_rows, count)?; @@ -202,17 +194,13 @@ where #[cfg(test)] mod tests { use binius_core::constraint_system::validate::validate_witness; - use binius_field::{arch::OptimalUnderlier, BinaryField128b}; use crate::{builder::ConstraintSystemBuilder, collatz::Collatz}; #[test] fn test_collatz() { let allocator = bumpalo::Bump::new(); - let mut builder = - ConstraintSystemBuilder::::new_with_witness( - &allocator, - ); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let x0 = 9999999; diff --git a/crates/circuits/src/groestl.rs b/crates/circuits/src/groestl.rs index d1ab36d39..598e6de16 100644 --- a/crates/circuits/src/groestl.rs +++ b/crates/circuits/src/groestl.rs @@ -378,3 +378,28 @@ fn s_box(x: AESTowerField8b) -> AESTowerField8b { let idx = u8::from(x) as usize; AESTowerField8b::from(S_BOX[idx]) } + +#[cfg(test)] +mod tests { + use binius_core::constraint_system::validate::validate_witness; + use binius_field::{arch::OptimalUnderlier, AESTowerField16b}; + + use super::groestl_p_permutation; + use crate::builder::ConstraintSystemBuilder; + + #[test] + fn test_groestl() { + let allocator = bumpalo::Bump::new(); + let mut builder = + ConstraintSystemBuilder::::new_with_witness( + &allocator, + ); + let log_size = 9; + let _state_out = groestl_p_permutation(&mut builder, log_size).unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } +} diff --git a/crates/circuits/src/keccakf.rs b/crates/circuits/src/keccakf.rs index 638a89e4f..fee7025f6 100644 --- a/crates/circuits/src/keccakf.rs +++ b/crates/circuits/src/keccakf.rs @@ -8,14 +8,19 @@ use binius_core::{ transparent::multilinear_extension::MultilinearExtensionTransparent, }; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::{UnderlierType, WithUnderlier}, - BinaryField1b, BinaryField64b, ExtensionField, PackedField, TowerField, + as_packed_field::PackedType, underlier::WithUnderlier, BinaryField1b, BinaryField64b, Field, + PackedField, TowerField, }; use binius_macros::arith_expr; use bytemuck::{pod_collect_to_vec, Pod}; -use crate::{builder::ConstraintSystemBuilder, transparent::step_down}; +use crate::{ + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, + transparent::step_down, +}; #[derive(Default, Clone, Copy)] pub struct KeccakfState(pub [u64; STATE_SIZE]); @@ -25,15 +30,11 @@ pub struct KeccakfOracles { pub output: [OracleId; STATE_SIZE], } -pub fn keccakf( - builder: &mut ConstraintSystemBuilder, +pub fn keccakf( + builder: &mut ConstraintSystemBuilder, input_witness: &Option>, log_size: usize, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar + PackScalar, - F: TowerField + ExtensionField, -{ +) -> Result { let internal_log_size = log_size + LOG_BIT_ROWS_PER_PERMUTATION; let round_consts_single: [OracleId; ROUNDS_PER_STATE_ROW] = array::try_from_fn(|round_within_row| { @@ -124,7 +125,7 @@ where builder.add_projected( "output", packed_state_out[xy], - vec![F::ONE; LOG_STATE_ROWS_PER_PERMUTATION], + vec![Field::ONE; LOG_STATE_ROWS_PER_PERMUTATION], ProjectionVariant::FirstVars, ) })?; @@ -135,7 +136,7 @@ where "c", internal_log_size, array::from_fn::<_, 5, _>(|offset| { - (state[round_within_row][x + 5 * offset], F::ONE) + (state[round_within_row][x + 5 * offset], Field::ONE) }), ) }) @@ -159,8 +160,8 @@ where "d", internal_log_size, [ - (c[round_within_row][(x + 4) % 5], F::ONE), - (c_shift[round_within_row][(x + 1) % 5], F::ONE), + (c[round_within_row][(x + 4) % 5], Field::ONE), + (c_shift[round_within_row][(x + 1) % 5], Field::ONE), ], ) }) @@ -174,8 +175,8 @@ where format!("a_theta[{xy}]"), internal_log_size, [ - (state[round_within_row][xy], F::ONE), - (d[round_within_row][x], F::ONE), + (state[round_within_row][xy], Field::ONE), + (d[round_within_row][x], Field::ONE), ], ) }) @@ -504,3 +505,31 @@ const KECCAKF_RC: [u64; ROUNDS_PER_PERMUTATION] = [ 0x0000000080000001, 0x8000000080008008, ]; + +#[cfg(test)] +mod tests { + use binius_core::constraint_system::validate::validate_witness; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + use super::KeccakfState; + use crate::builder::ConstraintSystemBuilder; + + #[test] + fn test_keccakf() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let log_size = 5; + + let mut rng = StdRng::seed_from_u64(0); + let input_states = vec![KeccakfState(rng.gen())]; + let _state_out = super::keccakf(&mut builder, &Some(input_states), log_size); + + let witness = builder.take_witness().unwrap(); + + let constraint_system = builder.build().unwrap(); + + let boundaries = vec![]; + + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } +} diff --git a/crates/circuits/src/lasso/batch.rs b/crates/circuits/src/lasso/batch.rs index 8fb6d0697..52c98e297 100644 --- a/crates/circuits/src/lasso/batch.rs +++ b/crates/circuits/src/lasso/batch.rs @@ -4,12 +4,15 @@ use anyhow::Ok; use binius_core::oracle::OracleId; use binius_field::{ as_packed_field::{PackScalar, PackedType}, - BinaryField1b, ExtensionField, PackedFieldIndexable, TowerField, + ExtensionField, PackedFieldIndexable, TowerField, }; use itertools::Itertools; use super::lasso::lasso; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; pub struct LookupBatch { lookup_us: Vec>, u_to_t_mappings: Vec>, @@ -48,19 +51,16 @@ impl LookupBatch { self.lookup_col_lens.push(lookup_u_col_len); } - pub fn execute( - mut self, - builder: &mut ConstraintSystemBuilder, - ) -> Result<(), anyhow::Error> + pub fn execute(mut self, builder: &mut ConstraintSystemBuilder) -> Result<(), anyhow::Error> where - U: PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, FC: TowerField, - F: ExtensionField + TowerField, + U: PackScalar, + F: ExtensionField, + PackedType: PackedFieldIndexable, { let channel = builder.add_channel(); - lasso::<_, _, FC>( + lasso::( builder, "batched lasso", &self.lookup_col_lens, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs index 43349e09c..0e274f402 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs @@ -3,14 +3,7 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b}; use crate::{ builder::ConstraintSystemBuilder, @@ -19,32 +12,16 @@ use crate::{ type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; -pub fn byte_sliced_add>( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_add>( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, x_in: &Level::Data, y_in: &Level::Data, carry_in: OracleId, log_size: usize, lookup_batch_add: &mut LookupBatch, -) -> Result<(OracleId, Level::Data), anyhow::Error> -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, - Level::Data: Sized, -{ +) -> Result<(OracleId, Level::Data), anyhow::Error> { if Level::WIDTH == 1 { let (carry_out, sum) = u8add(builder, lookup_batch_add, name, x_in[0], y_in[0], carry_in, log_size)?; @@ -58,7 +35,7 @@ where let (lower_half_x, upper_half_x) = Level::split(x_in); let (lower_half_y, upper_half_y) = Level::split(y_in); - let (internal_carry, lower_sum) = byte_sliced_add::<_, _, Level::Base>( + let (internal_carry, lower_sum) = byte_sliced_add::( builder, format!("lower sum {}b", Level::Base::WIDTH), lower_half_x, @@ -68,7 +45,7 @@ where lookup_batch_add, )?; - let (carry_out, upper_sum) = byte_sliced_add::<_, _, Level::Base>( + let (carry_out, upper_sum) = byte_sliced_add::( builder, format!("upper sum {}b", Level::Base::WIDTH), upper_half_x, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs index 881aa1b92..38e2703dd 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs @@ -3,14 +3,7 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b}; use super::byte_sliced_add; use crate::{ @@ -20,12 +13,10 @@ use crate::{ type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_add_carryfree>( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_add_carryfree>( + builder: &mut ConstraintSystemBuilder, name: impl ToString, x_in: &Level::Data, y_in: &Level::Data, @@ -33,21 +24,7 @@ pub fn byte_sliced_add_carryfree>( log_size: usize, lookup_batch_add: &mut LookupBatch, lookup_batch_add_carryfree: &mut LookupBatch, -) -> Result -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, - Level::Data: Sized, -{ +) -> Result { if Level::WIDTH == 1 { let sum = u8add_carryfree( builder, @@ -68,7 +45,7 @@ where let (lower_half_x, upper_half_x) = Level::split(x_in); let (lower_half_y, upper_half_y) = Level::split(y_in); - let (internal_carry, lower_sum) = byte_sliced_add::<_, _, Level::Base>( + let (internal_carry, lower_sum) = byte_sliced_add::( builder, format!("lower sum {}b", Level::Base::WIDTH), lower_half_x, @@ -78,7 +55,7 @@ where lookup_batch_add, )?; - let upper_sum = byte_sliced_add_carryfree::<_, _, Level::Base>( + let upper_sum = byte_sliced_add_carryfree::( builder, format!("upper sum {}b", Level::Base::WIDTH), upper_half_x, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs index b14baa3e1..ed697b8eb 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs @@ -3,14 +3,7 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b}; use crate::{ builder::ConstraintSystemBuilder, @@ -19,12 +12,10 @@ use crate::{ type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_double_conditional_increment>( - builder: &mut ConstraintSystemBuilder, +pub fn byte_sliced_double_conditional_increment>( + builder: &mut ConstraintSystemBuilder, name: impl ToString, x_in: &Level::Data, first_carry_in: OracleId, @@ -32,21 +23,7 @@ pub fn byte_sliced_double_conditional_increment Result<(OracleId, Level::Data), anyhow::Error> -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, - Level::Data: Sized, -{ +) -> Result<(OracleId, Level::Data), anyhow::Error> { if Level::WIDTH == 1 { let (carry_out, sum) = u8_double_conditional_increment( builder, @@ -66,7 +43,7 @@ where let (lower_half_x, upper_half_x) = Level::split(x_in); - let (internal_carry, lower_sum) = byte_sliced_double_conditional_increment::<_, _, Level::Base>( + let (internal_carry, lower_sum) = byte_sliced_double_conditional_increment::( builder, format!("lower sum {}b", Level::Base::WIDTH), lower_half_x, @@ -77,7 +54,7 @@ where lookup_batch_dci, )?; - let (carry_out, upper_sum) = byte_sliced_double_conditional_increment::<_, _, Level::Base>( + let (carry_out, upper_sum) = byte_sliced_double_conditional_increment::( builder, format!("upper sum {}b", Level::Base::WIDTH), upper_half_x, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs index ead52cccf..9677a3e4b 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs @@ -4,37 +4,27 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::{oracle::OracleId, transparent::constant::Constant}; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::{UnderlierType, WithUnderlier}, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, + tower_levels::TowerLevel, underlier::WithUnderlier, BinaryField32b, BinaryField8b, TowerField, }; use binius_macros::arith_expr; -use bytemuck::Pod; use super::{byte_sliced_add_carryfree, byte_sliced_mul}; use crate::{ - builder::ConstraintSystemBuilder, + builder::{types::F, ConstraintSystemBuilder}, lasso::{ batch::LookupBatch, lookups::u8_arithmetic::{add_carryfree_lookup, add_lookup, dci_lookup, mul_lookup}, }, }; -type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; #[allow(clippy::too_many_arguments)] pub fn byte_sliced_modular_mul< - U, - F, LevelIn: TowerLevel, LevelOut: TowerLevel, >( - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, mult_a: &LevelIn::Data, mult_b: &LevelIn::Data, @@ -42,21 +32,7 @@ pub fn byte_sliced_modular_mul< log_size: usize, zero_byte_oracle: OracleId, zero_carry_oracle: OracleId, -) -> Result -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, - ::Underlier: From, -{ +) -> Result { builder.push_namespace(name); let lookup_t_mul = mul_lookup(builder, "mul table")?; @@ -88,12 +64,14 @@ where "modulus", Constant::new( log_size, - F::from_underlier(>::into(modulus_input[byte_idx])), + ::from_underlier(::Underlier, + >>::into(modulus_input[byte_idx])), ), )?; } - let ab = byte_sliced_mul::<_, _, LevelIn, LevelOut>( + let ab = byte_sliced_mul::( builder, "ab", mult_a, @@ -166,7 +144,7 @@ where } } - let qm = byte_sliced_mul::<_, _, LevelIn, LevelOut>( + let qm = byte_sliced_mul::( builder, "qm", "ient, @@ -183,7 +161,7 @@ where repeating_zero[byte_idx] = zero_byte_oracle; } - let qm_plus_r = byte_sliced_add_carryfree::<_, _, LevelOut>( + let qm_plus_r = byte_sliced_add_carryfree::( builder, "hi*lo", &qm, @@ -194,12 +172,12 @@ where &mut lookup_batch_add_carryfree, )?; - lookup_batch_mul.execute::<_, _, BinaryField32b>(builder)?; - lookup_batch_add.execute::<_, _, BinaryField32b>(builder)?; - lookup_batch_add_carryfree.execute::<_, _, BinaryField32b>(builder)?; + lookup_batch_mul.execute::(builder)?; + lookup_batch_add.execute::(builder)?; + lookup_batch_add_carryfree.execute::(builder)?; if LevelIn::WIDTH != 1 { - lookup_batch_dci.execute::<_, _, BinaryField32b>(builder)?; + lookup_batch_dci.execute::(builder)?; } let consistency = arith_expr!([x, y] = x - y); diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs index 03742d347..2236045fb 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs @@ -3,14 +3,7 @@ use alloy_primitives::U512; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - tower_levels::TowerLevel, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{tower_levels::TowerLevel, BinaryField8b}; use super::{byte_sliced_add, byte_sliced_double_conditional_increment}; use crate::{ @@ -18,19 +11,14 @@ use crate::{ lasso::{batch::LookupBatch, u8mul::u8mul_bytesliced}, }; -type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; -type B32 = BinaryField32b; #[allow(clippy::too_many_arguments)] pub fn byte_sliced_mul< - U, - F, LevelIn: TowerLevel, LevelOut: TowerLevel, >( - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, mult_a: &LevelIn::Data, mult_b: &LevelIn::Data, @@ -39,20 +27,7 @@ pub fn byte_sliced_mul< lookup_batch_mul: &mut LookupBatch, lookup_batch_add: &mut LookupBatch, lookup_batch_dci: &mut LookupBatch, -) -> Result -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { if LevelIn::WIDTH == 1 { let result_of_u8mul = u8mul_bytesliced( builder, @@ -77,7 +52,7 @@ where let (mult_a_low, mult_a_high) = LevelIn::split(mult_a); let (mult_b_low, mult_b_high) = LevelIn::split(mult_b); - let a_lo_b_lo = byte_sliced_mul::<_, _, LevelIn::Base, LevelOut::Base>( + let a_lo_b_lo = byte_sliced_mul::( builder, format!("lo*lo {}b", LevelIn::Base::WIDTH), mult_a_low, @@ -88,7 +63,7 @@ where lookup_batch_add, lookup_batch_dci, )?; - let a_lo_b_hi = byte_sliced_mul::<_, _, LevelIn::Base, LevelOut::Base>( + let a_lo_b_hi = byte_sliced_mul::( builder, format!("lo*hi {}b", LevelIn::Base::WIDTH), mult_a_low, @@ -99,7 +74,7 @@ where lookup_batch_add, lookup_batch_dci, )?; - let a_hi_b_lo = byte_sliced_mul::<_, _, LevelIn::Base, LevelOut::Base>( + let a_hi_b_lo = byte_sliced_mul::( builder, format!("hi*lo {}b", LevelIn::Base::WIDTH), mult_a_high, @@ -110,7 +85,7 @@ where lookup_batch_add, lookup_batch_dci, )?; - let a_hi_b_hi = byte_sliced_mul::<_, _, LevelIn::Base, LevelOut::Base>( + let a_hi_b_hi = byte_sliced_mul::( builder, format!("hi*hi {}b", LevelIn::Base::WIDTH), mult_a_high, @@ -122,7 +97,7 @@ where lookup_batch_dci, )?; - let (karatsuba_carry_for_high_chunk, karatsuba_term) = byte_sliced_add::<_, _, LevelIn>( + let (karatsuba_carry_for_high_chunk, karatsuba_term) = byte_sliced_add::( builder, format!("karastsuba addition {}b", LevelIn::WIDTH), &a_lo_b_hi, @@ -135,7 +110,7 @@ where let (a_lo_b_lo_lower_half, a_lo_b_lo_upper_half) = LevelIn::split(&a_lo_b_lo); let (a_hi_b_hi_lower_half, a_hi_b_hi_upper_half) = LevelIn::split(&a_hi_b_hi); - let (additional_carry_for_high_chunk, final_middle_chunk) = byte_sliced_add::<_, _, LevelIn>( + let (additional_carry_for_high_chunk, final_middle_chunk) = byte_sliced_add::( builder, format!("post kartsuba middle term addition {}b", LevelIn::WIDTH), &karatsuba_term, @@ -145,7 +120,7 @@ where lookup_batch_add, )?; - let (_, final_high_chunk) = byte_sliced_double_conditional_increment::<_, _, LevelIn::Base>( + let (_, final_high_chunk) = byte_sliced_double_conditional_increment::( builder, format!("high chunk DCI {}b", LevelIn::Base::WIDTH), a_hi_b_hi_upper_half, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs index 1ceaae046..2767cc9ea 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs @@ -5,8 +5,7 @@ use std::{array, fmt::Debug}; use alloy_primitives::U512; use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; use binius_field::{ - arch::OptimalUnderlier, tower_levels::TowerLevel, BinaryField128b, BinaryField1b, - BinaryField32b, BinaryField8b, Field, TowerField, + tower_levels::TowerLevel, BinaryField1b, BinaryField32b, BinaryField8b, Field, TowerField, }; use rand::{rngs::ThreadRng, thread_rng, Rng}; @@ -36,24 +35,20 @@ pub fn test_bytesliced_add() where TL: TowerLevel, { - type U = OptimalUnderlier; - type F = BinaryField128b; let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = 14; - let x_in = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "x", log_size).unwrap() - }); - let y_in = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "y", log_size).unwrap() - }); - let c_in = unconstrained::<_, _, BinaryField1b>(&mut builder, "cin first", log_size).unwrap(); + let x_in = + array::from_fn(|_| unconstrained::(&mut builder, "x", log_size).unwrap()); + let y_in = + array::from_fn(|_| unconstrained::(&mut builder, "y", log_size).unwrap()); + let c_in = unconstrained::(&mut builder, "cin first", log_size).unwrap(); let lookup_t_add = add_lookup(&mut builder, "add table").unwrap(); let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); - let _sum_and_cout = byte_sliced_add::<_, _, TL>( + let _sum_and_cout = byte_sliced_add::( &mut builder, "lasso_bytesliced_add", &x_in, @@ -64,7 +59,7 @@ where ) .unwrap(); - lookup_batch_add.execute::<_, _, B32>(&mut builder).unwrap(); + lookup_batch_add.execute::(&mut builder).unwrap(); let witness = builder.take_witness().unwrap(); let constraint_system = builder.build().unwrap(); @@ -76,10 +71,8 @@ pub fn test_bytesliced_add_carryfree() where TL: TowerLevel, { - type U = OptimalUnderlier; - type F = BinaryField128b; let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = 14; let x_in = array::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL)); let y_in = array::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL)); @@ -130,7 +123,7 @@ where let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); let mut lookup_batch_add_carryfree = LookupBatch::new([lookup_t_add_carryfree]); - let _sum_and_cout = byte_sliced_add_carryfree::<_, _, TL>( + let _sum_and_cout = byte_sliced_add_carryfree::( &mut builder, "lasso_bytesliced_add_carryfree", &x_in, @@ -142,9 +135,9 @@ where ) .unwrap(); - lookup_batch_add.execute::<_, _, B32>(&mut builder).unwrap(); + lookup_batch_add.execute::(&mut builder).unwrap(); lookup_batch_add_carryfree - .execute::<_, _, B32>(&mut builder) + .execute::(&mut builder) .unwrap(); let witness = builder.take_witness().unwrap(); @@ -157,22 +150,16 @@ pub fn test_bytesliced_double_conditional_increment() where TL: TowerLevel, { - type U = OptimalUnderlier; - type F = BinaryField128b; - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = 14; - let x_in = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "x", log_size).unwrap() - }); + let x_in = + array::from_fn(|_| unconstrained::(&mut builder, "x", log_size).unwrap()); - let first_c_in = - unconstrained::<_, _, BinaryField1b>(&mut builder, "cin first", log_size).unwrap(); + let first_c_in = unconstrained::(&mut builder, "cin first", log_size).unwrap(); - let second_c_in = - unconstrained::<_, _, BinaryField1b>(&mut builder, "cin second", log_size).unwrap(); + let second_c_in = unconstrained::(&mut builder, "cin second", log_size).unwrap(); let zero_oracle_carry = transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); @@ -180,7 +167,7 @@ where let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); - let _sum_and_cout = byte_sliced_double_conditional_increment::<_, _, TL>( + let _sum_and_cout = byte_sliced_double_conditional_increment::( &mut builder, "lasso_bytesliced_DCI", &x_in, @@ -192,7 +179,7 @@ where ) .unwrap(); - lookup_batch_dci.execute::<_, _, B32>(&mut builder).unwrap(); + lookup_batch_dci.execute::(&mut builder).unwrap(); let witness = builder.take_witness().unwrap(); let constraint_system = builder.build().unwrap(); @@ -205,19 +192,14 @@ where TL: TowerLevel, TL::Base: TowerLevel, { - type U = OptimalUnderlier; - type F = BinaryField128b; - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = 14; - let mult_a = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "a", log_size).unwrap() - }); - let mult_b = array::from_fn(|_| { - unconstrained::<_, _, BinaryField8b>(&mut builder, "b", log_size).unwrap() - }); + let mult_a = + array::from_fn(|_| unconstrained::(&mut builder, "a", log_size).unwrap()); + let mult_b = + array::from_fn(|_| unconstrained::(&mut builder, "b", log_size).unwrap()); let zero_oracle_carry = transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); @@ -230,7 +212,7 @@ where let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); - let _sum_and_cout = byte_sliced_mul::<_, _, TL::Base, TL>( + let _sum_and_cout = byte_sliced_mul::( &mut builder, "lasso_bytesliced_mul", &mult_a, @@ -255,11 +237,8 @@ where TL::Base: TowerLevel, >::Data: Debug, { - type U = OptimalUnderlier; - type F = BinaryField128b; - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = 14; let mut rng = thread_rng(); @@ -304,7 +283,7 @@ where let zero_oracle_carry = transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - let _modded_product = byte_sliced_modular_mul::<_, _, TL::Base, TL>( + let _modded_product = byte_sliced_modular_mul::( &mut builder, "lasso_bytesliced_mul", &mult_a, diff --git a/crates/circuits/src/lasso/big_integer_ops/mod.rs b/crates/circuits/src/lasso/big_integer_ops/mod.rs index 3046af9be..4b41b524f 100644 --- a/crates/circuits/src/lasso/big_integer_ops/mod.rs +++ b/crates/circuits/src/lasso/big_integer_ops/mod.rs @@ -12,3 +12,56 @@ pub use byte_sliced_add_carryfree::byte_sliced_add_carryfree; pub use byte_sliced_double_conditional_increment::byte_sliced_double_conditional_increment; pub use byte_sliced_modular_mul::byte_sliced_modular_mul; pub use byte_sliced_mul::byte_sliced_mul; + +#[cfg(test)] +mod tests { + use binius_field::tower_levels::{ + TowerLevel1, TowerLevel16, TowerLevel2, TowerLevel4, TowerLevel8, + }; + + use super::byte_sliced_test_utils::{ + test_bytesliced_add, test_bytesliced_add_carryfree, + test_bytesliced_double_conditional_increment, test_bytesliced_modular_mul, + test_bytesliced_mul, + }; + + #[test] + fn test_lasso_add_bytesliced() { + test_bytesliced_add::<1, TowerLevel1>(); + test_bytesliced_add::<2, TowerLevel2>(); + test_bytesliced_add::<4, TowerLevel4>(); + test_bytesliced_add::<8, TowerLevel8>(); + } + + #[test] + fn test_lasso_mul_bytesliced() { + test_bytesliced_mul::<1, TowerLevel2>(); + test_bytesliced_mul::<2, TowerLevel4>(); + test_bytesliced_mul::<4, TowerLevel8>(); + test_bytesliced_mul::<8, TowerLevel16>(); + } + + #[test] + fn test_lasso_modular_mul_bytesliced() { + test_bytesliced_modular_mul::<1, TowerLevel2>(); + test_bytesliced_modular_mul::<2, TowerLevel4>(); + test_bytesliced_modular_mul::<4, TowerLevel8>(); + test_bytesliced_modular_mul::<8, TowerLevel16>(); + } + + #[test] + fn test_lasso_bytesliced_double_conditional_increment() { + test_bytesliced_double_conditional_increment::<1, TowerLevel1>(); + test_bytesliced_double_conditional_increment::<2, TowerLevel2>(); + test_bytesliced_double_conditional_increment::<4, TowerLevel4>(); + test_bytesliced_double_conditional_increment::<8, TowerLevel8>(); + } + + #[test] + fn test_lasso_bytesliced_add_carryfree() { + test_bytesliced_add_carryfree::<1, TowerLevel1>(); + test_bytesliced_add_carryfree::<2, TowerLevel2>(); + test_bytesliced_add_carryfree::<4, TowerLevel4>(); + test_bytesliced_add_carryfree::<8, TowerLevel8>(); + } +} diff --git a/crates/circuits/src/lasso/lasso.rs b/crates/circuits/src/lasso/lasso.rs index 1ecde278a..d93e2d484 100644 --- a/crates/circuits/src/lasso/lasso.rs +++ b/crates/circuits/src/lasso/lasso.rs @@ -4,15 +4,20 @@ use anyhow::{ensure, Error, Result}; use binius_core::{constraint_system::channel::ChannelId, oracle::OracleId}; use binius_field::{ as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField1b, ExtensionField, PackedFieldIndexable, TowerField, + ExtensionField, Field, PackedFieldIndexable, TowerField, }; use itertools::{izip, Itertools}; -use crate::{builder::ConstraintSystemBuilder, transparent}; +use crate::{ + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, + transparent, +}; -pub fn lasso( - builder: &mut ConstraintSystemBuilder, +pub fn lasso( + builder: &mut ConstraintSystemBuilder, name: impl ToString, n_lookups: &[usize], u_to_t_mappings: &[impl AsRef<[usize]>], @@ -21,10 +26,10 @@ pub fn lasso( channel: ChannelId, ) -> Result<()> where - U: UnderlierType + PackScalar + PackScalar + PackScalar, - F: TowerField + ExtensionField + From, - PackedType: PackedFieldIndexable, FC: TowerField, + U: PackScalar, + F: ExtensionField + From, + PackedType: PackedFieldIndexable, { if n_lookups.len() != lookups_u.len() { Err(anyhow::Error::msg("n_vars and lookups_u must be of the same length"))?; @@ -55,7 +60,7 @@ where } let t_log_rows = builder.log_rows(lookup_t.as_ref().iter().copied())?; - let lookup_o = transparent::constant(builder, "lookup_o", t_log_rows, F::ONE)?; + let lookup_o = transparent::constant(builder, "lookup_o", t_log_rows, Field::ONE)?; let lookup_f = builder.add_committed("lookup_f", t_log_rows, FC::TOWER_LEVEL); let lookups_r = u_log_rows .iter() diff --git a/crates/circuits/src/lasso/lookups/u8_arithmetic.rs b/crates/circuits/src/lasso/lookups/u8_arithmetic.rs index 5fdbe0c71..e1d25d834 100644 --- a/crates/circuits/src/lasso/lookups/u8_arithmetic.rs +++ b/crates/circuits/src/lasso/lookups/u8_arithmetic.rs @@ -2,34 +2,19 @@ use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField32b, TowerField}; use crate::builder::ConstraintSystemBuilder; -type B8 = BinaryField8b; -type B16 = BinaryField16b; type B32 = BinaryField32b; const T_LOG_SIZE_MUL: usize = 16; const T_LOG_SIZE_ADD: usize = 17; const T_LOG_SIZE_DCI: usize = 10; -pub fn mul_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn mul_lookup( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_MUL, B32::TOWER_LEVEL); @@ -53,17 +38,10 @@ where Ok(lookup_t) } -pub fn add_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn add_lookup( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_ADD, B32::TOWER_LEVEL); @@ -95,17 +73,10 @@ where Ok(lookup_t) } -pub fn add_carryfree_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn add_carryfree_lookup( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_ADD, B32::TOWER_LEVEL); @@ -139,17 +110,10 @@ where Ok(lookup_t) } -pub fn dci_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn dci_lookup( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let lookup_t = builder.add_committed("lookup_t", T_LOG_SIZE_DCI, B32::TOWER_LEVEL); @@ -182,3 +146,154 @@ where builder.pop_namespace(); Ok(lookup_t) } + +#[cfg(test)] +mod tests { + use binius_core::constraint_system::validate::validate_witness; + use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b}; + + use crate::{ + builder::ConstraintSystemBuilder, + lasso::{self, batch::LookupBatch}, + unconstrained::unconstrained, + }; + + #[test] + fn test_lasso_u8add_carryfree_rejects_carry() { + // TODO: Make this test 100% certain to pass instead of 2^14 bits of security from randomness + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let log_size = 14; + let x_in = unconstrained::(&mut builder, "x", log_size).unwrap(); + let y_in = unconstrained::(&mut builder, "y", log_size).unwrap(); + let c_in = unconstrained::(&mut builder, "c", log_size).unwrap(); + + let lookup_t = super::add_carryfree_lookup(&mut builder, "add cf table").unwrap(); + let mut lookup_batch = LookupBatch::new([lookup_t]); + let _sum_and_cout = lasso::u8add_carryfree( + &mut builder, + &mut lookup_batch, + "lasso_u8add", + x_in, + y_in, + c_in, + log_size, + ) + .unwrap(); + + lookup_batch + .execute::(&mut builder) + .unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness) + .expect_err("Rejected overflowing add"); + } + + #[test] + fn test_lasso_u8mul() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let log_size = 10; + + let mult_a = unconstrained::(&mut builder, "mult_a", log_size).unwrap(); + let mult_b = unconstrained::(&mut builder, "mult_b", log_size).unwrap(); + + let mul_lookup_table = super::mul_lookup(&mut builder, "mul table").unwrap(); + + let mut lookup_batch = LookupBatch::new([mul_lookup_table]); + + let _product = lasso::u8mul( + &mut builder, + &mut lookup_batch, + "lasso_u8mul", + mult_a, + mult_b, + 1 << log_size, + ) + .unwrap(); + + lookup_batch + .execute::(&mut builder) + .unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } + + #[test] + fn test_lasso_batched_u8mul() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let log_size = 10; + let mul_lookup_table = super::mul_lookup(&mut builder, "mul table").unwrap(); + + let mut lookup_batch = LookupBatch::new([mul_lookup_table]); + + for _ in 0..10 { + let mult_a = unconstrained::(&mut builder, "mult_a", log_size).unwrap(); + let mult_b = unconstrained::(&mut builder, "mult_b", log_size).unwrap(); + + let _product = lasso::u8mul( + &mut builder, + &mut lookup_batch, + "lasso_u8mul", + mult_a, + mult_b, + 1 << log_size, + ) + .unwrap(); + } + + lookup_batch + .execute::(&mut builder) + .unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } + + #[test] + fn test_lasso_batched_u8mul_rejects() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let log_size = 10; + + // We try to feed in the add table instead + let mul_lookup_table = super::add_lookup(&mut builder, "mul table").unwrap(); + + let mut lookup_batch = LookupBatch::new([mul_lookup_table]); + + // TODO?: Make this test fail 100% of the time, even though its almost impossible with rng + for _ in 0..10 { + let mult_a = unconstrained::(&mut builder, "mult_a", log_size).unwrap(); + let mult_b = unconstrained::(&mut builder, "mult_b", log_size).unwrap(); + + let _product = lasso::u8mul( + &mut builder, + &mut lookup_batch, + "lasso_u8mul", + mult_a, + mult_b, + 1 << log_size, + ) + .unwrap(); + } + + lookup_batch + .execute::(&mut builder) + .unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness) + .expect_err("Channels should be unbalanced"); + } +} diff --git a/crates/circuits/src/lasso/sha256.rs b/crates/circuits/src/lasso/sha256.rs index c8677acb6..cd51271bc 100644 --- a/crates/circuits/src/lasso/sha256.rs +++ b/crates/circuits/src/lasso/sha256.rs @@ -1,21 +1,19 @@ // Copyright 2024-2025 Irreducible Inc. -use std::marker::PhantomData; - use anyhow::Result; use binius_core::oracle::OracleId; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField16b, BinaryField1b, BinaryField32b, BinaryField4b, BinaryField8b, ExtensionField, + as_packed_field::PackedType, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField4b, PackedFieldIndexable, TowerField, }; -use bytemuck::Pod; use itertools::izip; use super::{lasso::lasso, u32add::SeveralU32add}; use crate::{ - builder::ConstraintSystemBuilder, + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, pack::pack, sha256::{rotate_and_xor, u32const_repeating, RotateRightType, INIT, ROUND_CONSTS_K}, }; @@ -24,36 +22,19 @@ pub const CH_MAJ_T_LOG_SIZE: usize = 12; type B1 = BinaryField1b; type B4 = BinaryField4b; -type B8 = BinaryField8b; type B16 = BinaryField16b; type B32 = BinaryField32b; -struct SeveralBitwise { +struct SeveralBitwise { n_lookups: Vec, lookup_t: OracleId, lookups_u: Vec<[OracleId; 1]>, u_to_t_mappings: Vec>, f: fn(u32, u32, u32) -> u32, - _phantom: PhantomData<(U, F)>, } -impl SeveralBitwise -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + ExtensionField + ExtensionField + ExtensionField, -{ - pub fn new( - builder: &mut ConstraintSystemBuilder, - f: fn(u32, u32, u32) -> u32, - ) -> Result { +impl SeveralBitwise { + pub fn new(builder: &mut ConstraintSystemBuilder, f: fn(u32, u32, u32) -> u32) -> Result { let lookup_t = builder.add_committed("bitwise lookup_t", CH_MAJ_T_LOG_SIZE, B16::TOWER_LEVEL); @@ -80,13 +61,12 @@ where lookups_u: Vec::new(), u_to_t_mappings: Vec::new(), f, - _phantom: PhantomData, }) } pub fn calculate( &mut self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, params: [OracleId; 3], ) -> Result { @@ -94,9 +74,9 @@ where let log_size = builder.log_rows(params)?; - let xin_packed = pack::(xin, builder, "xin_packed")?; - let yin_packed = pack::(yin, builder, "yin_packed")?; - let zin_packed = pack::(zin, builder, "zin_packed")?; + let xin_packed = pack::(xin, builder, "xin_packed")?; + let yin_packed = pack::(yin, builder, "yin_packed")?; + let zin_packed = pack::(zin, builder, "zin_packed")?; let res = builder.add_committed(name, log_size, B1::TOWER_LEVEL); @@ -160,12 +140,12 @@ where pub fn finalize( self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, ) -> Result<()> { let channel = builder.add_channel(); - lasso::<_, _, B32>( + lasso::( builder, name, &self.n_lookups, @@ -177,29 +157,11 @@ where } } -pub fn sha256( - builder: &mut ConstraintSystemBuilder, +pub fn sha256( + builder: &mut ConstraintSystemBuilder, input: [OracleId; 16], log_size: usize, -) -> Result<[OracleId; 8], anyhow::Error> -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField - + ExtensionField - + ExtensionField - + ExtensionField - + ExtensionField, -{ +) -> Result<[OracleId; 8], anyhow::Error> { let n_vars = log_size; let mut several_u32_add = SeveralU32add::new(builder)?; @@ -309,3 +271,66 @@ where Ok(output) } + +#[cfg(test)] +mod tests { + use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; + use binius_field::{as_packed_field::PackedType, BinaryField1b, BinaryField8b, TowerField}; + use sha2::{compress256, digest::generic_array::GenericArray}; + + use crate::{ + builder::{types::U, ConstraintSystemBuilder}, + unconstrained::unconstrained, + }; + + #[test] + fn test_sha256_lasso() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let log_size = PackedType::::LOG_WIDTH + BinaryField8b::TOWER_LEVEL; + let input: [OracleId; 16] = std::array::from_fn(|i| { + unconstrained::(&mut builder, i, log_size).unwrap() + }); + let state_output = super::sha256(&mut builder, input, log_size).unwrap(); + + let witness = builder.witness().unwrap(); + + let input_witneses: [_; 16] = std::array::from_fn(|i| { + witness + .get::(input[i]) + .unwrap() + .as_slice::() + }); + + let output_witneses: [_; 8] = std::array::from_fn(|i| { + witness + .get::(state_output[i]) + .unwrap() + .as_slice::() + }); + + let mut generic_array_input = GenericArray::::default(); + + let n_compressions = input_witneses[0].len(); + + for j in 0..n_compressions { + for i in 0..16 { + for z in 0..4 { + generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; + } + } + + let mut output = crate::sha256::INIT; + compress256(&mut output, &[generic_array_input]); + + for i in 0..8 { + assert_eq!(output[i], output_witneses[i][j]); + } + } + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } +} diff --git a/crates/circuits/src/lasso/u32add.rs b/crates/circuits/src/lasso/u32add.rs index dca11d724..271cb07ce 100644 --- a/crates/circuits/src/lasso/u32add.rs +++ b/crates/circuits/src/lasso/u32add.rs @@ -7,14 +7,19 @@ use binius_core::oracle::{OracleId, ShiftVariant}; use binius_field::{ as_packed_field::{PackScalar, PackedType}, packed::set_packed_slice, - underlier::{UnderlierType, U1}, + underlier::U1, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, PackedFieldIndexable, TowerField, }; -use bytemuck::Pod; use itertools::izip; use super::lasso::lasso; -use crate::{builder::ConstraintSystemBuilder, pack::pack}; +use crate::{ + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, + pack::pack, +}; const ADD_T_LOG_SIZE: usize = 17; @@ -22,32 +27,18 @@ type B1 = BinaryField1b; type B8 = BinaryField8b; type B32 = BinaryField32b; -pub fn u32add( - builder: &mut ConstraintSystemBuilder, +pub fn u32add( + builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, xin: OracleId, yin: OracleId, ) -> Result where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - B8: ExtensionField + ExtensionField, - F: TowerField - + ExtensionField - + ExtensionField - + ExtensionField - + ExtensionField, FInput: TowerField, FOutput: TowerField, - B32: TowerField, + U: PackScalar + PackScalar, + B8: ExtensionField + ExtensionField, + F: ExtensionField + ExtensionField, { let mut several = SeveralU32add::new(builder)?; let sum = several.u32add::(builder, name.clone(), xin, yin)?; @@ -55,7 +46,7 @@ where Ok(sum) } -pub struct SeveralU32add { +pub struct SeveralU32add { n_lookups: Vec, lookup_t: OracleId, lookups_u: Vec<[OracleId; 1]>, @@ -64,19 +55,8 @@ pub struct SeveralU32add { _phantom: PhantomData<(U, F)>, } -impl SeveralU32add -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + ExtensionField + ExtensionField, -{ - pub fn new(builder: &mut ConstraintSystemBuilder) -> Result { +impl SeveralU32add { + pub fn new(builder: &mut ConstraintSystemBuilder) -> Result { let lookup_t = builder.add_committed("lookup_t", ADD_T_LOG_SIZE, B32::TOWER_LEVEL); if let Some(witness) = builder.witness() { @@ -111,15 +91,15 @@ where pub fn u32add( &mut self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, xin: OracleId, yin: OracleId, ) -> Result where - U: PackScalar + PackScalar, FInput: TowerField, FOutput: TowerField, + U: PackScalar + PackScalar, F: ExtensionField + ExtensionField, B8: ExtensionField + ExtensionField, { @@ -143,8 +123,8 @@ where let cin = builder.add_shifted("cin", cout, 1, 2, ShiftVariant::LogicalLeft)?; - let xin_u8 = pack::<_, _, FInput, B8>(xin, builder, "repacked xin")?; - let yin_u8 = pack::<_, _, FInput, B8>(yin, builder, "repacked yin")?; + let xin_u8 = pack::(xin, builder, "repacked xin")?; + let yin_u8 = pack::(yin, builder, "repacked yin")?; let lookup_u = builder.add_linear_combination( "lookup_u", @@ -231,12 +211,12 @@ where pub fn finalize( mut self, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, ) -> Result<()> { let channel = builder.add_channel(); self.finalized = true; - lasso::<_, _, B32>( + lasso::( builder, name, &self.n_lookups, @@ -248,8 +228,70 @@ where } } -impl Drop for SeveralU32add { +impl Drop for SeveralU32add { fn drop(&mut self) { assert!(self.finalized) } } + +#[cfg(test)] +mod tests { + use binius_core::constraint_system::validate::validate_witness; + use binius_field::{BinaryField1b, BinaryField8b}; + + use super::SeveralU32add; + use crate::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; + + #[test] + fn test_several_lasso_u32add() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + let mut several_u32_add = SeveralU32add::new(&mut builder).unwrap(); + + for log_size in [11, 12, 13] { + // BinaryField8b is used here because we utilize an 8x8x1→8 table + let add_a_u8 = unconstrained::(&mut builder, "add_a", log_size).unwrap(); + let add_b_u8 = unconstrained::(&mut builder, "add_b", log_size).unwrap(); + let _sum = several_u32_add + .u32add::( + &mut builder, + "lasso_u32add", + add_a_u8, + add_b_u8, + ) + .unwrap(); + } + + several_u32_add + .finalize(&mut builder, "lasso_u32add") + .unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } + + #[test] + fn test_lasso_u32add() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let log_size = 14; + + let add_a = unconstrained::(&mut builder, "add_a", log_size).unwrap(); + let add_b = unconstrained::(&mut builder, "add_b", log_size).unwrap(); + let _sum = super::u32add::( + &mut builder, + "lasso_u32add", + add_a, + add_b, + ) + .unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } +} diff --git a/crates/circuits/src/lasso/u8_double_conditional_increment.rs b/crates/circuits/src/lasso/u8_double_conditional_increment.rs index 907ac418c..1f10b443a 100644 --- a/crates/circuits/src/lasso/u8_double_conditional_increment.rs +++ b/crates/circuits/src/lasso/u8_double_conditional_increment.rs @@ -2,44 +2,24 @@ use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b, TowerField}; use super::batch::LookupBatch; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; type B32 = BinaryField32b; -pub fn u8_double_conditional_increment( - builder: &mut ConstraintSystemBuilder, +pub fn u8_double_conditional_increment( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, x_in: OracleId, first_carry_in: OracleId, second_carry_in: OracleId, log_size: usize, -) -> Result<(OracleId, OracleId), anyhow::Error> -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result<(OracleId, OracleId), anyhow::Error> { builder.push_namespace(name); let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL); diff --git a/crates/circuits/src/lasso/u8add.rs b/crates/circuits/src/lasso/u8add.rs index fe42995dc..7e3dd0e58 100644 --- a/crates/circuits/src/lasso/u8add.rs +++ b/crates/circuits/src/lasso/u8add.rs @@ -4,44 +4,24 @@ use std::vec; use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b, TowerField}; use super::batch::LookupBatch; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; type B32 = BinaryField32b; -pub fn u8add( - builder: &mut ConstraintSystemBuilder, +pub fn u8add( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, x_in: OracleId, y_in: OracleId, carry_in: OracleId, log_size: usize, -) -> Result<(OracleId, OracleId), anyhow::Error> -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result<(OracleId, OracleId), anyhow::Error> { builder.push_namespace(name); let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL); diff --git a/crates/circuits/src/lasso/u8add_carryfree.rs b/crates/circuits/src/lasso/u8add_carryfree.rs index 45bebbd85..fd1959592 100644 --- a/crates/circuits/src/lasso/u8add_carryfree.rs +++ b/crates/circuits/src/lasso/u8add_carryfree.rs @@ -2,44 +2,24 @@ use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b, TowerField}; use super::batch::LookupBatch; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; type B1 = BinaryField1b; type B8 = BinaryField8b; -type B16 = BinaryField16b; type B32 = BinaryField32b; -pub fn u8add_carryfree( - builder: &mut ConstraintSystemBuilder, +pub fn u8add_carryfree( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, x_in: OracleId, y_in: OracleId, carry_in: OracleId, log_size: usize, -) -> Result -where - U: Pod - + UnderlierType - + PackScalar - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL); diff --git a/crates/circuits/src/lasso/u8mul.rs b/crates/circuits/src/lasso/u8mul.rs index 589bf7a4e..d0c93ecd1 100644 --- a/crates/circuits/src/lasso/u8mul.rs +++ b/crates/circuits/src/lasso/u8mul.rs @@ -2,37 +2,24 @@ use anyhow::{ensure, Result}; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, - BinaryField, BinaryField16b, BinaryField32b, BinaryField8b, ExtensionField, - PackedFieldIndexable, TowerField, -}; -use bytemuck::Pod; +use binius_field::{BinaryField16b, BinaryField32b, BinaryField8b, TowerField}; use itertools::izip; use super::batch::LookupBatch; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; type B8 = BinaryField8b; type B16 = BinaryField16b; type B32 = BinaryField32b; -pub fn u8mul_bytesliced( - builder: &mut ConstraintSystemBuilder, +pub fn u8mul_bytesliced( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, mult_a: OracleId, mult_b: OracleId, n_multiplications: usize, -) -> Result<[OracleId; 2], anyhow::Error> -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result<[OracleId; 2], anyhow::Error> { builder.push_namespace(name); let log_rows = builder.log_rows([mult_a, mult_b])?; let product = builder.add_committed_multiple("product", log_rows, B8::TOWER_LEVEL); @@ -92,21 +79,14 @@ where Ok(product) } -pub fn u8mul( - builder: &mut ConstraintSystemBuilder, +pub fn u8mul( + builder: &mut ConstraintSystemBuilder, lookup_batch: &mut LookupBatch, name: impl ToString + Clone, mult_a: OracleId, mult_b: OracleId, n_multiplications: usize, -) -> Result -where - U: Pod + UnderlierType + PackScalar + PackScalar + PackScalar + PackScalar, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - PackedType: PackedFieldIndexable, - F: TowerField + BinaryField + ExtensionField + ExtensionField + ExtensionField, -{ +) -> Result { builder.push_namespace(name.clone()); let product_bytesliced = diff --git a/crates/circuits/src/lib.rs b/crates/circuits/src/lib.rs index 4be63828e..06c0de27a 100644 --- a/crates/circuits/src/lib.rs +++ b/crates/circuits/src/lib.rs @@ -13,7 +13,6 @@ pub mod arithmetic; pub mod bitwise; pub mod builder; pub mod collatz; -pub mod groestl; pub mod keccakf; pub mod lasso; mod pack; @@ -26,504 +25,32 @@ pub mod vision; #[cfg(test)] mod tests { - use std::array; - use binius_core::{ constraint_system::{ self, channel::{Boundary, FlushDirection}, - validate::validate_witness, }, fiat_shamir::HasherChallenger, - oracle::OracleId, tower::CanonicalTowerFamily, }; use binius_field::{ - arch::OptimalUnderlier, - as_packed_field::PackedType, - tower_levels::{TowerLevel1, TowerLevel16, TowerLevel2, TowerLevel4, TowerLevel8}, - underlier::WithUnderlier, - AESTowerField16b, BinaryField128b, BinaryField1b, BinaryField32b, BinaryField64b, - BinaryField8b, Field, TowerField, + as_packed_field::PackedType, underlier::WithUnderlier, BinaryField8b, Field, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; use groestl_crypto::Groestl256; - use rand::{rngs::StdRng, Rng, SeedableRng}; - use sha2::{compress256, digest::generic_array::GenericArray}; - - use crate::{ - arithmetic, bitwise, - builder::ConstraintSystemBuilder, - groestl::groestl_p_permutation, - keccakf::{keccakf, KeccakfState}, - lasso::{ - self, - batch::LookupBatch, - big_integer_ops::byte_sliced_test_utils::{ - test_bytesliced_add, test_bytesliced_add_carryfree, - test_bytesliced_double_conditional_increment, test_bytesliced_modular_mul, - test_bytesliced_mul, - }, - lookups, - u32add::SeveralU32add, - }, - plain_lookup, - sha256::sha256, - u32fib::u32fib, - unconstrained::unconstrained, - vision::vision_permutation, - }; - - type U = OptimalUnderlier; - type F = BinaryField128b; - - #[test] - fn test_lasso_u8add_carryfree_rejects_carry() { - // TODO: Make this test 100% certain to pass instead of 2^14 bits of security from randomness - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - let x_in = unconstrained::<_, _, BinaryField8b>(&mut builder, "x", log_size).unwrap(); - let y_in = unconstrained::<_, _, BinaryField8b>(&mut builder, "y", log_size).unwrap(); - let c_in = unconstrained::<_, _, BinaryField1b>(&mut builder, "c", log_size).unwrap(); - - let lookup_t = - lookups::u8_arithmetic::add_carryfree_lookup(&mut builder, "add cf table").unwrap(); - let mut lookup_batch = LookupBatch::new([lookup_t]); - let _sum_and_cout = lasso::u8add_carryfree( - &mut builder, - &mut lookup_batch, - "lasso_u8add", - x_in, - y_in, - c_in, - log_size, - ) - .unwrap(); - - lookup_batch - .execute::<_, _, BinaryField32b>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness) - .expect_err("Rejected overflowing add"); - } - - #[test] - fn test_lasso_add_bytesliced() { - test_bytesliced_add::<1, TowerLevel1>(); - test_bytesliced_add::<2, TowerLevel2>(); - test_bytesliced_add::<4, TowerLevel4>(); - test_bytesliced_add::<8, TowerLevel8>(); - } - - #[test] - fn test_lasso_mul_bytesliced() { - test_bytesliced_mul::<1, TowerLevel2>(); - test_bytesliced_mul::<2, TowerLevel4>(); - test_bytesliced_mul::<4, TowerLevel8>(); - test_bytesliced_mul::<8, TowerLevel16>(); - } - - #[test] - fn test_lasso_modular_mul_bytesliced() { - test_bytesliced_modular_mul::<1, TowerLevel2>(); - test_bytesliced_modular_mul::<2, TowerLevel4>(); - test_bytesliced_modular_mul::<4, TowerLevel8>(); - test_bytesliced_modular_mul::<8, TowerLevel16>(); - } - - #[test] - fn test_lasso_bytesliced_double_conditional_increment() { - test_bytesliced_double_conditional_increment::<1, TowerLevel1>(); - test_bytesliced_double_conditional_increment::<2, TowerLevel2>(); - test_bytesliced_double_conditional_increment::<4, TowerLevel4>(); - test_bytesliced_double_conditional_increment::<8, TowerLevel8>(); - } - - #[test] - fn test_lasso_bytesliced_add_carryfree() { - test_bytesliced_add_carryfree::<1, TowerLevel1>(); - test_bytesliced_add_carryfree::<2, TowerLevel2>(); - test_bytesliced_add_carryfree::<4, TowerLevel4>(); - test_bytesliced_add_carryfree::<8, TowerLevel8>(); - } - - #[test] - fn test_lasso_u8mul() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 10; - - let mult_a = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_a", log_size).unwrap(); - let mult_b = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_b", log_size).unwrap(); - - let mul_lookup_table = - lookups::u8_arithmetic::mul_lookup(&mut builder, "mul table").unwrap(); - - let mut lookup_batch = LookupBatch::new([mul_lookup_table]); - - let _product = lasso::u8mul( - &mut builder, - &mut lookup_batch, - "lasso_u8mul", - mult_a, - mult_b, - 1 << log_size, - ) - .unwrap(); - - lookup_batch - .execute::<_, _, BinaryField32b>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_lasso_batched_u8mul() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 10; - let mul_lookup_table = - lookups::u8_arithmetic::mul_lookup(&mut builder, "mul table").unwrap(); - - let mut lookup_batch = LookupBatch::new([mul_lookup_table]); - - for _ in 0..10 { - let mult_a = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_a", log_size).unwrap(); - let mult_b = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_b", log_size).unwrap(); - - let _product = lasso::u8mul( - &mut builder, - &mut lookup_batch, - "lasso_u8mul", - mult_a, - mult_b, - 1 << log_size, - ) - .unwrap(); - } - - lookup_batch - .execute::<_, _, BinaryField32b>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_lasso_batched_u8mul_rejects() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 10; - - // We try to feed in the add table instead - let mul_lookup_table = - lookups::u8_arithmetic::add_lookup(&mut builder, "mul table").unwrap(); - - let mut lookup_batch = LookupBatch::new([mul_lookup_table]); - - // TODO?: Make this test fail 100% of the time, even though its almost impossible with rng - for _ in 0..10 { - let mult_a = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_a", log_size).unwrap(); - let mult_b = - unconstrained::<_, _, BinaryField8b>(&mut builder, "mult_b", log_size).unwrap(); - - let _product = lasso::u8mul( - &mut builder, - &mut lookup_batch, - "lasso_u8mul", - mult_a, - mult_b, - 1 << log_size, - ) - .unwrap(); - } - - lookup_batch - .execute::<_, _, BinaryField32b>(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness) - .expect_err("Channels should be unbalanced"); - } - - #[test] - fn test_several_lasso_u32add() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - - let mut several_u32_add = SeveralU32add::new(&mut builder).unwrap(); - - for log_size in [11, 12, 13] { - // BinaryField8b is used here because we utilize an 8x8x1→8 table - let add_a_u8 = - unconstrained::<_, _, BinaryField8b>(&mut builder, "add_a", log_size).unwrap(); - let add_b_u8 = - unconstrained::<_, _, BinaryField8b>(&mut builder, "add_b", log_size).unwrap(); - let _sum = several_u32_add - .u32add::( - &mut builder, - "lasso_u32add", - add_a_u8, - add_b_u8, - ) - .unwrap(); - } - - several_u32_add - .finalize(&mut builder, "lasso_u32add") - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_lasso_u32add() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - - let add_a = unconstrained::<_, _, BinaryField1b>(&mut builder, "add_a", log_size).unwrap(); - let add_b = unconstrained::<_, _, BinaryField1b>(&mut builder, "add_b", log_size).unwrap(); - let _sum = lasso::u32add::<_, _, BinaryField1b, BinaryField1b>( - &mut builder, - "lasso_u32add", - add_a, - add_b, - ) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_u32add() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 14; - let a = unconstrained::<_, _, BinaryField1b>(&mut builder, "a", log_size).unwrap(); - let b = unconstrained::<_, _, BinaryField1b>(&mut builder, "b", log_size).unwrap(); - let _c = arithmetic::u32::add(&mut builder, "u32add", a, b, arithmetic::Flags::Unchecked) - .unwrap(); - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_u32fib() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size_1b = 14; - let _ = u32fib(&mut builder, "u32fib", log_size_1b).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_bitwise() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 6; - let a = unconstrained::<_, _, BinaryField1b>(&mut builder, "a", log_size).unwrap(); - let b = unconstrained::<_, _, BinaryField1b>(&mut builder, "b", log_size).unwrap(); - let _and = bitwise::and(&mut builder, "and", a, b).unwrap(); - let _xor = bitwise::xor(&mut builder, "xor", a, b).unwrap(); - let _or = bitwise::or(&mut builder, "or", a, b).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_keccakf() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = 5; - - let mut rng = StdRng::seed_from_u64(0); - let input_states = vec![KeccakfState(rng.gen())]; - let _state_out = keccakf(&mut builder, &Some(input_states), log_size); - - let witness = builder.take_witness().unwrap(); - - let constraint_system = builder.build().unwrap(); - - let boundaries = vec![]; - - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_sha256() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = PackedType::::LOG_WIDTH; - let input: [OracleId; 16] = array::from_fn(|i| { - unconstrained::<_, _, BinaryField1b>(&mut builder, i, log_size).unwrap() - }); - let state_output = sha256(&mut builder, input, log_size).unwrap(); - - let witness = builder.witness().unwrap(); - - let input_witneses: [_; 16] = - array::from_fn(|i| witness.get(input[i]).unwrap().as_slice::()); - - let output_witneses: [_; 8] = - array::from_fn(|i| witness.get(state_output[i]).unwrap().as_slice::()); - - let mut generic_array_input = GenericArray::::default(); - - let n_compressions = input_witneses[0].len(); - - for j in 0..n_compressions { - for i in 0..16 { - for z in 0..4 { - generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; - } - } - - let mut output = crate::sha256::INIT; - compress256(&mut output, &[generic_array_input]); - - for i in 0..8 { - assert_eq!(output[i], output_witneses[i][j]); - } - } - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_sha256_lasso() { - let allocator = bumpalo::Bump::new(); - let mut builder = - ConstraintSystemBuilder::::new_with_witness(&allocator); - let log_size = PackedType::::LOG_WIDTH + BinaryField8b::TOWER_LEVEL; - let input: [OracleId; 16] = array::from_fn(|i| { - unconstrained::<_, _, BinaryField1b>(&mut builder, i, log_size).unwrap() - }); - let state_output = lasso::sha256(&mut builder, input, log_size).unwrap(); - - let witness = builder.witness().unwrap(); - - let input_witneses: [_; 16] = array::from_fn(|i| { - witness - .get::(input[i]) - .unwrap() - .as_slice::() - }); - - let output_witneses: [_; 8] = array::from_fn(|i| { - witness - .get::(state_output[i]) - .unwrap() - .as_slice::() - }); - - let mut generic_array_input = GenericArray::::default(); - - let n_compressions = input_witneses[0].len(); - - for j in 0..n_compressions { - for i in 0..16 { - for z in 0..4 { - generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; - } - } - - let mut output = crate::sha256::INIT; - compress256(&mut output, &[generic_array_input]); - - for i in 0..8 { - assert_eq!(output[i], output_witneses[i][j]); - } - } - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_groestl() { - let allocator = bumpalo::Bump::new(); - let mut builder = - ConstraintSystemBuilder::::new_with_witness( - &allocator, - ); - let log_size = 9; - let _state_out = groestl_p_permutation(&mut builder, log_size).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_vision32b() { - let allocator = bumpalo::Bump::new(); - let mut builder = - ConstraintSystemBuilder::::new_with_witness( - &allocator, - ); - let log_size = 8; - let state_in: [OracleId; 24] = array::from_fn(|i| { - unconstrained::<_, _, BinaryField32b>(&mut builder, format!("p_in[{i}]"), log_size) - .unwrap() - }); - let _state_out = vision_permutation(&mut builder, log_size, state_in).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } + use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }; #[test] fn test_boundaries() { // Proving Collatz Orbits let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = PackedType::::LOG_WIDTH + 2; @@ -633,74 +160,4 @@ mod tests { >(&constraint_system, 1, 10, &boundaries, proof) .unwrap(); } - - #[test] - fn test_plain_u8_mul_lookup() { - const MAX_LOG_MULTIPLICITY: usize = 18; - let log_lookup_count = 19; - - let log_inv_rate = 1; - let security_bits = 20; - - let proof = { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - - let boundary = plain_lookup::test_plain_lookup::test_u8_mul_lookup::< - _, - _, - MAX_LOG_MULTIPLICITY, - >(&mut builder, log_lookup_count) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - // validating witness with `validate_witness` is too slow for large transparents like the `table` - - let domain_factory = DefaultEvaluationDomainFactory::default(); - let backend = make_portable_backend(); - - constraint_system::prove::< - U, - CanonicalTowerFamily, - _, - Groestl256, - Groestl256ByteCompression, - HasherChallenger, - _, - >( - &constraint_system, - log_inv_rate, - security_bits, - &[boundary], - witness, - &domain_factory, - &backend, - ) - .unwrap() - }; - - // verify - { - let mut builder = ConstraintSystemBuilder::::new(); - - let boundary = plain_lookup::test_plain_lookup::test_u8_mul_lookup::< - _, - _, - MAX_LOG_MULTIPLICITY, - >(&mut builder, log_lookup_count) - .unwrap(); - - let constraint_system = builder.build().unwrap(); - - constraint_system::verify::< - U, - CanonicalTowerFamily, - Groestl256, - Groestl256ByteCompression, - HasherChallenger, - >(&constraint_system, log_inv_rate, security_bits, &[boundary], proof) - .unwrap(); - } - } } diff --git a/crates/circuits/src/pack.rs b/crates/circuits/src/pack.rs index 5d3c8f79e..290a44729 100644 --- a/crates/circuits/src/pack.rs +++ b/crates/circuits/src/pack.rs @@ -2,22 +2,23 @@ use anyhow::Result; use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, ExtensionField, TowerField, -}; +use binius_field::{as_packed_field::PackScalar, ExtensionField, TowerField}; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; -pub fn pack( +pub fn pack( oracle_id: OracleId, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, ) -> Result where - F: TowerField + ExtensionField + ExtensionField, + F: ExtensionField + ExtensionField, FInput: TowerField, FOutput: TowerField + ExtensionField, - U: UnderlierType + PackScalar + PackScalar + PackScalar, + U: PackScalar + PackScalar, { if FInput::TOWER_LEVEL == FOutput::TOWER_LEVEL { return Ok(oracle_id); diff --git a/crates/circuits/src/plain_lookup.rs b/crates/circuits/src/plain_lookup.rs index 0c41cc6ef..8ec4b38b5 100644 --- a/crates/circuits/src/plain_lookup.rs +++ b/crates/circuits/src/plain_lookup.rs @@ -11,7 +11,10 @@ use binius_field::{ use bytemuck::Pod; use itertools::izip; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; /// Checks values in `lookup_values` to be in `table`. /// @@ -48,8 +51,8 @@ use crate::builder::ConstraintSystemBuilder; /// To rectify this we put `balancer_value` in a boundary value and push this boundary value to the channel with a multiplicity that will balance the channel. /// This boundary value is returned from the gadget. /// -pub fn plain_lookup( - builder: &mut ConstraintSystemBuilder, +pub fn plain_lookup( + builder: &mut ConstraintSystemBuilder, table: OracleId, table_count: usize, balancer_value: FS, @@ -57,8 +60,8 @@ pub fn plain_lookup( lookup_values_count: usize, ) -> Result, anyhow::Error> where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, + U: PackScalar + Pod, + F: ExtensionField, FS: TowerField + Pod, { let n_vars = builder.log_rows([table])?; @@ -81,14 +84,13 @@ where )?); } - let components: [OracleId; LOG_MAX_MULTIPLICITY] = - get_components::<_, _, FS, LOG_MAX_MULTIPLICITY>( - builder, - table, - table_count, - balancer_value, - multiplicities, - )?; + let components: [OracleId; LOG_MAX_MULTIPLICITY] = get_components::( + builder, + table, + table_count, + balancer_value, + multiplicities, + )?; components .into_iter() @@ -117,16 +119,16 @@ where } // the `i`'th returned component holds values that are the product of the `table` values and the bits had by taking the `i`'th bit across the multiplicities. -fn get_components( - builder: &mut ConstraintSystemBuilder, +fn get_components( + builder: &mut ConstraintSystemBuilder, table: OracleId, table_count: usize, balancer_value: FS, multiplicities: Option>, ) -> Result<[OracleId; LOG_MAX_MULTIPLICITY], anyhow::Error> where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, + U: PackScalar, + F: ExtensionField, FS: TowerField + Pod, { let n_vars = builder.log_rows([table])?; @@ -242,14 +244,10 @@ pub mod test_plain_lookup { }); } - pub fn test_u8_mul_lookup( - builder: &mut ConstraintSystemBuilder, + pub fn test_u8_mul_lookup( + builder: &mut ConstraintSystemBuilder, log_lookup_count: usize, - ) -> Result, anyhow::Error> - where - U: PackScalar + PackScalar + PackScalar + Pod, - F: TowerField + ExtensionField, - { + ) -> Result, anyhow::Error> { let table_values = generate_u8_mul_table(); let table = transparent::make_transparent( builder, @@ -272,7 +270,7 @@ pub mod test_plain_lookup { generate_random_u8_mul_claims(&mut mut_slice[0..lookup_values_count]); } - let boundary = plain_lookup::( + let boundary = plain_lookup::( builder, table, table_count, @@ -365,3 +363,83 @@ mod count_multiplicity_tests { assert_eq!(result, vec![1, 2, 3]); } } + +#[cfg(test)] +mod tests { + use binius_core::{fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; + use binius_hal::make_portable_backend; + use binius_hash::compress::Groestl256ByteCompression; + use binius_math::DefaultEvaluationDomainFactory; + use groestl_crypto::Groestl256; + + use super::test_plain_lookup; + use crate::builder::ConstraintSystemBuilder; + + #[test] + fn test_plain_u8_mul_lookup() { + const MAX_LOG_MULTIPLICITY: usize = 18; + let log_lookup_count = 19; + + let log_inv_rate = 1; + let security_bits = 20; + + let proof = { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + let boundary = test_plain_lookup::test_u8_mul_lookup::( + &mut builder, + log_lookup_count, + ) + .unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + // validating witness with `validate_witness` is too slow for large transparents like the `table` + + let domain_factory = DefaultEvaluationDomainFactory::default(); + let backend = make_portable_backend(); + + binius_core::constraint_system::prove::< + crate::builder::types::U, + CanonicalTowerFamily, + _, + Groestl256, + Groestl256ByteCompression, + HasherChallenger, + _, + >( + &constraint_system, + log_inv_rate, + security_bits, + &[boundary], + witness, + &domain_factory, + &backend, + ) + .unwrap() + }; + + // verify + { + let mut builder = ConstraintSystemBuilder::new(); + + let boundary = test_plain_lookup::test_u8_mul_lookup::( + &mut builder, + log_lookup_count, + ) + .unwrap(); + + let constraint_system = builder.build().unwrap(); + + binius_core::constraint_system::verify::< + crate::builder::types::U, + CanonicalTowerFamily, + Groestl256, + Groestl256ByteCompression, + HasherChallenger, + >(&constraint_system, log_inv_rate, security_bits, &[boundary], proof) + .unwrap(); + } + } +} diff --git a/crates/circuits/src/sha256.rs b/crates/circuits/src/sha256.rs index d2e8fe6e9..d28502d15 100644 --- a/crates/circuits/src/sha256.rs +++ b/crates/circuits/src/sha256.rs @@ -5,16 +5,21 @@ use binius_core::{ transparent::multilinear_extension::MultilinearExtensionTransparent, }; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - underlier::{UnderlierType, WithUnderlier}, - BinaryField1b, PackedField, TowerField, + as_packed_field::PackedType, underlier::WithUnderlier, BinaryField1b, Field, PackedField, + TowerField, }; use binius_macros::arith_expr; use binius_utils::checked_arithmetics::checked_log_2; use bytemuck::{pod_collect_to_vec, Pod}; use itertools::izip; -use crate::{arithmetic, builder::ConstraintSystemBuilder}; +use crate::{ + arithmetic, + builder::{ + types::{F, U}, + ConstraintSystemBuilder, + }, +}; const LOG_U32_BITS: usize = checked_log_2(32); @@ -41,15 +46,11 @@ pub enum RotateRightType { Logical, } -pub fn rotate_and_xor( +pub fn rotate_and_xor( log_size: usize, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, r: &[(OracleId, usize, RotateRightType)], -) -> Result -where - F: TowerField, - U: UnderlierType + Pod + PackScalar + PackScalar, -{ +) -> Result { let shifted_oracle_ids = r .iter() .map(|(oracle_id, shift, t)| { @@ -76,7 +77,7 @@ where let result_oracle_id = builder.add_linear_combination( format!("linear combination of {:?}", shifted_oracle_ids), log_size, - shifted_oracle_ids.iter().map(|s| (*s, F::ONE)), + shifted_oracle_ids.iter().map(|s| (*s, Field::ONE)), )?; if let Some(witness) = builder.witness() { @@ -116,16 +117,12 @@ where .collect() } -pub fn u32const_repeating( +pub fn u32const_repeating( log_size: usize, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, x: u32, name: &str, -) -> Result -where - F: TowerField, - U: UnderlierType + Pod + PackScalar + PackScalar, -{ +) -> Result { let brodcasted = vec![x; 1 << (PackedType::::LOG_WIDTH.saturating_sub(LOG_U32_BITS))]; let transparent_id = builder.add_transparent( @@ -152,15 +149,11 @@ where Ok(repeating_id) } -pub fn sha256( - builder: &mut ConstraintSystemBuilder, +pub fn sha256( + builder: &mut ConstraintSystemBuilder, input: [OracleId; 16], log_size: usize, -) -> Result<[OracleId; 8], anyhow::Error> -where - U: UnderlierType + Pod + PackScalar + PackScalar, - F: TowerField, -{ +) -> Result<[OracleId; 8], anyhow::Error> { if log_size < >::LOG_WIDTH { Err(anyhow::Error::msg("log_size too small"))? } @@ -319,3 +312,66 @@ where Ok(output) } + +#[cfg(test)] +mod tests { + use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; + use binius_field::{as_packed_field::PackedType, BinaryField1b}; + use sha2::{compress256, digest::generic_array::GenericArray}; + + use crate::{ + builder::{types::U, ConstraintSystemBuilder}, + unconstrained::unconstrained, + }; + + #[test] + fn test_sha256() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let log_size = PackedType::::LOG_WIDTH; + let input: [OracleId; 16] = std::array::from_fn(|i| { + unconstrained::(&mut builder, i, log_size).unwrap() + }); + let state_output = super::sha256(&mut builder, input, log_size).unwrap(); + + let witness = builder.witness().unwrap(); + + let input_witneses: [_; 16] = std::array::from_fn(|i| { + witness + .get::(input[i]) + .unwrap() + .as_slice::() + }); + + let output_witneses: [_; 8] = std::array::from_fn(|i| { + witness + .get::(state_output[i]) + .unwrap() + .as_slice::() + }); + + let mut generic_array_input = GenericArray::::default(); + + let n_compressions = input_witneses[0].len(); + + for j in 0..n_compressions { + for i in 0..16 { + for z in 0..4 { + generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; + } + } + + let mut output = crate::sha256::INIT; + compress256(&mut output, &[generic_array_input]); + + for i in 0..8 { + assert_eq!(output[i], output_witneses[i][j]); + } + } + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } +} diff --git a/crates/circuits/src/transparent.rs b/crates/circuits/src/transparent.rs index 16efed3ac..11e36c223 100644 --- a/crates/circuits/src/transparent.rs +++ b/crates/circuits/src/transparent.rs @@ -3,23 +3,20 @@ use binius_core::{oracle::OracleId, transparent}; use binius_field::{ as_packed_field::{PackScalar, PackedType}, - underlier::UnderlierType, BinaryField1b, ExtensionField, PackedField, TowerField, }; -use bytemuck::Pod; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; -pub fn step_down( - builder: &mut ConstraintSystemBuilder, +pub fn step_down( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, index: usize, -) -> Result -where - U: UnderlierType + PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { let step_down = transparent::step_down::StepDown::new(log_size, index)?; let id = builder.add_transparent(name, step_down.clone())?; if let Some(witness) = builder.witness() { @@ -28,16 +25,12 @@ where Ok(id) } -pub fn step_up( - builder: &mut ConstraintSystemBuilder, +pub fn step_up( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, index: usize, -) -> Result -where - U: UnderlierType + PackScalar + PackScalar + Pod, - F: TowerField, -{ +) -> Result { let step_up = transparent::step_up::StepUp::new(log_size, index)?; let id = builder.add_transparent(name, step_up.clone())?; if let Some(witness) = builder.witness() { @@ -46,14 +39,14 @@ where Ok(id) } -pub fn constant( - builder: &mut ConstraintSystemBuilder, +pub fn constant( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, value: FS, ) -> Result where - U: UnderlierType + PackScalar + PackScalar, + U: PackScalar, F: TowerField + ExtensionField, FS: TowerField, { @@ -68,13 +61,13 @@ where Ok(id) } -pub fn make_transparent( - builder: &mut ConstraintSystemBuilder, +pub fn make_transparent( + builder: &mut ConstraintSystemBuilder, name: impl ToString, values: &[FS], ) -> Result where - U: PackScalar + PackScalar, + U: PackScalar, F: TowerField + ExtensionField, FS: TowerField, { diff --git a/crates/circuits/src/u32fib.rs b/crates/circuits/src/u32fib.rs index 1f4197af4..4c53c389d 100644 --- a/crates/circuits/src/u32fib.rs +++ b/crates/circuits/src/u32fib.rs @@ -1,26 +1,22 @@ // Copyright 2024-2025 Irreducible Inc. use binius_core::oracle::{OracleId, ShiftVariant}; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, BinaryField1b, BinaryField32b, - ExtensionField, TowerField, -}; +use binius_field::{BinaryField1b, BinaryField32b, TowerField}; use binius_macros::arith_expr; use binius_maybe_rayon::prelude::*; -use bytemuck::Pod; use rand::{thread_rng, Rng}; -use crate::{arithmetic, builder::ConstraintSystemBuilder, transparent::step_down}; +use crate::{ + arithmetic, + builder::{types::F, ConstraintSystemBuilder}, + transparent::step_down, +}; -pub fn u32fib( - builder: &mut ConstraintSystemBuilder, +pub fn u32fib( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, -) -> Result -where - U: UnderlierType + Pod + PackScalar + PackScalar + PackScalar, - F: TowerField + ExtensionField, -{ +) -> Result { builder.push_namespace(name); let current = builder.add_committed("current", log_size, BinaryField1b::TOWER_LEVEL); let next = builder.add_shifted("next", current, 32, log_size, ShiftVariant::LogicalRight)?; @@ -75,3 +71,23 @@ where builder.pop_namespace(); Ok(current) } + +#[cfg(test)] +mod tests { + use binius_core::constraint_system::validate::validate_witness; + + use crate::builder::ConstraintSystemBuilder; + + #[test] + fn test_u32fib() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let log_size_1b = 14; + let _ = super::u32fib(&mut builder, "u32fib", log_size_1b).unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } +} diff --git a/crates/circuits/src/unconstrained.rs b/crates/circuits/src/unconstrained.rs index a798ec5f1..1e4da1b06 100644 --- a/crates/circuits/src/unconstrained.rs +++ b/crates/circuits/src/unconstrained.rs @@ -1,21 +1,22 @@ // Copyright 2024-2025 Irreducible Inc. use binius_core::oracle::OracleId; -use binius_field::{ - as_packed_field::PackScalar, underlier::UnderlierType, ExtensionField, TowerField, -}; +use binius_field::{as_packed_field::PackScalar, ExtensionField, TowerField}; use binius_maybe_rayon::prelude::*; use bytemuck::Pod; use rand::{thread_rng, Rng}; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{ + types::{F, U}, + ConstraintSystemBuilder, +}; -pub fn unconstrained( - builder: &mut ConstraintSystemBuilder, +pub fn unconstrained( + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, ) -> Result where - U: UnderlierType + Pod + PackScalar + PackScalar, + U: PackScalar + Pod, F: TowerField + ExtensionField, FS: TowerField, { diff --git a/crates/circuits/src/vision.rs b/crates/circuits/src/vision.rs index 5e8fe20a0..e8294f244 100644 --- a/crates/circuits/src/vision.rs +++ b/crates/circuits/src/vision.rs @@ -12,36 +12,22 @@ use std::array; use anyhow::Result; use binius_core::{oracle::OracleId, transparent::constant::Constant}; use binius_field::{ - as_packed_field::{PackScalar, PackedType}, - linear_transformation::Transformation, - make_aes_to_binary_packed_transformer, - packed::get_packed_slice, - underlier::UnderlierType, - BinaryField1b, BinaryField32b, BinaryField64b, ExtensionField, PackedAESBinaryField8x32b, - PackedBinaryField8x32b, PackedField, TowerField, + linear_transformation::Transformation, make_aes_to_binary_packed_transformer, + packed::get_packed_slice, BinaryField1b, BinaryField32b, ExtensionField, Field, + PackedAESBinaryField8x32b, PackedBinaryField8x32b, PackedField, TowerField, }; use binius_hash::{Vision32MDSTransform, INV_PACKED_TRANS_AES}; use binius_macros::arith_expr; use binius_math::ArithExpr; -use bytemuck::{must_cast_slice, Pod}; +use bytemuck::must_cast_slice; -use crate::builder::ConstraintSystemBuilder; +use crate::builder::{types::F, ConstraintSystemBuilder}; -pub fn vision_permutation( - builder: &mut ConstraintSystemBuilder, +pub fn vision_permutation( + builder: &mut ConstraintSystemBuilder, log_size: usize, p_in: [OracleId; STATE_SIZE], -) -> Result<[OracleId; STATE_SIZE]> -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - F: TowerField + ExtensionField + ExtensionField, - PackedType: Pod, -{ +) -> Result<[OracleId; STATE_SIZE]> { // This only acts as a shorthand type B32 = BinaryField32b; @@ -77,7 +63,7 @@ where } let perm_out = (0..N_ROUNDS).try_fold(round_0_input, |state, round_i| { - vision_round::(builder, log_size, round_i, state) + vision_round(builder, log_size, round_i, state) })?; #[cfg(debug_assertions)] @@ -237,22 +223,13 @@ fn inv_constraint_expr() -> Result> { Ok(non_zero_case * zero_case) } -fn vision_round( - builder: &mut ConstraintSystemBuilder, +fn vision_round( + builder: &mut ConstraintSystemBuilder, log_size: usize, round_i: usize, perm_in: [OracleId; STATE_SIZE], ) -> Result<[OracleId; STATE_SIZE]> -where - U: UnderlierType - + Pod - + PackScalar - + PackScalar - + PackScalar - + PackScalar, - F: TowerField + ExtensionField + ExtensionField, - PackedType: Pod, -{ +where { builder.push_namespace(format!("round[{round_i}]")); let inv_0 = builder.add_committed_multiple::( "inv_evens", @@ -318,7 +295,10 @@ where .add_linear_combination( format!("round_out_evens_{}", row), log_size, - [(mds_out_0[row], F::ONE), (even_round_consts[row], F::ONE)], + [ + (mds_out_0[row], Field::ONE), + (even_round_consts[row], Field::ONE), + ], ) .unwrap() }); @@ -328,7 +308,10 @@ where .add_linear_combination( format!("round_out_odd_{}", row), log_size, - [(mds_out_1[row], F::ONE), (odd_round_consts[row], F::ONE)], + [ + (mds_out_1[row], Field::ONE), + (odd_round_consts[row], Field::ONE), + ], ) .unwrap() }); @@ -476,3 +459,28 @@ where Ok(perm_out) } + +#[cfg(test)] +mod tests { + use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; + use binius_field::BinaryField32b; + + use super::vision_permutation; + use crate::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; + + #[test] + fn test_vision32b() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let log_size = 8; + let state_in: [OracleId; 24] = std::array::from_fn(|i| { + unconstrained::(&mut builder, format!("p_in[{i}]"), log_size).unwrap() + }); + let _state_out = vision_permutation(&mut builder, log_size, state_in).unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraint_system = builder.build().unwrap(); + let boundaries = vec![]; + validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + } +} diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 74dea8f0c..07edfc0a3 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -24,10 +24,6 @@ rand.workspace = true tracing-profile.workspace = true tracing.workspace = true -[[example]] -name = "groestl_circuit" -path = "groestl_circuit.rs" - [[example]] name = "keccakf_circuit" path = "keccakf_circuit.rs" @@ -85,4 +81,3 @@ aes-tower = [] bail_panic = ["binius_utils/bail_panic"] fp-tower = [] rayon = ["binius_utils/rayon"] - diff --git a/examples/b32_mul.rs b/examples/b32_mul.rs index b0168f43a..5cdbbea53 100644 --- a/examples/b32_mul.rs +++ b/examples/b32_mul.rs @@ -1,9 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField32b, TowerField}; +use binius_field::{BinaryField32b, TowerField}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_macros::arith_expr; @@ -26,7 +26,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -42,18 +41,18 @@ fn main() -> Result<()> { let log_n_muls = log2_ceil_usize(args.n_ops as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField32b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_muls, ) .unwrap(); - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField32b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_muls, diff --git a/examples/bitwise_ops.rs b/examples/bitwise_ops.rs index 3eaec2bb5..6b46e29ad 100644 --- a/examples/bitwise_ops.rs +++ b/examples/bitwise_ops.rs @@ -3,11 +3,9 @@ use std::{fmt::Display, str::FromStr}; use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{ - arch::OptimalUnderlier, BinaryField128b, BinaryField1b, BinaryField32b, TowerField, -}; +use binius_field::{BinaryField1b, BinaryField32b, TowerField}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_macros::arith_expr; @@ -63,7 +61,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -80,16 +77,16 @@ fn main() -> Result<()> { log2_ceil_usize(args.n_u32_ops as usize) + BinaryField32b::TOWER_LEVEL; let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); // Assuming our 32bit values have been committed as bits - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField1b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_1b_operations, )?; - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField1b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_1b_operations, diff --git a/examples/collatz.rs b/examples/collatz.rs index c30305ca5..f51ce7a06 100644 --- a/examples/collatz.rs +++ b/examples/collatz.rs @@ -2,7 +2,7 @@ use anyhow::Result; use binius_circuits::{ - builder::ConstraintSystemBuilder, + builder::{types::U, ConstraintSystemBuilder}, collatz::{Advice, Collatz}, }; use binius_core::{ @@ -10,7 +10,6 @@ use binius_core::{ fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -29,9 +28,6 @@ struct Args { log_inv_rate: u32, } -type U = OptimalUnderlier; -type F = BinaryField128b; - const SECURITY_BITS: usize = 100; fn main() -> Result<()> { @@ -59,7 +55,7 @@ fn prove(x0: u32, log_inv_rate: usize) -> Result<(Advice, Proof), anyhow::Error> let advice = collatz.init_prover(); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let boundaries = collatz.build(&mut builder, advice)?; @@ -95,7 +91,7 @@ fn prove(x0: u32, log_inv_rate: usize) -> Result<(Advice, Proof), anyhow::Error> fn verify(x0: u32, advice: Advice, proof: Proof, log_inv_rate: usize) -> Result<(), anyhow::Error> { let collatz = Collatz::new(x0); - let mut builder = ConstraintSystemBuilder::::new(); + let mut builder = ConstraintSystemBuilder::new(); let boundaries = collatz.build(&mut builder, advice)?; diff --git a/examples/groestl_circuit.rs b/examples/groestl_circuit.rs.disabled similarity index 100% rename from examples/groestl_circuit.rs rename to examples/groestl_circuit.rs.disabled diff --git a/examples/keccakf_circuit.rs b/examples/keccakf_circuit.rs index 1fa583e4a..68f12812d 100644 --- a/examples/keccakf_circuit.rs +++ b/examples/keccakf_circuit.rs @@ -5,9 +5,8 @@ use std::vec; use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -28,7 +27,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -45,7 +43,7 @@ fn main() -> Result<()> { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = log_n_permutations; diff --git a/examples/modular_mul.rs b/examples/modular_mul.rs index b5dfb9e77..1649fa556 100644 --- a/examples/modular_mul.rs +++ b/examples/modular_mul.rs @@ -5,15 +5,14 @@ use std::array; use alloy_primitives::U512; use anyhow::Result; use binius_circuits::{ - builder::ConstraintSystemBuilder, + builder::{types::U, ConstraintSystemBuilder}, lasso::big_integer_ops::{byte_sliced_modular_mul, byte_sliced_test_utils::random_u512}, transparent, }; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; use binius_field::{ - arch::OptimalUnderlier128b, tower_levels::{TowerLevel4, TowerLevel8}, - BinaryField128b, BinaryField1b, BinaryField8b, Field, TowerField, + BinaryField1b, BinaryField8b, Field, TowerField, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; @@ -36,8 +35,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier128b; - type F = BinaryField128b; type B8 = BinaryField8b; const SECURITY_BITS: usize = 100; const WIDTH: usize = 4; @@ -53,7 +50,7 @@ fn main() -> Result<()> { println!("Verifying {} u32 modular multiplications", args.n_multiplications); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = log2_ceil_usize(args.n_multiplications as usize); let mut rng = thread_rng(); @@ -98,7 +95,7 @@ fn main() -> Result<()> { let zero_oracle_carry = transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - let _modded_product = byte_sliced_modular_mul::<_, _, TowerLevel4, TowerLevel8>( + let _modded_product = byte_sliced_modular_mul::( &mut builder, "lasso_bytesliced_mul", &mult_a, diff --git a/examples/sha256_circuit.rs b/examples/sha256_circuit.rs index 2c5d162ac..3835c79fc 100644 --- a/examples/sha256_circuit.rs +++ b/examples/sha256_circuit.rs @@ -5,11 +5,14 @@ use std::array; use anyhow::Result; -use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_circuits::{ + builder::{types::U, ConstraintSystemBuilder}, + unconstrained::unconstrained, +}; use binius_core::{ constraint_system, fiat_shamir::HasherChallenger, oracle::OracleId, tower::CanonicalTowerFamily, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b}; +use binius_field::BinaryField1b; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -32,7 +35,6 @@ struct Args { const COMPRESSION_LOG_LEN: usize = 5; fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -48,15 +50,11 @@ fn main() -> Result<()> { let log_n_compressions = log2_ceil_usize(args.n_compressions as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); let input: [OracleId; 16] = array::try_from_fn(|i| { - unconstrained::<_, _, BinaryField1b>( - &mut builder, - i, - log_n_compressions + COMPRESSION_LOG_LEN, - ) + unconstrained::(&mut builder, i, log_n_compressions + COMPRESSION_LOG_LEN) })?; let _state_out = binius_circuits::sha256::sha256( diff --git a/examples/sha256_circuit_with_lookup.rs b/examples/sha256_circuit_with_lookup.rs index 6b7cd791f..f7f83435b 100644 --- a/examples/sha256_circuit_with_lookup.rs +++ b/examples/sha256_circuit_with_lookup.rs @@ -5,13 +5,14 @@ use std::array; use anyhow::Result; -use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_circuits::{ + builder::{types::U, ConstraintSystemBuilder}, + unconstrained::unconstrained, +}; use binius_core::{ constraint_system, fiat_shamir::HasherChallenger, oracle::OracleId, tower::CanonicalTowerFamily, }; -use binius_field::{ - arch::OptimalUnderlier, as_packed_field::PackedType, BinaryField128b, BinaryField1b, -}; +use binius_field::{arch::OptimalUnderlier, as_packed_field::PackedType, BinaryField1b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -38,7 +39,6 @@ struct Args { const COMPRESSION_LOG_LEN: usize = 5; fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -54,15 +54,11 @@ fn main() -> Result<()> { let log_n_compressions = log2_ceil_usize(args.n_compressions as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating witness").entered(); let input: [OracleId; 16] = array::try_from_fn(|i| { - unconstrained::<_, _, BinaryField1b>( - &mut builder, - i, - log_n_compressions + COMPRESSION_LOG_LEN, - ) + unconstrained::(&mut builder, i, log_n_compressions + COMPRESSION_LOG_LEN) })?; let _state_out = binius_circuits::lasso::sha256( diff --git a/examples/u32_add.rs b/examples/u32_add.rs index 2feed9ef3..22c446ce9 100644 --- a/examples/u32_add.rs +++ b/examples/u32_add.rs @@ -1,9 +1,12 @@ // Copyright 2024-2025 Irreducible Inc. use anyhow::Result; -use binius_circuits::{arithmetic::Flags, builder::ConstraintSystemBuilder}; +use binius_circuits::{ + arithmetic::Flags, + builder::{types::U, ConstraintSystemBuilder}, +}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b}; +use binius_field::BinaryField1b; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -24,7 +27,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -40,15 +42,15 @@ fn main() -> Result<()> { let log_n_additions = log2_ceil_usize(args.n_additions as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField1b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_additions + 5, )?; - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField1b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_additions + 5, diff --git a/examples/u32_mul.rs b/examples/u32_mul.rs index d7ecb1cff..8a8bfafc4 100644 --- a/examples/u32_mul.rs +++ b/examples/u32_mul.rs @@ -4,7 +4,7 @@ use std::array; use anyhow::Result; use binius_circuits::{ - builder::ConstraintSystemBuilder, + builder::{types::U, ConstraintSystemBuilder}, lasso::{ batch::LookupBatch, big_integer_ops::byte_sliced_mul, @@ -14,9 +14,8 @@ use binius_circuits::{ }; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; use binius_field::{ - arch::OptimalUnderlier, tower_levels::{TowerLevel4, TowerLevel8}, - BinaryField128b, BinaryField1b, BinaryField32b, BinaryField8b, Field, + BinaryField1b, BinaryField32b, BinaryField8b, Field, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; @@ -38,7 +37,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -54,12 +52,12 @@ fn main() -> Result<()> { let log_n_muls = log2_ceil_usize(args.n_muls as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); // Assuming our input data is already transposed, i.e a length 4 array of B8's let in_a = array::from_fn(|i| { - binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + binius_circuits::unconstrained::unconstrained::( &mut builder, format!("in_a_{}", i), log_n_muls, @@ -67,7 +65,7 @@ fn main() -> Result<()> { .unwrap() }); let in_b = array::from_fn(|i| { - binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + binius_circuits::unconstrained::unconstrained::( &mut builder, format!("in_b_{}", i), log_n_muls, @@ -84,7 +82,7 @@ fn main() -> Result<()> { let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]); let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); - let _mul_and_cout = byte_sliced_mul::<_, _, TowerLevel4, TowerLevel8>( + let _mul_and_cout = byte_sliced_mul::( &mut builder, "lasso_bytesliced_mul", &in_a, @@ -95,9 +93,9 @@ fn main() -> Result<()> { &mut lookup_batch_add, &mut lookup_batch_dci, )?; - lookup_batch_mul.execute::(&mut builder)?; - lookup_batch_add.execute::(&mut builder)?; - lookup_batch_dci.execute::(&mut builder)?; + lookup_batch_mul.execute::(&mut builder)?; + lookup_batch_add.execute::(&mut builder)?; + lookup_batch_dci.execute::(&mut builder)?; drop(trace_gen_scope); diff --git a/examples/u32add_with_lookup.rs b/examples/u32add_with_lookup.rs index 1eba4ab79..9bc06cfbd 100644 --- a/examples/u32add_with_lookup.rs +++ b/examples/u32add_with_lookup.rs @@ -1,11 +1,10 @@ // Copyright 2024-2025 Irreducible Inc. use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; use binius_field::{ - arch::OptimalUnderlier, as_packed_field::PackedType, BinaryField128b, BinaryField1b, - BinaryField8b, + arch::OptimalUnderlier, as_packed_field::PackedType, BinaryField1b, BinaryField8b, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; @@ -29,7 +28,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -45,20 +43,20 @@ fn main() -> Result<()> { let log_n_additions = log2_ceil_usize(args.n_additions as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_additions + 2, )?; - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_additions + 2, )?; - let _product = binius_circuits::lasso::u32add::<_, _, BinaryField8b, BinaryField8b>( + let _product = binius_circuits::lasso::u32add::( &mut builder, "out_c", in_a, diff --git a/examples/u8mul.rs b/examples/u8mul.rs index 90d706dc5..cc6e0be9f 100644 --- a/examples/u8mul.rs +++ b/examples/u8mul.rs @@ -2,11 +2,11 @@ use anyhow::Result; use binius_circuits::{ - builder::ConstraintSystemBuilder, + builder::{types::U, ConstraintSystemBuilder}, lasso::{batch::LookupBatch, lookups}, }; use binius_core::{constraint_system, fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField32b, BinaryField8b}; +use binius_field::{BinaryField32b, BinaryField8b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::DefaultEvaluationDomainFactory; @@ -27,7 +27,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -43,15 +42,15 @@ fn main() -> Result<()> { let log_n_multiplications = log2_ceil_usize(args.n_multiplications as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); - let in_a = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + let in_a = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_a", log_n_multiplications, )?; - let in_b = binius_circuits::unconstrained::unconstrained::<_, _, BinaryField8b>( + let in_b = binius_circuits::unconstrained::unconstrained::( &mut builder, "in_b", log_n_multiplications, @@ -70,7 +69,7 @@ fn main() -> Result<()> { args.n_multiplications as usize, )?; - lookup_batch.execute::<_, _, BinaryField32b>(&mut builder)?; + lookup_batch.execute::(&mut builder)?; drop(trace_gen_scope); let witness = builder diff --git a/examples/vision32b_circuit.rs b/examples/vision32b_circuit.rs index 61c74d762..b5afa6f58 100644 --- a/examples/vision32b_circuit.rs +++ b/examples/vision32b_circuit.rs @@ -10,11 +10,11 @@ use std::array; use anyhow::Result; -use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::builder::{types::U, ConstraintSystemBuilder}; use binius_core::{ constraint_system, fiat_shamir::HasherChallenger, oracle::OracleId, tower::CanonicalTowerFamily, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField32b, BinaryField8b}; +use binius_field::{BinaryField32b, BinaryField8b}; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::IsomorphicEvaluationDomainFactory; @@ -35,7 +35,6 @@ struct Args { } fn main() -> Result<()> { - type U = OptimalUnderlier; const SECURITY_BITS: usize = 100; adjust_thread_pool() @@ -51,11 +50,11 @@ fn main() -> Result<()> { let log_n_permutations = log2_ceil_usize(args.n_permutations as usize); let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let trace_gen_scope = tracing::info_span!("generating trace").entered(); let state_in: [OracleId; 24] = array::from_fn(|i| { - binius_circuits::unconstrained::unconstrained::<_, _, BinaryField32b>( + binius_circuits::unconstrained::unconstrained::( &mut builder, format!("p_in_{i}"), log_n_permutations, From 28fdae579ac05ec182b343fe76223ba9905f2dce Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Thu, 13 Feb 2025 19:23:52 +0100 Subject: [PATCH 25/50] [field] Simplify usage of PackedExtension, RepackedExtension by making each trait imply its bounds (#24) --- crates/core/src/constraint_system/prove.rs | 2 +- crates/core/src/piop/prove.rs | 8 ++--- crates/core/src/piop/tests.rs | 4 +-- crates/core/src/protocols/fri/prove.rs | 4 +-- .../protocols/gkr_gpa/gpa_sumcheck/prove.rs | 12 +++---- crates/core/src/protocols/gkr_gpa/prove.rs | 6 +--- crates/core/src/protocols/gkr_gpa/tests.rs | 14 +++------ .../src/protocols/sumcheck/prove/oracles.rs | 2 +- .../protocols/sumcheck/prove/prover_state.rs | 4 +-- .../sumcheck/prove/regular_sumcheck.rs | 10 +++--- .../protocols/sumcheck/prove/univariate.rs | 18 +++-------- .../src/protocols/sumcheck/prove/zerocheck.rs | 10 +++--- crates/core/src/protocols/sumcheck/tests.rs | 2 +- .../core/src/protocols/sumcheck/zerocheck.rs | 6 ++-- crates/core/src/reed_solomon/reed_solomon.rs | 8 ++--- crates/core/src/ring_switch/eq_ind.rs | 2 +- crates/core/src/tensor_algebra.rs | 7 +---- crates/field/benches/packed_extension_mul.rs | 1 - .../benches/packed_field_subfield_ops.rs | 10 ++---- crates/field/src/aes_field.rs | 8 ++--- crates/field/src/packed.rs | 9 ++---- crates/field/src/packed_extension.rs | 19 +++++------- crates/field/src/packed_extension_ops.rs | 31 ++++++------------- crates/field/src/transpose.rs | 4 +-- crates/hal/src/backend.rs | 6 ++-- crates/hal/src/cpu.rs | 4 +-- crates/hal/src/sumcheck_round_calculator.rs | 6 ++-- crates/hash/src/groestl/hasher.rs | 1 - crates/hash/src/vision.rs | 1 - crates/math/src/mle_adapters.rs | 8 ++--- crates/math/src/univariate.rs | 15 ++++----- crates/ntt/src/additive_ntt.rs | 20 ++++++------ crates/ntt/src/tests/ntt_tests.rs | 13 +++----- 33 files changed, 102 insertions(+), 173 deletions(-) diff --git a/crates/core/src/constraint_system/prove.rs b/crates/core/src/constraint_system/prove.rs index 2a8be85fa..6d7d3f83f 100644 --- a/crates/core/src/constraint_system/prove.rs +++ b/crates/core/src/constraint_system/prove.rs @@ -513,7 +513,7 @@ where P: PackedExtension + PackedExtension + PackedExtension, - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, { let univariate_prover = sumcheck::prove::constraint_set_zerocheck_prover::<_, _, FBase, _, _>( diff --git a/crates/core/src/piop/prove.rs b/crates/core/src/piop/prove.rs index 49a24066a..27e480d38 100644 --- a/crates/core/src/piop/prove.rs +++ b/crates/core/src/piop/prove.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{ - packed::set_packed_slice, BinaryField, ExtensionField, Field, PackedExtension, PackedField, + packed::set_packed_slice, BinaryField, Field, PackedExtension, PackedField, PackedFieldIndexable, SerializeCanonical, TowerField, }; use binius_hal::ComputationBackend; @@ -101,7 +101,7 @@ pub fn commit( multilins: &[M], ) -> Result, Error> where - F: BinaryField + ExtensionField, + F: BinaryField, FEncode: BinaryField, P: PackedField + PackedExtension, M: MultilinearPoly

, @@ -166,7 +166,7 @@ pub fn prove Result<(), Error> where - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, FDomain: Field, FEncode: BinaryField, P: PackedFieldIndexable @@ -251,7 +251,7 @@ fn prove_interleaved_fri_sumcheck, ) -> Result<(), Error> where - F: TowerField + ExtensionField, + F: TowerField, FEncode: BinaryField, P: PackedFieldIndexable + PackedExtension, MTScheme: MerkleTreeScheme, diff --git a/crates/core/src/piop/tests.rs b/crates/core/src/piop/tests.rs index 1c9377a82..88ee6caa5 100644 --- a/crates/core/src/piop/tests.rs +++ b/crates/core/src/piop/tests.rs @@ -3,7 +3,7 @@ use std::iter::repeat_with; use binius_field::{ - BinaryField, BinaryField16b, BinaryField8b, DeserializeCanonical, ExtensionField, Field, + BinaryField, BinaryField16b, BinaryField8b, DeserializeCanonical, Field, PackedBinaryField2x128b, PackedExtension, PackedField, PackedFieldIndexable, SerializeCanonical, TowerField, }; @@ -104,7 +104,7 @@ fn commit_prove_verify( merkle_prover: &impl MerkleTreeProver, log_inv_rate: usize, ) where - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, FDomain: BinaryField, FEncode: BinaryField, P: PackedFieldIndexable diff --git a/crates/core/src/protocols/fri/prove.rs b/crates/core/src/protocols/fri/prove.rs index 287eef45f..2bb66fd6a 100644 --- a/crates/core/src/protocols/fri/prove.rs +++ b/crates/core/src/protocols/fri/prove.rs @@ -176,7 +176,7 @@ pub fn commit_interleaved( message: &[P], ) -> Result, Error> where - F: BinaryField + ExtensionField, + F: BinaryField, FA: BinaryField, P: PackedField + PackedExtension, PA: PackedField, @@ -211,7 +211,7 @@ pub fn commit_interleaved_with( message_writer: impl FnOnce(&mut [P]), ) -> Result, Error> where - F: BinaryField + ExtensionField, + F: BinaryField, FA: BinaryField, P: PackedField + PackedExtension, PA: PackedField, diff --git a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs index 51736bf52..90bb03462 100644 --- a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs +++ b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs @@ -2,9 +2,7 @@ use std::ops::Range; -use binius_field::{ - util::eq, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, -}; +use binius_field::{util::eq, Field, PackedExtension, PackedField, PackedFieldIndexable}; use binius_hal::{ComputationBackend, SumcheckEvaluator}; use binius_math::{ CompositionPolyOS, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly, @@ -48,7 +46,7 @@ where impl<'a, F, FDomain, P, Composition, M, Backend> GPAProver<'a, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension @@ -193,7 +191,7 @@ where impl SumcheckProver for GPAProver<'_, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension @@ -290,7 +288,7 @@ where impl SumcheckEvaluator for GPAEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension + PackedExtension, FDomain: Field, Composition: CompositionPolyOS

, @@ -344,7 +342,7 @@ where impl SumcheckInterpolator for GPAEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension, FDomain: Field, Composition: CompositionPolyOS

, diff --git a/crates/core/src/protocols/gkr_gpa/prove.rs b/crates/core/src/protocols/gkr_gpa/prove.rs index c1ab27bcb..e48d0c272 100644 --- a/crates/core/src/protocols/gkr_gpa/prove.rs +++ b/crates/core/src/protocols/gkr_gpa/prove.rs @@ -1,8 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{ - ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField, -}; +use binius_field::{Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField}; use binius_hal::ComputationBackend; use binius_math::{ extrapolate_line_scalar, EvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, @@ -46,7 +44,6 @@ where + PackedExtension + PackedExtension, FDomain: Field, - P::Scalar: Field + ExtensionField, Challenger_: Challenger, Backend: ComputationBackend, { @@ -266,7 +263,6 @@ where where FDomain: Field, P: PackedExtension, - F: ExtensionField, { // test same layer let Some(first_prover) = provers.first() else { diff --git a/crates/core/src/protocols/gkr_gpa/tests.rs b/crates/core/src/protocols/gkr_gpa/tests.rs index 87fcefa27..263544606 100644 --- a/crates/core/src/protocols/gkr_gpa/tests.rs +++ b/crates/core/src/protocols/gkr_gpa/tests.rs @@ -7,8 +7,8 @@ use binius_field::{ as_packed_field::{PackScalar, PackedType}, packed::set_packed_slice, underlier::{UnderlierType, WithUnderlier}, - BinaryField128b, BinaryField32b, ExtensionField, Field, PackedExtension, PackedField, - PackedFieldIndexable, RepackedExtension, TowerField, + BinaryField128b, BinaryField32b, Field, PackedExtension, PackedField, PackedFieldIndexable, + RepackedExtension, TowerField, }; use binius_math::{IsomorphicEvaluationDomainFactory, MultilinearExtension}; use bytemuck::zeroed_vec; @@ -24,15 +24,11 @@ use crate::{ witness::MultilinearExtensionIndex, }; -fn generate_poly_helper( +fn generate_poly_helper, F: Field>( rng: &mut StdRng, n_vars: usize, n_multilinears: usize, -) -> Vec<(MultilinearExtension

, F)> -where - P: PackedField>, - F: Field, -{ +) -> Vec<(MultilinearExtension

, F)> { repeat_with(|| { let values = repeat_with(|| F::random(&mut *rng)) .take(1 << n_vars) @@ -119,7 +115,7 @@ fn run_prove_verify_batch_test() where U: UnderlierType + PackScalar, P: PackedExtension + RepackedExtension

+ PackedFieldIndexable, - F: TowerField + ExtensionField, + F: TowerField, FS: TowerField, { let rng = StdRng::seed_from_u64(0); diff --git a/crates/core/src/protocols/sumcheck/prove/oracles.rs b/crates/core/src/protocols/sumcheck/prove/oracles.rs index 8b34f4e57..c6f4b7980 100644 --- a/crates/core/src/protocols/sumcheck/prove/oracles.rs +++ b/crates/core/src/protocols/sumcheck/prove/oracles.rs @@ -54,7 +54,7 @@ where + PackedExtension + PackedExtension + PackedExtension, - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, FBase: TowerField + ExtensionField + TryFrom, FDomain: Field, Backend: ComputationBackend, diff --git a/crates/core/src/protocols/sumcheck/prove/prover_state.rs b/crates/core/src/protocols/sumcheck/prove/prover_state.rs index a7e252fe9..185d32f84 100644 --- a/crates/core/src/protocols/sumcheck/prove/prover_state.rs +++ b/crates/core/src/protocols/sumcheck/prove/prover_state.rs @@ -5,7 +5,7 @@ use std::{ sync::atomic::{AtomicBool, Ordering}, }; -use binius_field::{util::powers, ExtensionField, Field, PackedExtension, PackedField}; +use binius_field::{util::powers, Field, PackedExtension, PackedField}; use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear}; use binius_math::{ evaluate_univariate, CompositionPolyOS, MLEDirectAdapter, MultilinearPoly, MultilinearQuery, @@ -70,7 +70,7 @@ where impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend> where FDomain: Field, - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, diff --git a/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs b/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs index 8cae1a0a1..9ecb400a9 100644 --- a/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs @@ -2,7 +2,7 @@ use std::{marker::PhantomData, ops::Range}; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField}; +use binius_field::{Field, PackedExtension, PackedField}; use binius_hal::{ComputationBackend, SumcheckEvaluator}; use binius_math::{ CompositionPolyOS, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly, @@ -82,7 +82,7 @@ where impl<'a, F, FDomain, P, Composition, M, Backend> RegularSumcheckProver<'a, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedField + PackedExtension + PackedExtension, Composition: CompositionPolyOS

, @@ -166,7 +166,7 @@ where impl SumcheckProver for RegularSumcheckProver<'_, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedField + PackedExtension + PackedExtension, Composition: CompositionPolyOS

, @@ -216,7 +216,7 @@ where impl SumcheckEvaluator for RegularSumcheckEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension + PackedExtension, FDomain: Field, Composition: CompositionPolyOS

, @@ -256,7 +256,7 @@ where impl SumcheckInterpolator for RegularSumcheckEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension, FDomain: Field, { diff --git a/crates/core/src/protocols/sumcheck/prove/univariate.rs b/crates/core/src/protocols/sumcheck/prove/univariate.rs index 7c373f96c..91995a1b7 100644 --- a/crates/core/src/protocols/sumcheck/prove/univariate.rs +++ b/crates/core/src/protocols/sumcheck/prove/univariate.rs @@ -95,7 +95,7 @@ pub fn univariatizing_reduction_prover<'a, F, FDomain, P, Backend>( backend: &'a Backend, ) -> Result, Error> where - F: TowerField + ExtensionField, + F: TowerField, FDomain: TowerField, P: PackedFieldIndexable + PackedExtension @@ -132,12 +132,7 @@ where } #[derive(Debug)] -struct ParFoldStates -where - FBase: Field + PackedField, - P: PackedField + PackedExtension, - P::Scalar: ExtensionField, -{ +struct ParFoldStates> { /// Evaluations of a multilinear subcube, embedded into P (see MultilinearPoly::subcube_evals). Scratch space. evals: Vec

, /// `evals` cast to base field and transposed to 2^skip_rounds * 2^log_batch row-major form. Scratch space. @@ -150,12 +145,7 @@ where round_evals: Vec>, } -impl ParFoldStates -where - FBase: Field, - P: PackedField + PackedExtension, - P::Scalar: ExtensionField, -{ +impl> ParFoldStates { fn new( n_multilinears: usize, skip_rounds: usize, @@ -335,7 +325,7 @@ pub fn zerocheck_univariate_evals where FDomain: TowerField, FBase: ExtensionField, - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, P: PackedFieldIndexable + PackedExtension + PackedExtension, diff --git a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs index bc6cbd852..866d7f77d 100644 --- a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs @@ -111,7 +111,7 @@ where impl<'a, 'm, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend> UnivariateZerocheck<'a, 'm, FDomain, FBase, P, CompositionBase, Composition, M, Backend> where - F: Field + ExtensionField + ExtensionField, + F: Field, FDomain: Field, FBase: ExtensionField, P: PackedFieldIndexable @@ -245,7 +245,7 @@ impl<'a, 'm, F, FDomain, FBase, P, CompositionBase, Composition, M, Backend> UnivariateZerocheckProver<'a, F> for UnivariateZerocheck<'a, 'm, FDomain, FBase, P, CompositionBase, Composition, M, Backend> where - F: TowerField + ExtensionField + ExtensionField, + F: TowerField, FDomain: TowerField, FBase: ExtensionField, P: PackedFieldIndexable @@ -447,7 +447,7 @@ where impl<'a, F, FDomain, P, Composition, M, Backend> ZerocheckProver<'a, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension, Composition: CompositionPolyOS

, @@ -537,7 +537,7 @@ where impl SumcheckProver for ZerocheckProver<'_, FDomain, P, Composition, M, Backend> where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension, Composition: CompositionPolyOS

, @@ -766,7 +766,7 @@ where impl SumcheckInterpolator for ZerocheckLaterRoundEvaluator<'_, P, FDomain, Composition> where - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension, FDomain: Field, { diff --git a/crates/core/src/protocols/sumcheck/tests.rs b/crates/core/src/protocols/sumcheck/tests.rs index 667f63162..e798de0fb 100644 --- a/crates/core/src/protocols/sumcheck/tests.rs +++ b/crates/core/src/protocols/sumcheck/tests.rs @@ -267,7 +267,7 @@ fn make_test_sumcheck<'a, F, FDomain, P, PExt, Backend>( impl SumcheckProver + 'a, ) where - F: Field + ExtensionField + ExtensionField, + F: Field, FDomain: Field, P: PackedField, PExt: PackedField diff --git a/crates/core/src/protocols/sumcheck/zerocheck.rs b/crates/core/src/protocols/sumcheck/zerocheck.rs index 4bce691c9..bdf4cd8e7 100644 --- a/crates/core/src/protocols/sumcheck/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/zerocheck.rs @@ -195,8 +195,8 @@ mod tests { use std::{iter, sync::Arc}; use binius_field::{ - BinaryField128b, BinaryField32b, BinaryField8b, ExtensionField, PackedBinaryField1x128b, - PackedExtension, PackedFieldIndexable, PackedSubfield, RepackedExtension, + BinaryField128b, BinaryField32b, BinaryField8b, PackedBinaryField1x128b, PackedExtension, + PackedFieldIndexable, PackedSubfield, RepackedExtension, }; use binius_hal::{make_portable_backend, ComputationBackend, ComputationBackendExt}; use binius_math::{ @@ -236,7 +236,7 @@ mod tests { Backend, > where - F: Field + ExtensionField, + F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension + RepackedExtension

, Composition: CompositionPolyOS

, diff --git a/crates/core/src/reed_solomon/reed_solomon.rs b/crates/core/src/reed_solomon/reed_solomon.rs index 2b080050e..ade7bace2 100644 --- a/crates/core/src/reed_solomon/reed_solomon.rs +++ b/crates/core/src/reed_solomon/reed_solomon.rs @@ -160,15 +160,11 @@ where /// /// * If the `code` buffer does not have capacity for `len() << log_batch_size` field elements. #[instrument(skip_all, level = "debug")] - pub fn encode_ext_batch_inplace( + pub fn encode_ext_batch_inplace>( &self, code: &mut [PE], log_batch_size: usize, - ) -> Result<(), Error> - where - PE: RepackedExtension

, - PE::Scalar: ExtensionField<

::Scalar>, - { + ) -> Result<(), Error> { self.encode_batch_inplace(PE::cast_bases_mut(code), log_batch_size + PE::Scalar::LOG_DEGREE) } } diff --git a/crates/core/src/ring_switch/eq_ind.rs b/crates/core/src/ring_switch/eq_ind.rs index c6c448c0e..3505969fd 100644 --- a/crates/core/src/ring_switch/eq_ind.rs +++ b/crates/core/src/ring_switch/eq_ind.rs @@ -145,7 +145,7 @@ where impl MultivariatePoly for RingSwitchEqInd where FSub: TowerField, - F: TowerField + PackedField + ExtensionField + PackedExtension, + F: TowerField + PackedField + PackedExtension, { fn n_vars(&self) -> usize { self.z_vals.len() diff --git a/crates/core/src/tensor_algebra.rs b/crates/core/src/tensor_algebra.rs index 0832803af..cb071f259 100644 --- a/crates/core/src/tensor_algebra.rs +++ b/crates/core/src/tensor_algebra.rs @@ -123,12 +123,7 @@ where } } -impl TensorAlgebra -where - F: Field, - FE: ExtensionField + PackedExtension, - FE::Scalar: ExtensionField, -{ +impl + PackedExtension> TensorAlgebra { /// Multiply by an element from the vertical subring. /// /// Internally, this performs a transpose, vertical scaling, then transpose sequence. If diff --git a/crates/field/benches/packed_extension_mul.rs b/crates/field/benches/packed_extension_mul.rs index a119dc9a7..836b8474f 100644 --- a/crates/field/benches/packed_extension_mul.rs +++ b/crates/field/benches/packed_extension_mul.rs @@ -16,7 +16,6 @@ fn benchmark_packed_extension_mul( label: &str, ) where F: Field, - BinaryField128b: ExtensionField, PackedBinaryField2x128b: PackedExtension, { let mut rng = thread_rng(); diff --git a/crates/field/benches/packed_field_subfield_ops.rs b/crates/field/benches/packed_field_subfield_ops.rs index cfd2dff39..bc2a36539 100644 --- a/crates/field/benches/packed_field_subfield_ops.rs +++ b/crates/field/benches/packed_field_subfield_ops.rs @@ -5,8 +5,8 @@ use std::array; use binius_field::{ packed::mul_by_subfield_scalar, underlier::{UnderlierType, WithUnderlier}, - BinaryField1b, BinaryField32b, BinaryField4b, BinaryField64b, BinaryField8b, ExtensionField, - Field, PackedBinaryField16x8b, PackedBinaryField1x128b, PackedBinaryField2x128b, + BinaryField1b, BinaryField32b, BinaryField4b, BinaryField64b, BinaryField8b, Field, + PackedBinaryField16x8b, PackedBinaryField1x128b, PackedBinaryField2x128b, PackedBinaryField32x8b, PackedBinaryField4x128b, PackedBinaryField4x32b, PackedBinaryField64x8b, PackedBinaryField8x32b, PackedBinaryField8x64b, PackedExtension, }; @@ -17,11 +17,7 @@ use rand::thread_rng; const BATCH_SIZE: usize = 32; -fn bench_mul_subfield(group: &mut BenchmarkGroup<'_, WallTime>) -where - PE: PackedExtension>, - F: Field, -{ +fn bench_mul_subfield, F: Field>(group: &mut BenchmarkGroup<'_, WallTime>) { let mut rng = thread_rng(); let packed: [PE; BATCH_SIZE] = array::from_fn(|_| PE::random(&mut rng)); let scalars: [F; BATCH_SIZE] = array::from_fn(|_| F::random(&mut rng)); diff --git a/crates/field/src/aes_field.rs b/crates/field/src/aes_field.rs index f37104509..f2d480bff 100644 --- a/crates/field/src/aes_field.rs +++ b/crates/field/src/aes_field.rs @@ -164,8 +164,8 @@ impl Transformation for SubfieldTransformer>, - OEP: PackedExtension>, + IEP: PackedExtension, + OEP: PackedExtension, T: Transformation, PackedSubfield>, { fn transform(&self, input: &IEP) -> OEP { @@ -178,9 +178,7 @@ where pub fn make_aes_to_binary_packed_transformer() -> impl Transformation where IP: PackedExtension, - IP::Scalar: ExtensionField, OP: PackedExtension, - OP::Scalar: ExtensionField, PackedSubfield: PackedTransformationFactory>, { @@ -197,9 +195,7 @@ where pub fn make_binary_to_aes_packed_transformer() -> impl Transformation where IP: PackedExtension, - IP::Scalar: ExtensionField, OP: PackedExtension, - OP::Scalar: ExtensionField, PackedSubfield: PackedTransformationFactory>, { diff --git a/crates/field/src/packed.rs b/crates/field/src/packed.rs index 471121b01..b47a58212 100644 --- a/crates/field/src/packed.rs +++ b/crates/field/src/packed.rs @@ -20,8 +20,7 @@ use super::{ Error, }; use crate::{ - arithmetic_traits::InvertOrZero, underlier::WithUnderlier, BinaryField, ExtensionField, Field, - PackedExtension, + arithmetic_traits::InvertOrZero, underlier::WithUnderlier, BinaryField, Field, PackedExtension, }; /// A packed field represents a vector of underlying field elements. @@ -371,11 +370,7 @@ pub const fn len_packed_slice(packed: &[P]) -> usize { } /// Multiply packed field element by a subfield scalar. -pub fn mul_by_subfield_scalar(val: P, multiplier: FS) -> P -where - P: PackedExtension>, - FS: Field, -{ +pub fn mul_by_subfield_scalar, FS: Field>(val: P, multiplier: FS) -> P { use crate::underlier::UnderlierType; // This is a workaround not to make the multiplication slower in certain cases. diff --git a/crates/field/src/packed_extension.rs b/crates/field/src/packed_extension.rs index 3ecd7b662..9e03ac3a2 100644 --- a/crates/field/src/packed_extension.rs +++ b/crates/field/src/packed_extension.rs @@ -59,7 +59,7 @@ where /// /// fn cast_then_iter<'a, F, PE>(packed: &'a PE) -> impl Iterator + 'a /// where -/// PE: PackedExtension>, +/// PE: PackedExtension, /// F: Field, /// { /// PE::cast_base_ref(packed).into_iter() @@ -71,9 +71,8 @@ where /// In order for the above relation to be guaranteed, the memory representation of /// `PackedExtensionField` element must be the same as a slice of the underlying `PackedField` /// element. -pub trait PackedExtension: PackedField -where - Self::Scalar: ExtensionField, +pub trait PackedExtension: + PackedField> + WithUnderlier> { type PackedSubfield: PackedField; @@ -187,9 +186,7 @@ where /// This trait is a shorthand for the case `PackedExtension` which is a /// quite common case in our codebase. pub trait RepackedExtension: - PackedExtension -where - Self::Scalar: ExtensionField, + PackedField> + PackedExtension { } @@ -202,10 +199,8 @@ where /// This trait adds shortcut methods for the case `PackedExtension` which is a /// quite common case in our codebase. -pub trait PackedExtensionIndexable: PackedExtension -where - Self::Scalar: ExtensionField, - Self::PackedSubfield: PackedFieldIndexable, +pub trait PackedExtensionIndexable: + PackedExtension + PackedField> { fn unpack_base_scalars(packed: &[Self]) -> &[F] { Self::PackedSubfield::unpack_scalars(Self::cast_bases(packed)) @@ -219,7 +214,7 @@ where impl PackedExtensionIndexable for PT where F: Field, - PT: PackedExtension, PackedSubfield: PackedFieldIndexable>, + PT: PackedExtension, { } diff --git a/crates/field/src/packed_extension_ops.rs b/crates/field/src/packed_extension_ops.rs index 6d2ede809..53dc4e0b6 100644 --- a/crates/field/src/packed_extension_ops.rs +++ b/crates/field/src/packed_extension_ops.rs @@ -6,21 +6,17 @@ use binius_maybe_rayon::prelude::{ use crate::{Error, ExtensionField, Field, PackedExtension, PackedField}; -pub fn ext_base_mul(lhs: &mut [PE], rhs: &[PE::PackedSubfield]) -> Result<(), Error> -where - PE: PackedExtension, - PE::Scalar: ExtensionField, - F: Field, -{ +pub fn ext_base_mul, F: Field>( + lhs: &mut [PE], + rhs: &[PE::PackedSubfield], +) -> Result<(), Error> { ext_base_op(lhs, rhs, |_, lhs, broadcasted_rhs| PE::cast_ext(lhs.cast_base() * broadcasted_rhs)) } -pub fn ext_base_mul_par(lhs: &mut [PE], rhs: &[PE::PackedSubfield]) -> Result<(), Error> -where - PE: PackedExtension, - PE::Scalar: ExtensionField, - F: Field, -{ +pub fn ext_base_mul_par, F: Field>( + lhs: &mut [PE], + rhs: &[PE::PackedSubfield], +) -> Result<(), Error> { ext_base_op_par(lhs, rhs, |_, lhs, broadcasted_rhs| { PE::cast_ext(lhs.cast_base() * broadcasted_rhs) }) @@ -29,15 +25,10 @@ where /// # Safety /// /// Width of PackedSubfield is >= the width of the field implementing PackedExtension. -pub unsafe fn get_packed_subfields_at_pe_idx( +pub unsafe fn get_packed_subfields_at_pe_idx, F: Field>( packed_subfields: &[PE::PackedSubfield], i: usize, -) -> PE::PackedSubfield -where - PE: PackedExtension, - PE::Scalar: ExtensionField, - F: Field, -{ +) -> PE::PackedSubfield { let bottom_most_scalar_idx = i * PE::WIDTH; let bottom_most_scalar_idx_in_subfield_arr = bottom_most_scalar_idx / PE::PackedSubfield::WIDTH; let bottom_most_scalar_idx_within_packed_subfield = @@ -67,7 +58,6 @@ pub fn ext_base_op( ) -> Result<(), Error> where PE: PackedExtension, - PE::Scalar: ExtensionField, F: Field, Func: Fn(usize, PE, PE::PackedSubfield) -> PE, { @@ -93,7 +83,6 @@ pub fn ext_base_op_par( ) -> Result<(), Error> where PE: PackedExtension, - PE::Scalar: ExtensionField, F: Field, Func: Fn(usize, PE, PE::PackedSubfield) -> PE + std::marker::Sync, { diff --git a/crates/field/src/transpose.rs b/crates/field/src/transpose.rs index efabbfc60..b3838b57a 100644 --- a/crates/field/src/transpose.rs +++ b/crates/field/src/transpose.rs @@ -2,7 +2,7 @@ use binius_utils::checked_arithmetics::log2_strict_usize; -use super::{packed::PackedField, ExtensionField, PackedFieldIndexable, RepackedExtension}; +use super::{packed::PackedField, Field, PackedFieldIndexable, RepackedExtension}; /// Error thrown when a transpose operation fails. #[derive(Clone, thiserror::Error, Debug)] @@ -76,7 +76,7 @@ pub fn square_transpose(log_n: usize, elems: &mut [P]) -> Result pub fn transpose_scalars(src: &[PE], dst: &mut [P]) -> Result<(), Error> where P: PackedField, - FE: ExtensionField, + FE: Field, PE: PackedFieldIndexable + RepackedExtension

, { let len = src.len(); diff --git a/crates/hal/src/backend.rs b/crates/hal/src/backend.rs index 02a443e7e..b91dc9b35 100644 --- a/crates/hal/src/backend.rs +++ b/crates/hal/src/backend.rs @@ -5,7 +5,7 @@ use std::{ ops::{Deref, DerefMut}, }; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField}; +use binius_field::{Field, PackedExtension, PackedField}; use binius_math::{ CompositionPolyOS, MultilinearExtension, MultilinearPoly, MultilinearQuery, MultilinearQueryRef, }; @@ -53,7 +53,7 @@ pub trait ComputationBackend: Send + Sync + Debug { ) -> Result>, Error> where FDomain: Field, - P: PackedField> + PackedExtension, + P: PackedExtension, M: MultilinearPoly

+ Send + Sync, Evaluator: SumcheckEvaluator + Sync, Composition: CompositionPolyOS

; @@ -95,7 +95,7 @@ where ) -> Result>, Error> where FDomain: Field, - P: PackedField> + PackedExtension, + P: PackedExtension, M: MultilinearPoly

+ Send + Sync, Evaluator: SumcheckEvaluator + Sync, Composition: CompositionPolyOS

, diff --git a/crates/hal/src/cpu.rs b/crates/hal/src/cpu.rs index b9f079f75..ef19db96b 100644 --- a/crates/hal/src/cpu.rs +++ b/crates/hal/src/cpu.rs @@ -2,7 +2,7 @@ use std::fmt::Debug; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField}; +use binius_field::{Field, PackedExtension, PackedField}; use binius_math::{ eq_ind_partial_eval, CompositionPolyOS, MultilinearExtension, MultilinearPoly, MultilinearQueryRef, @@ -47,7 +47,7 @@ impl ComputationBackend for CpuBackend { ) -> Result>, Error> where FDomain: Field, - P: PackedField> + PackedExtension, + P: PackedExtension, M: MultilinearPoly

+ Send + Sync, Evaluator: SumcheckEvaluator + Sync, Composition: CompositionPolyOS

, diff --git a/crates/hal/src/sumcheck_round_calculator.rs b/crates/hal/src/sumcheck_round_calculator.rs index 8fe276662..8e79713cf 100644 --- a/crates/hal/src/sumcheck_round_calculator.rs +++ b/crates/hal/src/sumcheck_round_calculator.rs @@ -6,7 +6,7 @@ use std::iter; -use binius_field::{ExtensionField, Field, PackedExtension, PackedField, PackedSubfield}; +use binius_field::{Field, PackedExtension, PackedField, PackedSubfield}; use binius_math::{ deinterleave, extrapolate_lines, CompositionPolyOS, MultilinearPoly, MultilinearQuery, MultilinearQueryRef, @@ -54,7 +54,7 @@ pub(crate) fn calculate_round_evals( ) -> Result>, Error> where FDomain: Field, - F: Field + ExtensionField, + F: Field, P: PackedField + PackedExtension, M: MultilinearPoly

+ Send + Sync, Evaluator: SumcheckEvaluator + Sync, @@ -82,7 +82,7 @@ fn calculate_round_evals_with_access Result>, Error> where FDomain: Field, - F: ExtensionField, + F: Field, P: PackedField + PackedExtension, Evaluator: SumcheckEvaluator + Sync, Access: SumcheckMultilinearAccess

+ Sync, diff --git a/crates/hash/src/groestl/hasher.rs b/crates/hash/src/groestl/hasher.rs index ae8f3144c..d6ee087ce 100644 --- a/crates/hash/src/groestl/hasher.rs +++ b/crates/hash/src/groestl/hasher.rs @@ -153,7 +153,6 @@ impl Hasher

for Groestl256 where F: BinaryField + From + Into, P: PackedExtension, - P::Scalar: ExtensionField, OptimalUnderlier256b: PackScalar + Divisible, Self: UpdateOverSlice, { diff --git a/crates/hash/src/vision.rs b/crates/hash/src/vision.rs index 45940192a..750db41b8 100644 --- a/crates/hash/src/vision.rs +++ b/crates/hash/src/vision.rs @@ -221,7 +221,6 @@ where U: PackScalar + Divisible, F: BinaryField + From + Into, P: PackedExtension, - P::Scalar: ExtensionField, PackedAESBinaryField8x32b: WithUnderlier, { type Digest = PackedType; diff --git a/crates/math/src/mle_adapters.rs b/crates/math/src/mle_adapters.rs index 1a8d8fec8..47c227cd1 100644 --- a/crates/math/src/mle_adapters.rs +++ b/crates/math/src/mle_adapters.rs @@ -273,11 +273,9 @@ where P: PackedField, Data: Deref + Send + Sync + Debug + 'a, { - pub fn specialize_arc_dyn(self) -> Arc + Send + Sync + 'a> - where - PE: PackedField + RepackedExtension

, - PE::Scalar: ExtensionField, - { + pub fn specialize_arc_dyn>( + self, + ) -> Arc + Send + Sync + 'a> { self.specialize().upcast_arc_dyn() } } diff --git a/crates/math/src/univariate.rs b/crates/math/src/univariate.rs index 7e12f4d19..d2339d213 100644 --- a/crates/math/src/univariate.rs +++ b/crates/math/src/univariate.rs @@ -223,10 +223,11 @@ impl InterpolationDomain { self.evaluation_domain.with_infinity() } - pub fn extrapolate(&self, values: &[PE], x: PE::Scalar) -> Result - where - PE: PackedExtension>, - { + pub fn extrapolate>( + &self, + values: &[PE], + x: PE::Scalar, + ) -> Result { self.evaluation_domain.extrapolate(values, x) } @@ -243,11 +244,7 @@ impl InterpolationDomain { /// Extrapolates lines through a pair of packed fields at a single point from a subfield. #[inline] -pub fn extrapolate_line(x0: P, x1: P, z: FS) -> P -where - P: PackedExtension>, - FS: Field, -{ +pub fn extrapolate_line, FS: Field>(x0: P, x1: P, z: FS) -> P { x0 + mul_by_subfield_scalar(x1 - x0, z) } diff --git a/crates/ntt/src/additive_ntt.rs b/crates/ntt/src/additive_ntt.rs index a4828d5fa..7b5e0bcc4 100644 --- a/crates/ntt/src/additive_ntt.rs +++ b/crates/ntt/src/additive_ntt.rs @@ -45,19 +45,19 @@ pub trait AdditiveNTT { log_batch_size: usize, ) -> Result<(), Error>; - fn forward_transform_ext(&self, data: &mut [PE], coset: u32) -> Result<(), Error> - where - PE: RepackedExtension

, - PE::Scalar: ExtensionField, - { + fn forward_transform_ext>( + &self, + data: &mut [PE], + coset: u32, + ) -> Result<(), Error> { self.forward_transform(PE::cast_bases_mut(data), coset, PE::Scalar::LOG_DEGREE) } - fn inverse_transform_ext(&self, data: &mut [PE], coset: u32) -> Result<(), Error> - where - PE: RepackedExtension

, - PE::Scalar: ExtensionField, - { + fn inverse_transform_ext>( + &self, + data: &mut [PE], + coset: u32, + ) -> Result<(), Error> { self.inverse_transform(PE::cast_bases_mut(data), coset, PE::Scalar::LOG_DEGREE) } } diff --git a/crates/ntt/src/tests/ntt_tests.rs b/crates/ntt/src/tests/ntt_tests.rs index 7ab2d2e39..ccdf8b7d8 100644 --- a/crates/ntt/src/tests/ntt_tests.rs +++ b/crates/ntt/src/tests/ntt_tests.rs @@ -10,8 +10,8 @@ use binius_field::{ packed_8::PackedBinaryField1x8b, }, underlier::{NumCast, WithUnderlier}, - AESTowerField8b, BinaryField, BinaryField8b, ExtensionField, PackedBinaryField16x32b, - PackedBinaryField8x32b, PackedField, RepackedExtension, + AESTowerField8b, BinaryField, BinaryField8b, PackedBinaryField16x32b, PackedBinaryField8x32b, + PackedField, RepackedExtension, }; use rand::{rngs::StdRng, SeedableRng}; @@ -144,16 +144,12 @@ fn tests_field_512_bits() { check_roundtrip_all_ntts::(12, 6, 4, 0); } -fn check_packed_extension_roundtrip_with_reference( +fn check_packed_extension_roundtrip_with_reference>( reference_ntt: &impl AdditiveNTT

, ntt: &impl AdditiveNTT

, data: &mut [PE], cosets: Range, -) where - P: PackedField, - PE: RepackedExtension

, - PE::Scalar: ExtensionField, -{ +) { let data_copy = data.to_vec(); let mut data_copy_2 = data.to_vec(); @@ -182,7 +178,6 @@ fn check_packed_extension_roundtrip_all_ntts( ) where P: PackedField, PE: RepackedExtension

+ WithUnderlier>, - PE::Scalar: ExtensionField, { let simple_ntt = SingleThreadedNTT::::new(log_domain_size) .unwrap() From a451f4cf277938ee968c34c2daf5945465a72d7c Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Fri, 14 Feb 2025 11:31:53 +0100 Subject: [PATCH 26/50] [macros] Remove unused IterOracles, IterPolys derive proc macros (#25) --- crates/macros/src/lib.rs | 156 --------------------------------------- 1 file changed, 156 deletions(-) diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index f72cef222..2b8daf9f7 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -5,8 +5,6 @@ mod arith_circuit_poly; mod arith_expr; mod composition_poly; -use std::collections::BTreeSet; - use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields}; @@ -300,157 +298,3 @@ fn field_names(fields: Fields, positional_prefix: Option<&str>) -> Vec vec![], } } - -/// Implements `pub fn iter_oracles(&self) -> impl Iterator`. -/// -/// Detects and includes fields with type `OracleId`, `[OracleId; N]` -/// -/// ``` -/// use binius_macros::IterOracles; -/// type OracleId = usize; -/// type BatchId = usize; -/// -/// #[derive(IterOracles)] -/// struct Oracle { -/// x: OracleId, -/// y: [OracleId; 5], -/// z: [OracleId; 5*2], -/// ignored_field1: usize, -/// ignored_field2: BatchId, -/// ignored_field3: [[OracleId; 5]; 2], -/// } -/// ``` -#[proc_macro_derive(IterOracles)] -pub fn iter_oracle_derive(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - let Data::Struct(data) = &input.data else { - panic!("#[derive(IterOracles)] is only defined for structs with named fields"); - }; - let Fields::Named(fields) = &data.fields else { - panic!("#[derive(IterOracles)] is only defined for structs with named fields"); - }; - - let name = &input.ident; - let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl(); - - let oracles = fields - .named - .iter() - .filter_map(|f| { - let name = f.ident.clone(); - match &f.ty { - syn::Type::Path(type_path) if type_path.path.is_ident("OracleId") => { - Some(quote!(std::iter::once(self.#name))) - } - syn::Type::Array(array) => { - if let syn::Type::Path(type_path) = *array.elem.clone() { - type_path - .path - .is_ident("OracleId") - .then(|| quote!(self.#name.into_iter())) - } else { - None - } - } - _ => None, - } - }) - .collect::>(); - - quote! { - impl #impl_generics #name #ty_generics #where_clause { - pub fn iter_oracles(&self) -> impl Iterator { - std::iter::empty() - #(.chain(#oracles))* - } - } - } - .into() -} - -/// Implements `pub fn iter_polys(&self) -> impl Iterator>`. -/// -/// Supports `Vec

`, `[Vec

; N]`. Currently doesn't filter out fields from the struct, so you can't add any other fields. -/// -/// ``` -/// use binius_macros::IterPolys; -/// use binius_field::PackedField; -/// -/// #[derive(IterPolys)] -/// struct Witness { -/// x: Vec

, -/// y: [Vec

; 5], -/// z: [Vec

; 5*2], -/// } -/// ``` -#[proc_macro_derive(IterPolys)] -pub fn iter_witness_derive(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - let Data::Struct(data) = &input.data else { - panic!("#[derive(IterPolys)] is only defined for structs with named fields"); - }; - let Fields::Named(fields) = &data.fields else { - panic!("#[derive(IterPolys)] is only defined for structs with named fields"); - }; - - let name = &input.ident; - let witnesses = fields - .named - .iter() - .map(|f| { - let name = f.ident.clone(); - match &f.ty { - syn::Type::Array(_) => quote!(self.#name.iter()), - _ => quote!(std::iter::once(&self.#name)), - } - }) - .collect::>(); - - let packed_field_vars = generic_vars_with_trait(&input.generics, "PackedField"); - assert_eq!(packed_field_vars.len(), 1, "Only a single packed field is supported for now"); - let p = packed_field_vars.first(); - let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl(); - quote! { - impl #impl_generics #name #ty_generics #where_clause { - pub fn iter_polys(&self) -> impl Iterator> { - std::iter::empty() - #(.chain(#witnesses))* - .map(|values| binius_math::MultilinearExtension::from_values_slice(values.as_slice()).unwrap()) - } - } - } - .into() -} - -/// This will accept the generics definition of a struct (relevant for derive macros), -/// and return all the generic vars that are constrained by a specific trait identifier. -/// ``` -/// use binius_field::{PackedField, Field}; -/// struct Example(A, B, C); -/// ``` -/// In the above example, when matching against the trait_name "PackedField", -/// the identifiers A and B will be returned, but not C -pub(crate) fn generic_vars_with_trait( - vars: &syn::Generics, - trait_name: &str, -) -> BTreeSet { - vars.params - .iter() - .filter_map(|param| match param { - syn::GenericParam::Type(type_param) => { - let is_bounded_by_trait_name = type_param.bounds.iter().any(|bound| match bound { - syn::TypeParamBound::Trait(trait_bound) => { - if let Some(last_segment) = trait_bound.path.segments.last() { - last_segment.ident == trait_name - } else { - false - } - } - _ => false, - }); - is_bounded_by_trait_name.then(|| type_param.ident.clone()) - } - syn::GenericParam::Const(_) | syn::GenericParam::Lifetime(_) => None, - }) - .collect() -} From 3b253acbfb29174c43f0e0437fff1d2b035ca953 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Fri, 14 Feb 2025 13:28:26 +0100 Subject: [PATCH 27/50] [matrix]: simplify scale_row (#31) --- crates/math/src/matrix.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/math/src/matrix.rs b/crates/math/src/matrix.rs index 674087d97..408b79863 100644 --- a/crates/math/src/matrix.rs +++ b/crates/math/src/matrix.rs @@ -186,8 +186,6 @@ impl Matrix { } fn scale_row(&mut self, i: usize, scalar: F) { - assert!(i < self.m); - for x in self.row_mut(i) { *x *= scalar; } From d2faef54742e5bc61f637917bed021deccb77508 Mon Sep 17 00:00:00 2001 From: Dmytro Gordon Date: Fri, 14 Feb 2025 16:36:06 +0200 Subject: [PATCH 28/50] [field] Remove unnecessary `WithUnderlier` trait bound (#32) --- crates/field/src/packed_extension.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/crates/field/src/packed_extension.rs b/crates/field/src/packed_extension.rs index 9e03ac3a2..8f2aaf028 100644 --- a/crates/field/src/packed_extension.rs +++ b/crates/field/src/packed_extension.rs @@ -71,9 +71,7 @@ where /// In order for the above relation to be guaranteed, the memory representation of /// `PackedExtensionField` element must be the same as a slice of the underlying `PackedField` /// element. -pub trait PackedExtension: - PackedField> + WithUnderlier> -{ +pub trait PackedExtension: PackedField> { type PackedSubfield: PackedField; fn cast_bases(packed: &[Self]) -> &[Self::PackedSubfield]; From 2b0d641755311d58f9022823bd9365152ee412e0 Mon Sep 17 00:00:00 2001 From: Dmytro Gordon Date: Fri, 14 Feb 2025 17:33:45 +0200 Subject: [PATCH 29/50] [field] Optimize SIMD element access for Zen4 architecture as well. (#28) --- crates/field/src/arch/x86_64/m128.rs | 86 ++++++++++++++-------------- crates/field/src/arch/x86_64/m256.rs | 81 +++++++++++++------------- crates/field/src/arch/x86_64/m512.rs | 78 ++++++++++++++----------- 3 files changed, 128 insertions(+), 117 deletions(-) diff --git a/crates/field/src/arch/x86_64/m128.rs b/crates/field/src/arch/x86_64/m128.rs index 8223f63fa..c2827ae76 100644 --- a/crates/field/src/arch/x86_64/m128.rs +++ b/crates/field/src/arch/x86_64/m128.rs @@ -13,7 +13,7 @@ use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use crate::{ arch::{ - binary_utils::{as_array_mut, make_func_to_i8}, + binary_utils::{as_array_mut, as_array_ref, make_func_to_i8}, portable::{ packed::{impl_pack_scalar, PackedPrimitiveType}, packed_arithmetic::{ @@ -420,39 +420,41 @@ impl UnderlierWithBitOps for M128 { #[inline(always)] unsafe fn get_subvalue(&self, i: usize) -> T where - T: WithUnderlier, - T::Underlier: NumCast, + T: UnderlierType + NumCast, { - match T::Underlier::BITS { - 1 | 2 | 4 | 8 | 16 | 32 | 64 => { - let elements_in_64 = 64 / T::Underlier::BITS; - let chunk_64 = unsafe { - if i >= elements_in_64 { - _mm_extract_epi64(self.0, 1) - } else { - _mm_extract_epi64(self.0, 0) - } - }; - - let result_64 = if T::Underlier::BITS == 64 { - chunk_64 - } else { - let ones = ((1u128 << T::Underlier::BITS) - 1) as u64; - let val_64 = (chunk_64 as u64) - >> (T::Underlier::BITS - * (if i >= elements_in_64 { - i - elements_in_64 - } else { - i - })) & ones; - - val_64 as i64 - }; - T::from_underlier(T::Underlier::num_cast_from(Self(unsafe { - _mm_set_epi64x(0, result_64) - }))) + match T::BITS { + 1 | 2 | 4 => { + let elements_in_8 = 8 / T::BITS; + let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { + *arr.get_unchecked(i / elements_in_8) + }); + + let shift = (i % elements_in_8) * T::BITS; + value_u8 >>= shift; + + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 8 => { + let value_u8 = + as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 16 => { + let value_u16 = + as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u16))) + } + 32 => { + let value_u32 = + as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u32))) + } + 64 => { + let value_u64 = + as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u64))) } - 128 => T::from_underlier(T::Underlier::num_cast_from(*self)), + 128 => T::from_underlier(T::num_cast_from(*self)), _ => panic!("unsupported bit count"), } } @@ -471,23 +473,23 @@ impl UnderlierWithBitOps for M128 { let val = u8::num_cast_from(Self::from(val)) << shift; let mask = mask << shift; - as_array_mut::<_, u8, 16>(self, |array| { - let element = &mut array[i / elements_in_8]; + as_array_mut::<_, u8, 16>(self, |array| unsafe { + let element = array.get_unchecked_mut(i / elements_in_8); *element &= !mask; *element |= val; }); } - 8 => as_array_mut::<_, u8, 16>(self, |array| { - array[i] = u8::num_cast_from(Self::from(val)); + 8 => as_array_mut::<_, u8, 16>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val)); }), - 16 => as_array_mut::<_, u16, 8>(self, |array| { - array[i] = u16::num_cast_from(Self::from(val)); + 16 => as_array_mut::<_, u16, 8>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val)); }), - 32 => as_array_mut::<_, u32, 4>(self, |array| { - array[i] = u32::num_cast_from(Self::from(val)); + 32 => as_array_mut::<_, u32, 4>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val)); }), - 64 => as_array_mut::<_, u64, 2>(self, |array| { - array[i] = u64::num_cast_from(Self::from(val)); + 64 => as_array_mut::<_, u64, 2>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val)); }), 128 => { *self = Self::from(val); diff --git a/crates/field/src/arch/x86_64/m256.rs b/crates/field/src/arch/x86_64/m256.rs index 4f3663d7c..c548cc2f3 100644 --- a/crates/field/src/arch/x86_64/m256.rs +++ b/crates/field/src/arch/x86_64/m256.rs @@ -463,42 +463,41 @@ impl UnderlierWithBitOps for M256 { T: UnderlierType + NumCast, { match T::BITS { - 1 | 2 | 4 | 8 | 16 | 32 => { - let elements_in_64 = 64 / T::BITS; - let chunk_64 = unsafe { - match i / elements_in_64 { - 0 => _mm256_extract_epi64(self.0, 0), - 1 => _mm256_extract_epi64(self.0, 1), - 2 => _mm256_extract_epi64(self.0, 2), - _ => _mm256_extract_epi64(self.0, 3), - } - }; + 1 | 2 | 4 => { + let elements_in_8 = 8 / T::BITS; + let mut value_u8 = as_array_ref::<_, u8, 32, _>(self, |arr| unsafe { + *arr.get_unchecked(i / elements_in_8) + }); - let result_64 = if T::BITS == 64 { - chunk_64 - } else { - let ones = ((1u128 << T::BITS) - 1) as u64; - let val_64 = (chunk_64 as u64) >> (T::BITS * (i % elements_in_64)) & ones; + let shift = (i % elements_in_8) * T::BITS; + value_u8 >>= shift; - val_64 as i64 - }; - T::num_cast_from(Self(unsafe { _mm256_set_epi64x(0, 0, 0, result_64) })) + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 8 => { + let value_u8 = + as_array_ref::<_, u8, 32, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 16 => { + let value_u16 = + as_array_ref::<_, u16, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u16))) + } + 32 => { + let value_u32 = + as_array_ref::<_, u32, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u32))) } - // NOTE: benchmark show that this strategy is optimal for getting 64-bit subvalues from 256-bit register. - // However using similar code for 1..32 bits is slower than the version above. - // Also even getting `chunk_64` in the code above using this code shows worser benchmarks results. 64 => { - T::num_cast_from(as_array_ref::<_, u64, 4, _>(self, |array| Self::from(array[i]))) + let value_u64 = + as_array_ref::<_, u64, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u64))) } 128 => { - let chunk_128 = unsafe { - if i == 0 { - _mm256_extracti128_si256(self.0, 0) - } else { - _mm256_extracti128_si256(self.0, 1) - } - }; - T::num_cast_from(Self(unsafe { _mm256_set_m128i(_mm_setzero_si128(), chunk_128) })) + let value_u128 = + as_array_ref::<_, u128, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u128))) } _ => panic!("unsupported bit count"), } @@ -518,26 +517,26 @@ impl UnderlierWithBitOps for M256 { let val = u8::num_cast_from(Self::from(val)) << shift; let mask = mask << shift; - as_array_mut::<_, u8, 32>(self, |array| { - let element = &mut array[i / elements_in_8]; + as_array_mut::<_, u8, 32>(self, |array| unsafe { + let element = array.get_unchecked_mut(i / elements_in_8); *element &= !mask; *element |= val; }); } - 8 => as_array_mut::<_, u8, 32>(self, |array| { - array[i] = u8::num_cast_from(Self::from(val)); + 8 => as_array_mut::<_, u8, 32>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val)); }), - 16 => as_array_mut::<_, u16, 16>(self, |array| { - array[i] = u16::num_cast_from(Self::from(val)); + 16 => as_array_mut::<_, u16, 16>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val)); }), - 32 => as_array_mut::<_, u32, 8>(self, |array| { - array[i] = u32::num_cast_from(Self::from(val)); + 32 => as_array_mut::<_, u32, 8>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val)); }), - 64 => as_array_mut::<_, u64, 4>(self, |array| { - array[i] = u64::num_cast_from(Self::from(val)); + 64 => as_array_mut::<_, u64, 4>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val)); }), 128 => as_array_mut::<_, u128, 2>(self, |array| { - array[i] = u128::num_cast_from(Self::from(val)); + *array.get_unchecked_mut(i) = u128::num_cast_from(Self::from(val)); }), _ => panic!("unsupported bit count"), } diff --git a/crates/field/src/arch/x86_64/m512.rs b/crates/field/src/arch/x86_64/m512.rs index 80c4bccbf..a7de540e5 100644 --- a/crates/field/src/arch/x86_64/m512.rs +++ b/crates/field/src/arch/x86_64/m512.rs @@ -12,7 +12,7 @@ use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use crate::{ arch::{ - binary_utils::{as_array_mut, make_func_to_i8}, + binary_utils::{as_array_mut, as_array_ref, make_func_to_i8}, portable::{ packed::{impl_pack_scalar, PackedPrimitiveType}, packed_arithmetic::{ @@ -617,31 +617,41 @@ impl UnderlierWithBitOps for M512 { T: UnderlierType + NumCast, { match T::BITS { - 1 | 2 | 4 | 8 | 16 | 32 | 64 => { - let elements_in_64 = 64 / T::BITS; - let shuffle = unsafe { _mm512_set1_epi64((i / elements_in_64) as i64) }; - let chunk_64 = - u64::num_cast_from(Self(unsafe { _mm512_permutexvar_epi64(shuffle, self.0) })); - - let result_64 = if T::BITS == 64 { - chunk_64 - } else { - let ones = ((1u128 << T::BITS) - 1) as u64; - (chunk_64 >> (T::BITS * (i % elements_in_64))) & ones - }; + 1 | 2 | 4 => { + let elements_in_8 = 8 / T::BITS; + let mut value_u8 = as_array_ref::<_, u8, 64, _>(self, |arr| unsafe { + *arr.get_unchecked(i / elements_in_8) + }); - T::num_cast_from(Self::from(result_64)) + let shift = (i % elements_in_8) * T::BITS; + value_u8 >>= shift; + + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 8 => { + let value_u8 = + as_array_ref::<_, u8, 64, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u8))) + } + 16 => { + let value_u16 = + as_array_ref::<_, u16, 32, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u16))) + } + 32 => { + let value_u32 = + as_array_ref::<_, u32, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u32))) + } + 64 => { + let value_u64 = + as_array_ref::<_, u64, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u64))) } 128 => { - let chunk_128 = unsafe { - match i { - 0 => _mm512_extracti32x4_epi32(self.0, 0), - 1 => _mm512_extracti32x4_epi32(self.0, 1), - 2 => _mm512_extracti32x4_epi32(self.0, 2), - _ => _mm512_extracti32x4_epi32(self.0, 3), - } - }; - T::num_cast_from(Self(unsafe { _mm512_castsi128_si512(chunk_128) })) + let value_u128 = + as_array_ref::<_, u128, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) }); + T::from_underlier(T::num_cast_from(Self::from(value_u128))) } _ => panic!("unsupported bit count"), } @@ -661,26 +671,26 @@ impl UnderlierWithBitOps for M512 { let val = u8::num_cast_from(Self::from(val)) << shift; let mask = mask << shift; - as_array_mut::<_, u8, 64>(self, |array| { - let element = &mut array[i / elements_in_8]; + as_array_mut::<_, u8, 64>(self, |array| unsafe { + let element = array.get_unchecked_mut(i / elements_in_8); *element &= !mask; *element |= val; }); } - 8 => as_array_mut::<_, u8, 64>(self, |array| { - array[i] = u8::num_cast_from(Self::from(val)); + 8 => as_array_mut::<_, u8, 64>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val)); }), - 16 => as_array_mut::<_, u16, 32>(self, |array| { - array[i] = u16::num_cast_from(Self::from(val)); + 16 => as_array_mut::<_, u16, 32>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val)); }), - 32 => as_array_mut::<_, u32, 16>(self, |array| { - array[i] = u32::num_cast_from(Self::from(val)); + 32 => as_array_mut::<_, u32, 16>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val)); }), - 64 => as_array_mut::<_, u64, 8>(self, |array| { - array[i] = u64::num_cast_from(Self::from(val)); + 64 => as_array_mut::<_, u64, 8>(self, |array| unsafe { + *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val)); }), 128 => as_array_mut::<_, u128, 4>(self, |array| { - array[i] = u128::num_cast_from(Self::from(val)); + *array.get_unchecked_mut(i) = u128::num_cast_from(Self::from(val)); }), _ => panic!("unsupported bit count"), } From 5a9ffd6b5508d4d4bb55ff0a2a5ee75a512bfa8e Mon Sep 17 00:00:00 2001 From: Nikita Lesnikov Date: Sat, 15 Feb 2025 00:05:38 +0300 Subject: [PATCH 30/50] refactor: Use binary_tower_level for base field detection (#30) --- crates/core/src/constraint_system/prove.rs | 30 +++------------------- 1 file changed, 3 insertions(+), 27 deletions(-) diff --git a/crates/core/src/constraint_system/prove.rs b/crates/core/src/constraint_system/prove.rs index 6d7d3f83f..a04539d5e 100644 --- a/crates/core/src/constraint_system/prove.rs +++ b/crates/core/src/constraint_system/prove.rs @@ -12,7 +12,7 @@ use binius_field::{ use binius_hal::ComputationBackend; use binius_hash::PseudoCompressionFunction; use binius_math::{ - ArithExpr, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEDirectAdapter, + EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly, }; use binius_maybe_rayon::prelude::*; @@ -54,7 +54,7 @@ use crate::{ }, }, ring_switch, - tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier, TowerFamily}, + tower::{PackedTop, ProverTowerFamily, ProverTowerUnderlier}, transcript::ProverTranscript, witness::{MultilinearExtensionIndex, MultilinearWitness}, }; @@ -297,7 +297,7 @@ where .map(|multilinear| 7 - multilinear.log_extension_degree()), constraints .iter() - .map(|constraint| arith_expr_base_tower_level::(&constraint.composition)) + .map(|constraint| constraint.composition.binary_tower_level()) ) .max() .unwrap_or(0); @@ -452,30 +452,6 @@ where }) } -fn arith_expr_base_tower_level(composition: &ArithExpr>) -> usize { - if composition.try_convert_field::().is_ok() { - return 0; - } - - if composition.try_convert_field::().is_ok() { - return 3; - } - - if composition.try_convert_field::().is_ok() { - return 4; - } - - if composition.try_convert_field::().is_ok() { - return 5; - } - - if composition.try_convert_field::().is_ok() { - return 6; - } - - 7 -} - type TypeErasedUnivariateZerocheck<'a, F> = Box + 'a>; type TypeErasedSumcheck<'a, F> = Box + 'a>; type TypeErasedProver<'a, F> = From b3f307612030e24d06ae1e64936d3011af4dacdd Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Mon, 17 Feb 2025 13:20:28 +0100 Subject: [PATCH 31/50] [serialization] impl SerializeCanonical, DeserializeCanonical for ConstraintSystem (#11) --- Cargo.toml | 1 + crates/core/Cargo.toml | 1 + crates/core/src/constraint_system/channel.rs | 2 +- crates/core/src/constraint_system/mod.rs | 20 ++- crates/core/src/lib.rs | 2 + crates/core/src/oracle/constraint.rs | 13 +- crates/core/src/oracle/multilinear.rs | 117 +++++++++++++++--- crates/core/src/polynomial/multivariate.rs | 13 +- .../src/protocols/sumcheck/prove/zerocheck.rs | 6 +- crates/core/src/transparent/constant.rs | 16 ++- crates/core/src/transparent/mod.rs | 1 + .../src/transparent/multilinear_extension.rs | 73 ++++++++++- crates/core/src/transparent/powers.rs | 20 ++- crates/core/src/transparent/select_row.rs | 14 ++- crates/core/src/transparent/serialization.rs | 78 ++++++++++++ crates/core/src/transparent/step_down.rs | 14 ++- crates/core/src/transparent/step_up.rs | 14 ++- crates/core/src/transparent/tower_basis.rs | 14 ++- crates/field/src/packed.rs | 8 ++ crates/field/src/packed_extension_ops.rs | 14 +-- crates/field/src/serialization/canonical.rs | 58 +++++++++ crates/field/src/serialization/error.rs | 6 + crates/macros/src/lib.rs | 38 +++++- crates/math/Cargo.toml | 1 + crates/math/src/arith_expr.rs | 27 +++- 25 files changed, 515 insertions(+), 56 deletions(-) create mode 100644 crates/core/src/transparent/serialization.rs diff --git a/Cargo.toml b/Cargo.toml index b6916ac36..2ee58dfc2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -100,6 +100,7 @@ generic-array = "0.14.7" getset = "0.1.2" groestl_crypto = { package = "groestl", version = "0.10.1" } hex-literal = "0.4.1" +inventory = "0.3.19" itertools = "0.13.0" lazy_static = "1.5.0" paste = "1.0.15" diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index e0033c614..7f74200ae 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -24,6 +24,7 @@ derive_more.workspace = true digest.workspace = true either.workspace = true getset.workspace = true +inventory.workspace = true itertools.workspace = true rand.workspace = true stackalloc.workspace = true diff --git a/crates/core/src/constraint_system/channel.rs b/crates/core/src/constraint_system/channel.rs index e73d83067..1242e781b 100644 --- a/crates/core/src/constraint_system/channel.rs +++ b/crates/core/src/constraint_system/channel.rs @@ -59,7 +59,7 @@ use crate::{oracle::OracleId, witness::MultilinearExtensionIndex}; pub type ChannelId = usize; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] pub struct Flush { pub oracles: Vec, pub channel_id: ChannelId, diff --git a/crates/core/src/constraint_system/mod.rs b/crates/core/src/constraint_system/mod.rs index de178b31f..2c040c41e 100644 --- a/crates/core/src/constraint_system/mod.rs +++ b/crates/core/src/constraint_system/mod.rs @@ -7,7 +7,8 @@ mod prove; pub mod validate; mod verify; -use binius_field::TowerField; +use binius_field::{serialization, BinaryField128b, DeserializeCanonical, TowerField}; +use binius_macros::SerializeCanonical; use channel::{ChannelId, Flush}; pub use prove::prove; pub use verify::verify; @@ -21,7 +22,7 @@ use crate::oracle::{ConstraintSet, MultilinearOracleSet, OracleId}; /// /// As a result, a ConstraintSystem allows us to validate all of these /// constraints against a witness, as well as enabling generic prove/verify -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeCanonical)] pub struct ConstraintSystem { pub oracles: MultilinearOracleSet, pub table_constraints: Vec>, @@ -30,6 +31,21 @@ pub struct ConstraintSystem { pub max_channel_id: ChannelId, } +impl DeserializeCanonical for ConstraintSystem { + fn deserialize_canonical(mut read_buf: impl bytes::Buf) -> Result + where + Self: Sized, + { + Ok(Self { + oracles: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + table_constraints: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + non_zero_oracle_ids: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + flushes: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + max_channel_id: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + }) + } +} + impl ConstraintSystem { pub const fn no_base_constraints(self) -> Self { self diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 4ff8e8ff0..48f1f060d 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -28,3 +28,5 @@ pub mod tower; pub mod transcript; pub mod transparent; pub mod witness; + +pub use inventory; diff --git a/crates/core/src/oracle/constraint.rs b/crates/core/src/oracle/constraint.rs index 4a9d0225a..bd489381b 100644 --- a/crates/core/src/oracle/constraint.rs +++ b/crates/core/src/oracle/constraint.rs @@ -4,6 +4,7 @@ use core::iter::IntoIterator; use std::sync::Arc; use binius_field::{Field, TowerField}; +use binius_macros::{DeserializeCanonical, SerializeCanonical}; use binius_math::{ArithExpr, CompositionPolyOS}; use binius_utils::bail; use itertools::Itertools; @@ -15,23 +16,23 @@ use super::{Error, MultilinearOracleSet, MultilinearPolyVariant, OracleId}; pub type TypeErasedComposition

= Arc>; /// Constraint is a type erased composition along with a predicate on its values on the boolean hypercube -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] pub struct Constraint { - pub name: Arc, + pub name: String, pub composition: ArithExpr, pub predicate: ConstraintPredicate, } /// Predicate can either be a sum of values of a composition on the hypercube (sumcheck) or equality to zero /// on the hypercube (zerocheck) -#[derive(Clone, Debug)] +#[derive(Clone, Debug, SerializeCanonical, DeserializeCanonical)] pub enum ConstraintPredicate { Sum(F), Zero, } /// Constraint set is a group of constraints that operate over the same set of oracle-identified multilinears -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] pub struct ConstraintSet { pub n_vars: usize, pub oracle_ids: Vec, @@ -41,7 +42,7 @@ pub struct ConstraintSet { // A deferred constraint constructor that instantiates index composition after the superset of oracles is known #[allow(clippy::type_complexity)] struct UngroupedConstraint { - name: Arc, + name: String, oracle_ids: Vec, composition: ArithExpr, predicate: ConstraintPredicate, @@ -82,7 +83,7 @@ impl ConstraintSetBuilder { composition: ArithExpr, ) { self.constraints.push(UngroupedConstraint { - name: name.to_string().into(), + name: name.to_string(), oracle_ids: oracle_ids.into_iter().collect(), composition, predicate: ConstraintPredicate::Zero, diff --git a/crates/core/src/oracle/multilinear.rs b/crates/core/src/oracle/multilinear.rs index bfba63489..be06ca5da 100644 --- a/crates/core/src/oracle/multilinear.rs +++ b/crates/core/src/oracle/multilinear.rs @@ -2,7 +2,10 @@ use std::{array, fmt::Debug, sync::Arc}; -use binius_field::{Field, TowerField}; +use binius_field::{ + serialization, BinaryField128b, DeserializeCanonical, Field, SerializeCanonical, TowerField, +}; +use binius_macros::{DeserializeCanonical, SerializeCanonical}; use binius_utils::bail; use getset::{CopyGetters, Getters}; @@ -280,9 +283,20 @@ impl MultilinearOracleSetAddition<'_, F> { /// /// The oracle set also tracks the committed polynomial in batches where each batch is committed /// together with a polynomial commitment scheme. -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, SerializeCanonical)] pub struct MultilinearOracleSet { - oracles: Vec>>, + oracles: Vec>, +} + +impl binius_field::DeserializeCanonical for MultilinearOracleSet { + fn deserialize_canonical(mut read_buf: impl bytes::Buf) -> Result + where + Self: Sized, + { + Ok(Self { + oracles: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + }) + } } impl MultilinearOracleSet { @@ -323,12 +337,11 @@ impl MultilinearOracleSet { oracle: impl FnOnce(OracleId) -> MultilinearPolyOracle, ) -> OracleId { let id = self.oracles.len(); - - self.oracles.push(Arc::new(oracle(id))); + self.oracles.push(oracle(id)); id } - fn get_from_set(&self, id: OracleId) -> Arc> { + fn get_from_set(&self, id: OracleId) -> MultilinearPolyOracle { self.oracles[id].clone() } @@ -401,7 +414,7 @@ impl MultilinearOracleSet { } pub fn oracle(&self, id: OracleId) -> MultilinearPolyOracle { - (*self.oracles[id]).clone() + self.oracles[id].clone() } pub fn n_vars(&self, id: OracleId) -> usize { @@ -438,7 +451,7 @@ impl MultilinearOracleSet { /// other oracles. This is formalized in [DP23] Section 4. /// /// [DP23]: -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, SerializeCanonical)] pub struct MultilinearPolyOracle { pub id: OracleId, pub name: Option, @@ -447,7 +460,22 @@ pub struct MultilinearPolyOracle { pub variant: MultilinearPolyVariant, } -#[derive(Debug, Clone, PartialEq, Eq)] +impl DeserializeCanonical for MultilinearPolyOracle { + fn deserialize_canonical(mut read_buf: impl bytes::Buf) -> Result + where + Self: Sized, + { + Ok(Self { + id: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + name: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + n_vars: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + tower_level: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + variant: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, SerializeCanonical)] pub enum MultilinearPolyVariant { Committed, Transparent(TransparentPolyOracle), @@ -459,6 +487,33 @@ pub enum MultilinearPolyVariant { ZeroPadded(OracleId), } +impl DeserializeCanonical for MultilinearPolyVariant { + fn deserialize_canonical(mut buf: impl bytes::Buf) -> Result + where + Self: Sized, + { + Ok(match u8::deserialize_canonical(&mut buf)? { + 0 => Self::Committed, + 1 => Self::Transparent(DeserializeCanonical::deserialize_canonical(&mut buf)?), + 2 => Self::Repeating { + id: DeserializeCanonical::deserialize_canonical(&mut buf)?, + log_count: DeserializeCanonical::deserialize_canonical(&mut buf)?, + }, + 3 => Self::Projected(DeserializeCanonical::deserialize_canonical(&mut buf)?), + 4 => Self::Shifted(DeserializeCanonical::deserialize_canonical(&mut buf)?), + 5 => Self::Packed(DeserializeCanonical::deserialize_canonical(&mut buf)?), + 6 => Self::LinearCombination(DeserializeCanonical::deserialize_canonical(&mut buf)?), + 7 => Self::ZeroPadded(DeserializeCanonical::deserialize_canonical(&mut buf)?), + variant_index => { + return Err(serialization::Error::UnknownEnumVariant { + name: "MultilinearPolyVariant", + index: variant_index, + }); + } + }) + } +} + /// A transparent multilinear polynomial oracle. /// /// See the [`MultilinearPolyOracle`] documentation for context. @@ -468,6 +523,28 @@ pub struct TransparentPolyOracle { poly: Arc>, } +impl SerializeCanonical for TransparentPolyOracle { + fn serialize_canonical( + &self, + mut write_buf: impl bytes::BufMut, + ) -> Result<(), binius_field::serialization::Error> { + self.poly.erased_serialize_canonical(&mut write_buf) + } +} + +impl DeserializeCanonical for TransparentPolyOracle { + fn deserialize_canonical( + mut read_buf: impl bytes::Buf, + ) -> Result + where + Self: Sized, + { + let poly: Box> = + DeserializeCanonical::deserialize_canonical(&mut read_buf)?; + Ok(Self { poly: poly.into() }) + } +} + impl TransparentPolyOracle { fn new(poly: Arc>) -> Result { if poly.binary_tower_level() > F::TOWER_LEVEL { @@ -494,13 +571,15 @@ impl PartialEq for TransparentPolyOracle { impl Eq for TransparentPolyOracle {} -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeCanonical, DeserializeCanonical)] pub enum ProjectionVariant { FirstVars, LastVars, } -#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)] +#[derive( + Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeCanonical, DeserializeCanonical, +)] pub struct Projected { #[get_copy = "pub"] id: OracleId, @@ -530,14 +609,16 @@ impl Projected { } } -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeCanonical, DeserializeCanonical)] pub enum ShiftVariant { CircularLeft, LogicalLeft, LogicalRight, } -#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)] +#[derive( + Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeCanonical, DeserializeCanonical, +)] pub struct Shifted { #[get_copy = "pub"] id: OracleId, @@ -579,7 +660,9 @@ impl Shifted { } } -#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)] +#[derive( + Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeCanonical, DeserializeCanonical, +)] pub struct Packed { #[get_copy = "pub"] id: OracleId, @@ -593,7 +676,9 @@ pub struct Packed { log_degree: usize, } -#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters)] +#[derive( + Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeCanonical, DeserializeCanonical, +)] pub struct LinearCombination { #[get_copy = "pub"] n_vars: usize, @@ -606,7 +691,7 @@ impl LinearCombination { fn new( n_vars: usize, offset: F, - inner: impl IntoIterator>, F)>, + inner: impl IntoIterator, F)>, ) -> Result { let inner = inner .into_iter() diff --git a/crates/core/src/polynomial/multivariate.rs b/crates/core/src/polynomial/multivariate.rs index a8e6fdcf8..e154e8e25 100644 --- a/crates/core/src/polynomial/multivariate.rs +++ b/crates/core/src/polynomial/multivariate.rs @@ -2,11 +2,12 @@ use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sync::Arc}; -use binius_field::{Field, PackedField}; +use binius_field::{serialization, Field, PackedField}; use binius_math::{ ArithExpr, CompositionPolyOS, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef, }; use binius_utils::bail; +use bytes::BufMut; use itertools::Itertools; use rand::{rngs::StdRng, SeedableRng}; @@ -28,6 +29,16 @@ pub trait MultivariatePoly

: Debug + Send + Sync { /// Returns the maximum binary tower level of all constants in the arithmetic expression. fn binary_tower_level(&self) -> usize; + + /// Serialize a type erased MultivariatePoly. + /// Since not every MultivariatePoly implements serialization, this defaults to returning an error. + fn erased_serialize_canonical( + &self, + write_buf: &mut dyn BufMut, + ) -> Result<(), serialization::Error> { + let _ = write_buf; + Err(serialization::Error::SerializationNotImplemented) + } } /// Identity composition function $g(X) = X$. diff --git a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs index 866d7f77d..0e11eb8d0 100644 --- a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs @@ -41,7 +41,7 @@ use crate::{ pub fn validate_witness<'a, F, P, M, Composition>( multilinears: &[M], - zero_claims: impl IntoIterator, Composition)>, + zero_claims: impl IntoIterator, ) -> Result<(), Error> where F: Field, @@ -99,7 +99,7 @@ where #[getset(get = "pub")] multilinears: Vec, switchover_rounds: Vec, - compositions: Vec<(Arc, CompositionBase, Composition)>, + compositions: Vec<(String, CompositionBase, Composition)>, zerocheck_challenges: Vec, domains: Vec>, backend: &'a Backend, @@ -125,7 +125,7 @@ where { pub fn new( multilinears: Vec, - zero_claims: impl IntoIterator, CompositionBase, Composition)>, + zero_claims: impl IntoIterator, zerocheck_challenges: &[F], evaluation_domain_factory: impl EvaluationDomainFactory, switchover_fn: impl Fn(usize) -> usize, diff --git a/crates/core/src/transparent/constant.rs b/crates/core/src/transparent/constant.rs index 1c0108739..ea7bd735a 100644 --- a/crates/core/src/transparent/constant.rs +++ b/crates/core/src/transparent/constant.rs @@ -1,18 +1,29 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{ExtensionField, TowerField}; +use binius_macros::{erased_serialize_canonical, DeserializeCanonical, SerializeCanonical}; use binius_utils::bail; use crate::polynomial::{Error, MultivariatePoly}; /// A constant polynomial. -#[derive(Debug, Copy, Clone)] -pub struct Constant { +#[derive(Debug, Copy, Clone, SerializeCanonical, DeserializeCanonical)] +pub struct Constant { n_vars: usize, value: F, tower_level: usize, } +inventory::submit! { + >::register_deserializer( + "Constant", + |buf: &mut dyn bytes::Buf| { + let deserialized = as binius_field::DeserializeCanonical>::deserialize_canonical(&mut *buf)?; + Ok(Box::new(deserialized)) + } + ) +} + impl Constant { pub fn new(n_vars: usize, value: FS) -> Self where @@ -26,6 +37,7 @@ impl Constant { } } +#[erased_serialize_canonical] impl MultivariatePoly for Constant { fn n_vars(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/mod.rs b/crates/core/src/transparent/mod.rs index e95864fe8..a1e79368c 100644 --- a/crates/core/src/transparent/mod.rs +++ b/crates/core/src/transparent/mod.rs @@ -6,6 +6,7 @@ pub mod eq_ind; pub mod multilinear_extension; pub mod powers; pub mod select_row; +pub mod serialization; pub mod shift_ind; pub mod step_down; pub mod step_up; diff --git a/crates/core/src/transparent/multilinear_extension.rs b/crates/core/src/transparent/multilinear_extension.rs index 7df54751a..9c59486ee 100644 --- a/crates/core/src/transparent/multilinear_extension.rs +++ b/crates/core/src/transparent/multilinear_extension.rs @@ -2,8 +2,13 @@ use std::{fmt::Debug, ops::Deref}; -use binius_field::{ExtensionField, PackedField, RepackedExtension, TowerField}; +use binius_field::{ + arch::OptimalUnderlier, as_packed_field::PackedType, packed::pack_slice, BinaryField128b, + DeserializeCanonical, ExtensionField, PackedField, RepackedExtension, SerializeCanonical, + TowerField, +}; use binius_hal::{make_portable_backend, ComputationBackendExt}; +use binius_macros::erased_serialize_canonical; use binius_math::{MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly}; use crate::polynomial::{Error, MultivariatePoly}; @@ -26,6 +31,71 @@ where data: MLEEmbeddingAdapter, } +impl SerializeCanonical for MultilinearExtensionTransparent +where + P: PackedField, + PE: RepackedExtension

, + PE::Scalar: TowerField + ExtensionField, + Data: Deref + Debug + Send + Sync, +{ + fn serialize_canonical( + &self, + write_buf: impl bytes::BufMut, + ) -> Result<(), binius_field::serialization::Error> { + let elems = PE::iter_slice( + self.data + .packed_evals() + .expect("Evals should always be available here"), + ) + .collect::>(); + SerializeCanonical::serialize_canonical(&elems, write_buf) + } +} + +inventory::submit! { + >::register_deserializer( + "MultilinearExtensionTransparent", + |buf: &mut dyn bytes::Buf| { + type U = OptimalUnderlier; + type F = BinaryField128b; + type P = PackedType; + let hypercube_evals: Vec = DeserializeCanonical::deserialize_canonical(&mut *buf)?; + let result: Box> = if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + } else { + Box::new(MultilinearExtensionTransparent::::from_values(pack_slice(&hypercube_evals)).unwrap()) + }; + Ok(result) + } + ) +} + +fn try_pack_slice(xs: &[F]) -> Option> +where + PS: PackedField, + F: ExtensionField, +{ + Some(pack_slice( + &xs.iter() + .copied() + .map(TryInto::try_into) + .collect::, _>>() + .ok()?, + )) +} + impl MultilinearExtensionTransparent where P: PackedField, @@ -49,6 +119,7 @@ where } } +#[erased_serialize_canonical] impl MultivariatePoly for MultilinearExtensionTransparent where F: TowerField + ExtensionField, diff --git a/crates/core/src/transparent/powers.rs b/crates/core/src/transparent/powers.rs index c0c912d9f..960608706 100644 --- a/crates/core/src/transparent/powers.rs +++ b/crates/core/src/transparent/powers.rs @@ -2,7 +2,8 @@ use std::iter::successors; -use binius_field::{Field, PackedField, TowerField}; +use binius_field::{PackedField, TowerField}; +use binius_macros::{erased_serialize_canonical, DeserializeCanonical, SerializeCanonical}; use binius_math::MultilinearExtension; use binius_maybe_rayon::prelude::*; use binius_utils::bail; @@ -13,13 +14,23 @@ use crate::polynomial::{Error, MultivariatePoly}; /// A transparent multilinear polynomial whose evaluation at index $i$ is $g^i$ for /// some field element $g$. -#[derive(Debug)] -pub struct Powers { +#[derive(Debug, SerializeCanonical, DeserializeCanonical)] +pub struct Powers { n_vars: usize, base: F, } -impl Powers { +inventory::submit! { + >::register_deserializer( + "Powers", + |buf: &mut dyn bytes::Buf| { + let deserialized = as binius_field::DeserializeCanonical>::deserialize_canonical(&mut *buf)?; + Ok(Box::new(deserialized)) + } + ) +} + +impl Powers { pub const fn new(n_vars: usize, base: F) -> Self { Self { n_vars, base } } @@ -49,6 +60,7 @@ impl Powers { } } +#[erased_serialize_canonical] impl> MultivariatePoly

for Powers { fn n_vars(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/select_row.rs b/crates/core/src/transparent/select_row.rs index fdcd32c5b..ed86f7d7f 100644 --- a/crates/core/src/transparent/select_row.rs +++ b/crates/core/src/transparent/select_row.rs @@ -1,6 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{packed::set_packed_slice, BinaryField1b, Field, PackedField}; +use binius_macros::{erased_serialize_canonical, DeserializeCanonical, SerializeCanonical}; use binius_math::MultilinearExtension; use binius_utils::bail; @@ -18,12 +19,22 @@ use crate::polynomial::{Error, MultivariatePoly}; /// ``` /// /// This is useful for defining boundary constraints -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] pub struct SelectRow { n_vars: usize, index: usize, } +inventory::submit! { + >::register_deserializer( + "SelectRow", + |buf: &mut dyn bytes::Buf| { + let deserialized = ::deserialize_canonical(&mut *buf)?; + Ok(Box::new(deserialized)) + } + ) +} + impl SelectRow { pub fn new(n_vars: usize, index: usize) -> Result { if index >= (1 << n_vars) { @@ -50,6 +61,7 @@ impl SelectRow { } } +#[erased_serialize_canonical] impl MultivariatePoly for SelectRow { fn degree(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/serialization.rs b/crates/core/src/transparent/serialization.rs new file mode 100644 index 000000000..bc785713a --- /dev/null +++ b/crates/core/src/transparent/serialization.rs @@ -0,0 +1,78 @@ +// Copyright 2025 Irreducible Inc. + +//! The purpose of this module is to enable serialization/deserialization of generic MultivariatePoly implementations +//! +//! The simplest way to do this would be to create an enum with all the possible structs that implement MultivariatePoly +//! +//! This has a few problems, though: +//! - Third party code is not able to define custom transparent polynomials +//! - The enum would inherit, or be forced to enumerate possible type parameters of every struct variant + +use std::{collections::HashMap, sync::LazyLock}; + +use binius_field::{ + serialization::Error, BinaryField128b, DeserializeCanonical, SerializeCanonical, TowerField, +}; + +use crate::polynomial::MultivariatePoly; + +impl SerializeCanonical for Box> { + fn serialize_canonical( + &self, + mut write_buf: impl bytes::BufMut, + ) -> Result<(), binius_field::serialization::Error> { + self.erased_serialize_canonical(&mut write_buf) + } +} + +impl DeserializeCanonical for Box> { + fn deserialize_canonical(mut read_buf: impl bytes::Buf) -> Result + where + Self: Sized, + { + let name = String::deserialize_canonical(&mut read_buf)?; + match REGISTRY.get(name.as_str()) { + Some(Some(erased_deserialize_canonical)) => erased_deserialize_canonical(&mut read_buf), + Some(None) => Err(Error::DeserializerNameConflict { name }), + None => Err(Error::DeserializerNotImplented), + } + } +} + +// Using the inventory crate we can collect all deserializers before the main function runs +// This allows third party code to submit their own deserializers as well +inventory::collect!(DeserializerEntry); + +static REGISTRY: LazyLock< + HashMap<&'static str, Option>>, +> = LazyLock::new(|| { + let mut registry = HashMap::new(); + inventory::iter::> + .into_iter() + .for_each(|&DeserializerEntry { name, deserializer }| match registry.entry(name) { + std::collections::hash_map::Entry::Vacant(entry) => { + entry.insert(Some(deserializer)); + } + std::collections::hash_map::Entry::Occupied(mut entry) => { + entry.insert(None); + } + }); + registry +}); + +impl dyn MultivariatePoly { + pub const fn register_deserializer( + name: &'static str, + deserializer: ErasedDeserializeCanonical, + ) -> DeserializerEntry { + DeserializerEntry { name, deserializer } + } +} + +pub struct DeserializerEntry { + name: &'static str, + deserializer: ErasedDeserializeCanonical, +} + +type ErasedDeserializeCanonical = + fn(&mut dyn bytes::Buf) -> Result>, Error>; diff --git a/crates/core/src/transparent/step_down.rs b/crates/core/src/transparent/step_down.rs index 8c588f738..7e0d6e8bb 100644 --- a/crates/core/src/transparent/step_down.rs +++ b/crates/core/src/transparent/step_down.rs @@ -1,6 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{Field, PackedField}; +use binius_macros::{erased_serialize_canonical, DeserializeCanonical, SerializeCanonical}; use binius_math::MultilinearExtension; use binius_utils::bail; @@ -20,12 +21,22 @@ use crate::polynomial::{Error, MultivariatePoly}; /// ``` /// /// This is useful for making constraints that are not enforced at the last rows of the trace -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] pub struct StepDown { n_vars: usize, index: usize, } +inventory::submit! { + >::register_deserializer( + "StepDown", + |buf: &mut dyn bytes::Buf| { + let deserialized = ::deserialize_canonical(&mut *buf)?; + Ok(Box::new(deserialized)) + } + ) +} + impl StepDown { pub fn new(n_vars: usize, index: usize) -> Result { if index > 1 << n_vars { @@ -68,6 +79,7 @@ impl StepDown { } } +#[erased_serialize_canonical] impl MultivariatePoly for StepDown { fn degree(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/step_up.rs b/crates/core/src/transparent/step_up.rs index 3a24b9f53..e764d0428 100644 --- a/crates/core/src/transparent/step_up.rs +++ b/crates/core/src/transparent/step_up.rs @@ -1,6 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{Field, PackedField}; +use binius_macros::{erased_serialize_canonical, DeserializeCanonical, SerializeCanonical}; use binius_math::MultilinearExtension; use binius_utils::bail; @@ -20,12 +21,22 @@ use crate::polynomial::{Error, MultivariatePoly}; /// ``` /// /// This is useful for making constraints that are not enforced at the first rows of the trace -#[derive(Debug, Clone)] +#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] pub struct StepUp { n_vars: usize, index: usize, } +inventory::submit! { + >::register_deserializer( + "StepUp", + |buf: &mut dyn bytes::Buf| { + let deserialized = ::deserialize_canonical(&mut *buf)?; + Ok(Box::new(deserialized)) + } + ) +} + impl StepUp { pub fn new(n_vars: usize, index: usize) -> Result { if index > 1 << n_vars { @@ -64,6 +75,7 @@ impl StepUp { } } +#[erased_serialize_canonical] impl MultivariatePoly for StepUp { fn degree(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/tower_basis.rs b/crates/core/src/transparent/tower_basis.rs index 20992c467..471ef4e6c 100644 --- a/crates/core/src/transparent/tower_basis.rs +++ b/crates/core/src/transparent/tower_basis.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use binius_field::{Field, PackedField, TowerField}; +use binius_macros::{erased_serialize_canonical, DeserializeCanonical, SerializeCanonical}; use binius_math::MultilinearExtension; use binius_utils::bail; @@ -20,13 +21,23 @@ use crate::polynomial::{Error, MultivariatePoly}; /// /// Thus, $\mathcal{T}_{\iota+k}$ has a $\mathcal{T}_{\iota}$-basis of size $2^k$: /// * $1, X_{\iota}, X_{\iota+1}, X_{\iota}X_{\iota+1}, X_{\iota+2}, \ldots, X_{\iota} X_{\iota+1} \ldots X_{\iota+k-1}$ -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, SerializeCanonical, DeserializeCanonical)] pub struct TowerBasis { k: usize, iota: usize, _marker: PhantomData, } +inventory::submit! { + >::register_deserializer( + "TowerBasis", + |buf: &mut dyn bytes::Buf| { + let deserialized = as binius_field::DeserializeCanonical>::deserialize_canonical(&mut *buf)?; + Ok(Box::new(deserialized)) + } + ) +} + impl TowerBasis { pub fn new(k: usize, iota: usize) -> Result { if iota + k > F::TOWER_LEVEL { @@ -62,6 +73,7 @@ impl TowerBasis { } } +#[erased_serialize_canonical] impl MultivariatePoly for TowerBasis where F: TowerField, diff --git a/crates/field/src/packed.rs b/crates/field/src/packed.rs index b47a58212..9139b67f0 100644 --- a/crates/field/src/packed.rs +++ b/crates/field/src/packed.rs @@ -385,6 +385,14 @@ pub fn mul_by_subfield_scalar, FS: Field>(val: P, multipl } } +pub fn pack_slice(scalars: &[P::Scalar]) -> Vec

{ + let mut packed_slice = vec![P::default(); scalars.len() / P::WIDTH]; + for (i, scalar) in scalars.iter().enumerate() { + set_packed_slice(&mut packed_slice, i, *scalar); + } + packed_slice +} + impl Broadcast for F { fn broadcast(scalar: F) -> Self { scalar diff --git a/crates/field/src/packed_extension_ops.rs b/crates/field/src/packed_extension_ops.rs index 53dc4e0b6..30fad681f 100644 --- a/crates/field/src/packed_extension_ops.rs +++ b/crates/field/src/packed_extension_ops.rs @@ -106,10 +106,10 @@ mod tests { use crate::{ ext_base_mul, ext_base_mul_par, - packed::{get_packed_slice, set_packed_slice}, + packed::{get_packed_slice, pack_slice}, underlier::WithUnderlier, BinaryField128b, BinaryField16b, BinaryField8b, PackedBinaryField16x16b, - PackedBinaryField2x128b, PackedBinaryField32x8b, PackedField, + PackedBinaryField2x128b, PackedBinaryField32x8b, }; fn strategy_8b_scalars() -> impl Strategy { @@ -127,16 +127,6 @@ mod tests { .prop_map(|arr| arr.map(::from_underlier)) } - fn pack_slice(scalar_slice: &[P::Scalar]) -> Vec

{ - let mut packed_slice = vec![P::default(); scalar_slice.len() / P::WIDTH]; - - for (i, scalar) in scalar_slice.iter().enumerate() { - set_packed_slice(&mut packed_slice, i, *scalar); - } - - packed_slice - } - proptest! { #[test] fn test_base_ext_mul_8(base_scalars in strategy_8b_scalars(), ext_scalars in strategy_128b_scalars()){ diff --git a/crates/field/src/serialization/canonical.rs b/crates/field/src/serialization/canonical.rs index 394a7c409..6d7a85e4c 100644 --- a/crates/field/src/serialization/canonical.rs +++ b/crates/field/src/serialization/canonical.rs @@ -140,6 +140,21 @@ impl DeserializeCanonical for u8 { } } +impl SerializeCanonical for bool { + fn serialize_canonical(&self, write_buf: impl BufMut) -> Result<(), Error> { + u8::serialize_canonical(&(*self as u8), write_buf) + } +} + +impl DeserializeCanonical for bool { + fn deserialize_canonical(read_buf: impl Buf) -> Result + where + Self: Sized, + { + Ok(u8::deserialize_canonical(read_buf)? != 0) + } +} + impl SerializeCanonical for std::marker::PhantomData { fn serialize_canonical(&self, _write_buf: impl BufMut) -> Result<(), Error> { Ok(()) @@ -202,6 +217,49 @@ impl DeserializeCanonical for Vec { } } +impl SerializeCanonical for Option { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + match self { + Some(value) => { + SerializeCanonical::serialize_canonical(&true, &mut write_buf)?; + SerializeCanonical::serialize_canonical(value, &mut write_buf)?; + } + None => { + SerializeCanonical::serialize_canonical(&false, write_buf)?; + } + } + Ok(()) + } +} + +impl DeserializeCanonical for Option { + fn deserialize_canonical(mut read_buf: impl Buf) -> Result + where + Self: Sized, + { + Ok(match bool::deserialize_canonical(&mut read_buf)? { + true => Some(T::deserialize_canonical(&mut read_buf)?), + false => None, + }) + } +} + +impl SerializeCanonical for (U, V) { + fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { + U::serialize_canonical(&self.0, &mut write_buf)?; + V::serialize_canonical(&self.1, write_buf) + } +} + +impl DeserializeCanonical for (U, V) { + fn deserialize_canonical(mut read_buf: impl Buf) -> Result + where + Self: Sized, + { + Ok((U::deserialize_canonical(&mut read_buf)?, V::deserialize_canonical(read_buf)?)) + } +} + impl> SerializeCanonical for GenericArray { fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { assert_enough_space_for(&write_buf, N::USIZE)?; diff --git a/crates/field/src/serialization/error.rs b/crates/field/src/serialization/error.rs index bbab7b99c..2fa17e5e2 100644 --- a/crates/field/src/serialization/error.rs +++ b/crates/field/src/serialization/error.rs @@ -8,6 +8,12 @@ pub enum Error { NotEnoughBytes, #[error("Unknown enum variant index {name}::{index}")] UnknownEnumVariant { name: &'static str, index: u8 }, + #[error("Serialization has not been implemented")] + SerializationNotImplemented, + #[error("Deserializer has not been implemented")] + DeserializerNotImplented, + #[error("Multiple deserializers with the same name {name} has been registered")] + DeserializerNameConflict { name: String }, #[error("FromUtf8Error: {0}")] FromUtf8Error(#[from] std::string::FromUtf8Error), } diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index 2b8daf9f7..4bde1bf1b 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -7,7 +7,7 @@ mod composition_poly; use proc_macro::TokenStream; use quote::{quote, ToTokens}; -use syn::{parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields}; +use syn::{parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, ItemImpl}; use crate::{ arith_circuit_poly::ArithCircuitPolyItem, arith_expr::ArithExprItem, @@ -277,6 +277,42 @@ pub fn derive_deserialize_canonical(input: TokenStream) -> TokenStream { .into() } +/// Use on an impl block for MultivariatePoly, to automatically implement erased_serialize_canonical. +/// +/// Importantly, this will serialize the concrete instance, prefixed by the identifier of the data type. +/// +/// This prefix can be used to figure out which concrete data type it should use for deserialization later. +#[proc_macro_attribute] +pub fn erased_serialize_canonical(_attr: TokenStream, item: TokenStream) -> TokenStream { + let mut item_impl: ItemImpl = parse_macro_input!(item); + let syn::Type::Path(p) = &*item_impl.self_ty else { + return syn::Error::new( + item_impl.span(), + "#[erased_serialize_canonical] can only be used on an impl for a concrete type", + ) + .into_compile_error() + .into(); + }; + let name = p.path.segments.last().unwrap().ident.to_string(); + + let method = parse_quote! { + fn erased_serialize_canonical( + &self, + write_buf: &mut dyn binius_field::bytes::BufMut, + ) -> Result<(), binius_field::serialization::Error> { + binius_field::SerializeCanonical::serialize_canonical(&#name, &mut *write_buf)?; + binius_field::SerializeCanonical::serialize_canonical(self, &mut *write_buf) + } + }; + + item_impl.items.push(syn::ImplItem::Fn(method)); + + quote! { + #item_impl + } + .into() +} + fn field_names(fields: Fields, positional_prefix: Option<&str>) -> Vec { match fields { Fields::Named(fields) => fields diff --git a/crates/math/Cargo.toml b/crates/math/Cargo.toml index d077173ab..d37e65281 100644 --- a/crates/math/Cargo.toml +++ b/crates/math/Cargo.toml @@ -9,6 +9,7 @@ workspace = true [dependencies] binius_field = { path = "../field" } +binius_macros = { path = "../macros" } binius_maybe_rayon = { path = "../maybe_rayon", default-features = false } binius_utils = { path = "../utils", default-features = false } auto_impl.workspace = true diff --git a/crates/math/src/arith_expr.rs b/crates/math/src/arith_expr.rs index 3c0e4adcb..c5cb0bbab 100644 --- a/crates/math/src/arith_expr.rs +++ b/crates/math/src/arith_expr.rs @@ -7,7 +7,8 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; -use binius_field::{Field, PackedField, TowerField}; +use binius_field::{DeserializeCanonical, Field, PackedField, SerializeCanonical, TowerField}; +use binius_macros::{DeserializeCanonical, SerializeCanonical}; use super::error::Error; @@ -16,7 +17,7 @@ use super::error::Error; /// Arithmetic expressions are trees, where the leaves are either constants or variables, and the /// non-leaf nodes are arithmetic operations, such as addition, multiplication, etc. They are /// specific representations of multivariate polynomials. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, SerializeCanonical, DeserializeCanonical)] pub enum ArithExpr { Const(F), Var(usize), @@ -25,6 +26,26 @@ pub enum ArithExpr { Pow(Box>, u64), } +impl SerializeCanonical for Box> { + fn serialize_canonical( + &self, + write_buf: impl binius_field::bytes::BufMut, + ) -> Result<(), binius_field::serialization::Error> { + ArithExpr::::serialize_canonical(&self.to_owned(), write_buf) + } +} + +impl DeserializeCanonical for Box> { + fn deserialize_canonical( + read_buf: impl binius_field::bytes::Buf, + ) -> Result + where + Self: Sized, + { + Ok(Self::new(ArithExpr::::deserialize_canonical(read_buf)?)) + } +} + impl Display for ArithExpr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -136,7 +157,7 @@ impl ArithExpr { &self, ) -> Result, >::Error> { Ok(match self { - Self::Const(val) => ArithExpr::Const((*val).try_into()?), + Self::Const(val) => ArithExpr::Const(FTgt::try_from(*val)?), Self::Var(index) => ArithExpr::Var(*index), Self::Add(left, right) => { let new_left = left.try_convert_field()?; From 2aeb27eb47f4038592c293d4162ff70852edf7b9 Mon Sep 17 00:00:00 2001 From: Joseph Johnston Date: Mon, 17 Feb 2025 13:52:28 +0100 Subject: [PATCH 32/50] [circuits] Optimize plain_lookup using selector flushing (#29) --- crates/circuits/src/plain_lookup.rs | 145 ++++++---------------------- 1 file changed, 32 insertions(+), 113 deletions(-) diff --git a/crates/circuits/src/plain_lookup.rs b/crates/circuits/src/plain_lookup.rs index 8ec4b38b5..afcb77ac7 100644 --- a/crates/circuits/src/plain_lookup.rs +++ b/crates/circuits/src/plain_lookup.rs @@ -1,15 +1,11 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_core::{ - constraint_system::channel::{Boundary, FlushDirection}, - oracle::OracleId, -}; +use binius_core::{constraint_system::channel::FlushDirection, oracle::OracleId}; use binius_field::{ as_packed_field::PackScalar, packed::set_packed_slice, BinaryField1b, ExtensionField, Field, TowerField, }; use bytemuck::Pod; -use itertools::izip; use crate::builder::{ types::{F, U}, @@ -27,45 +23,36 @@ use crate::builder::{ /// # Parameters /// - `builder`: a mutable reference to the `ConstraintSystemBuilder`. /// - `table`: an oracle holding the table of valid lookup values. -/// - `table_count`: only the first `table_count` values of `table` are considered valid lookup values. -/// - `balancer_value`: any valid table value, needed for balancing the channel. /// - `lookup_values`: an oracle holding the values to be looked up. /// - `lookup_values_count`: only the first `lookup_values_count` values in `lookup_values` will be looked up. /// -/// # Constraints -/// - no value in `lookup_values` can be looked only less than `1 << LOG_MAX_MULTIPLICITY` times, limiting completeness not soundness. -/// /// # How this Works /// We create a single channel for this lookup. /// We let the prover push all values in `lookup_values`, that is all values to be looked up, into the channel. /// We also must pull valid table values (i.e. values that appear in `table`) from the channel if the channel is to balance. /// By ensuring that only valid table values get pulled from the channel, and observing the channel to balance, we ensure that only valid table values get pushed (by the prover) into the channel. /// Therefore our construction is sound. -/// In order for the construction to be complete, allowing an honest prover to pass, we must pull each table value from the channel with exactly the same multiplicity (duplicate count) that the prover pushed that table value into the channel. +/// In order for the construction to be complete, allowing an honest prover to pass, we must pull each +/// table value from the channel with exactly the same multiplicity (duplicate count) that the prover pushed that table value into the channel. /// To do so, we allow the prover to commit information on the multiplicity of each table value. /// -/// The prover counts the multiplicity of each table value, and commits columns holding the bit-decomposition of the multiplicities. -/// Using these bit columns we create `component` columns the same height as the table, which select the table value where a multiplicity bit is 1 and select `balancer_value` where the bit is 0. -/// Pulling these component columns out of the channel with appropriate multiplicities, we pull out each table value from the channel with the multiplicity requested by the prover. -/// Due to the `balancer_value` appearing in the component columns, however, we will also pull the table value `balancer_value` from the channel many more times than needed. -/// To rectify this we put `balancer_value` in a boundary value and push this boundary value to the channel with a multiplicity that will balance the channel. -/// This boundary value is returned from the gadget. +/// The prover counts the multiplicity of each table value, and creates a bit column for +/// each of the LOG_MAX_MULTIPLICITY bits in the bit-decomposition of the multiplicities. +/// Then we flush the table values LOG_MAX_MULTIPLICITY times, each time using a different bit column as the 'selector' oracle to select which values in the +/// table actually get pushed into the channel flushed. When flushing the table with the i'th bit column as the selector, we flush with multiplicity 1 << i. /// pub fn plain_lookup( builder: &mut ConstraintSystemBuilder, table: OracleId, - table_count: usize, - balancer_value: FS, lookup_values: OracleId, lookup_values_count: usize, -) -> Result, anyhow::Error> +) -> Result<(), anyhow::Error> where U: PackScalar + Pod, F: ExtensionField, FS: TowerField + Pod, { let n_vars = builder.log_rows([table])?; - debug_assert!(table_count <= 1 << n_vars); let channel = builder.add_channel(); @@ -78,52 +65,24 @@ where let values_slice = witness.get::(lookup_values)?.as_slice::(); multiplicities = Some(count_multiplicities( - &table_slice[0..table_count], + &table_slice[0..1 << n_vars], &values_slice[0..lookup_values_count], false, )?); } - let components: [OracleId; LOG_MAX_MULTIPLICITY] = get_components::( - builder, - table, - table_count, - balancer_value, - multiplicities, - )?; - - components - .into_iter() - .enumerate() - .try_for_each(|(i, component)| { - builder.flush_with_multiplicity( - FlushDirection::Pull, - channel, - table_count, - [component], - 1 << i, - ) - })?; - - let balancer_value_multiplicity = - (((1 << LOG_MAX_MULTIPLICITY) - 1) * table_count - lookup_values_count) as u64; + let bits: [OracleId; LOG_MAX_MULTIPLICITY] = get_bits(builder, table, multiplicities)?; + bits.into_iter().enumerate().try_for_each(|(i, bit)| { + builder.flush_custom(FlushDirection::Pull, channel, bit, [table], 1 << i) + })?; - let boundary = Boundary { - values: vec![balancer_value.into()], - channel_id: channel, - direction: FlushDirection::Push, - multiplicity: balancer_value_multiplicity, - }; - - Ok(boundary) + Ok(()) } -// the `i`'th returned component holds values that are the product of the `table` values and the bits had by taking the `i`'th bit across the multiplicities. -fn get_components( +// the `i`'th returned bit column holds the `i`'th multiplicity bit. +fn get_bits( builder: &mut ConstraintSystemBuilder, table: OracleId, - table_count: usize, - balancer_value: FS, multiplicities: Option>, ) -> Result<[OracleId; LOG_MAX_MULTIPLICITY], anyhow::Error> where @@ -136,13 +95,10 @@ where let bits: [OracleId; LOG_MAX_MULTIPLICITY] = builder .add_committed_multiple::("bits", n_vars, BinaryField1b::TOWER_LEVEL); - let components: [OracleId; LOG_MAX_MULTIPLICITY] = builder - .add_committed_multiple::("components", n_vars, FS::TOWER_LEVEL); - if let Some(witness) = builder.witness() { let multiplicities = multiplicities.ok_or_else(|| anyhow::anyhow!("multiplicities empty for prover"))?; - debug_assert_eq!(table_count, multiplicities.len()); + debug_assert_eq!(1 << n_vars, multiplicities.len()); // check all multiplicities are in range if multiplicities @@ -157,19 +113,13 @@ where // create the columns for the bits let mut bit_cols = bits.map(|bit| witness.new_column::(bit)); let mut packed_bit_cols = bit_cols.each_mut().map(|bit_col| bit_col.packed()); - // create the columns for the components - let mut component_cols = components.map(|component| witness.new_column::(component)); - let mut packed_component_cols = component_cols - .each_mut() - .map(|component_col| component_col.packed()); - - let table_slice = witness.get::(table)?.as_slice::(); - izip!(table_slice, multiplicities).enumerate().for_each( - |(i, (table_val, multiplicity))| { - for j in 0..LOG_MAX_MULTIPLICITY { + multiplicities + .iter() + .enumerate() + .for_each(|(i, multiplicity)| { + (0..LOG_MAX_MULTIPLICITY).for_each(|j| { let bit_set = multiplicity & (1 << j) != 0; - // set the bit value set_packed_slice( packed_bit_cols[j], i, @@ -178,36 +128,11 @@ where false => BinaryField1b::ZERO, }, ); - // set the component value - set_packed_slice( - packed_component_cols[j], - i, - match bit_set { - true => *table_val, - false => balancer_value, - }, - ); - } - }, - ); + }) + }); } - let expression = { - use binius_math::ArithExpr as Expr; - let table = Expr::Var(0); - let bit = Expr::Var(1); - let component = Expr::Var(2); - component - (bit.clone() * table + (Expr::one() - bit) * Expr::Const(balancer_value)) - }; - (0..LOG_MAX_MULTIPLICITY).for_each(|i| { - builder.assert_zero( - format!("lookup_{i}"), - [table, bits[i], components[i]], - expression.convert_field(), - ); - }); - - Ok(components) + Ok(bits) } #[cfg(test)] @@ -247,21 +172,17 @@ pub mod test_plain_lookup { pub fn test_u8_mul_lookup( builder: &mut ConstraintSystemBuilder, log_lookup_count: usize, - ) -> Result, anyhow::Error> { + ) -> Result<(), anyhow::Error> { let table_values = generate_u8_mul_table(); let table = transparent::make_transparent( builder, "u8_mul_table", bytemuck::cast_slice::<_, BinaryField32b>(&table_values), )?; - let balancer_value = BinaryField32b::new(table_values[99]); // any table value let lookup_values = builder.add_committed("lookup_values", log_lookup_count, BinaryField32b::TOWER_LEVEL); - // reduce these if only some table values are valid - // or only some lookup_values are to be looked up - let table_count = table_values.len(); let lookup_values_count = 1 << log_lookup_count; if let Some(witness) = builder.witness() { @@ -270,16 +191,14 @@ pub mod test_plain_lookup { generate_random_u8_mul_claims(&mut mut_slice[0..lookup_values_count]); } - let boundary = plain_lookup::( + plain_lookup::( builder, table, - table_count, - balancer_value, lookup_values, lookup_values_count, )?; - Ok(boundary) + Ok(()) } } @@ -387,7 +306,7 @@ mod tests { let allocator = bumpalo::Bump::new(); let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let boundary = test_plain_lookup::test_u8_mul_lookup::( + test_plain_lookup::test_u8_mul_lookup::( &mut builder, log_lookup_count, ) @@ -412,7 +331,7 @@ mod tests { &constraint_system, log_inv_rate, security_bits, - &[boundary], + &[], witness, &domain_factory, &backend, @@ -424,7 +343,7 @@ mod tests { { let mut builder = ConstraintSystemBuilder::new(); - let boundary = test_plain_lookup::test_u8_mul_lookup::( + test_plain_lookup::test_u8_mul_lookup::( &mut builder, log_lookup_count, ) @@ -438,7 +357,7 @@ mod tests { Groestl256, Groestl256ByteCompression, HasherChallenger, - >(&constraint_system, log_inv_rate, security_bits, &[boundary], proof) + >(&constraint_system, log_inv_rate, security_bits, &[], proof) .unwrap(); } } From 1cfd900cdabe71026e3978ae44127e2916502f50 Mon Sep 17 00:00:00 2001 From: Anex007 Date: Mon, 17 Feb 2025 15:36:06 -0500 Subject: [PATCH 33/50] [scripts] Remove groestl run from benchmark script (#26) --- scripts/nightly_benchmarks.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/scripts/nightly_benchmarks.py b/scripts/nightly_benchmarks.py index 7aea6980d..1c2f9ad03 100755 --- a/scripts/nightly_benchmarks.py +++ b/scripts/nightly_benchmarks.py @@ -15,7 +15,6 @@ SAMPLE_SIZE = 5 KECCAKF_PERMS = 1 << 13 -GROESTLP_PERMS = 1 << 14 VISION32B_PERMS = 1 << 14 SHA256_PERMS = 1 << 14 NUM_BINARY_OPS = 1 << 22 @@ -29,13 +28,6 @@ "args": ["keccakf_circuit", "--", "--n-permutations"], "n_ops": KECCAKF_PERMS, }, - "groestlp": { - "type": "hasher", - "display": r"Grøstl P", - "export": "groestl-report.csv", - "args": ["groestl_circuit", "--", "--n-permutations"], - "n_ops": GROESTLP_PERMS, - }, "vision32b": { "type": "hasher", "display": r"Vision Mark-32", From af81f5f3d85574c658c17df55135d12856ab35a7 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Tue, 18 Feb 2025 09:29:37 +0100 Subject: [PATCH 34/50] [arith_expr]: Statically compile exponentiation in ArithCircuitPoly (#15) --- crates/core/src/polynomial/arith_circuit.rs | 424 ++++++++++++++++++-- 1 file changed, 396 insertions(+), 28 deletions(-) diff --git a/crates/core/src/polynomial/arith_circuit.rs b/crates/core/src/polynomial/arith_circuit.rs index 69b4a4224..b0fe14f6f 100644 --- a/crates/core/src/polynomial/arith_circuit.rs +++ b/crates/core/src/polynomial/arith_circuit.rs @@ -50,10 +50,22 @@ fn circuit_steps_for_expr( result.push(CircuitStep::Mul(left, right)); CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1)) } - ArithExpr::Pow(id, exp) => { - let id = to_circuit_inner(id, result); - result.push(CircuitStep::Pow(id, *exp)); - CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1)) + ArithExpr::Pow(base, exp) => { + let mut acc = to_circuit_inner(base, result); + let base_expr = acc; + let highest_bit = exp.ilog2(); + + for i in (0..highest_bit).rev() { + result.push(CircuitStep::Square(acc)); + acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1)); + + if (exp >> i) & 1 != 0 { + result.push(CircuitStep::Mul(acc, base_expr)); + acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1)); + } + } + + acc } } } @@ -63,7 +75,7 @@ fn circuit_steps_for_expr( } /// Input of the circuit calculation step -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] enum CircuitNode { /// Input variable Var(usize), @@ -87,7 +99,7 @@ impl CircuitNode { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] enum CircuitStepArgument { Expr(CircuitNode), Const(F), @@ -101,7 +113,7 @@ enum CircuitStepArgument { enum CircuitStep { Add(CircuitStepArgument, CircuitStepArgument), Mul(CircuitStepArgument, CircuitStepArgument), - Pow(CircuitStepArgument, u64), + Square(CircuitStepArgument), AddMul(usize, CircuitStepArgument, CircuitStepArgument), } @@ -231,8 +243,8 @@ impl CompositionPoly for ArithCircuitPoly { after, get_argument_value(*x, before) * get_argument_value(*y, before), ), - CircuitStep::Pow(id, exp) => { - write_result(after, get_argument_value(*id, before).pow(*exp)) + CircuitStep::Square(x) => { + write_result(after, get_argument_value(*x, before).square()) } }; } @@ -290,29 +302,31 @@ impl CompositionPoly for ArithCircuitPoly { }, ); } - CircuitStep::Pow(id, exp) => match id { - CircuitStepArgument::Expr(id) => { - let id = id.get_sparse_chunk(batch_query, before, row_len); - for j in 0..row_len { - // Safety: `current` and `id` have length equal to `row_len` - unsafe { - current - .get_unchecked_mut(j) - .write(id.get_unchecked(j).pow(*exp)); + CircuitStep::Square(arg) => { + match arg { + CircuitStepArgument::Expr(node) => { + let id_chunk = node.get_sparse_chunk(batch_query, before, row_len); + for j in 0..row_len { + // Safety: `current` and `id_chunk` have length equal to `row_len` + unsafe { + current + .get_unchecked_mut(j) + .write(id_chunk.get_unchecked(j).square()); + } } } - } - CircuitStepArgument::Const(id) => { - let id: P = P::broadcast((*id).into()); - let result = id.pow(*exp); - for j in 0..row_len { - // Safety: `current` has length equal to `row_len` - unsafe { - current.get_unchecked_mut(j).write(result); + CircuitStepArgument::Const(value) => { + let value: P = P::broadcast((*value).into()); + let result = value.square(); + for j in 0..row_len { + // Safety: `current` has length equal to `row_len` + unsafe { + current.get_unchecked_mut(j).write(result); + } } } } - }, + } CircuitStep::AddMul(target, left, right) => { let target = &before[row_len * target..(target + 1) * row_len]; // Safety: by construction of steps and evaluation order we know @@ -751,7 +765,7 @@ mod tests { // ((x0^2)^3)^4 let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4); let circuit = ArithCircuitPoly::::new(expr); - assert_eq!(circuit.steps.len(), 1); + assert_eq!(circuit.steps.len(), 5); let typed_circuit: &dyn CompositionPolyOS

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), 0); @@ -769,4 +783,358 @@ mod tests { P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])), ); } + + #[test] + fn test_circuit_steps_for_expr_constant() { + type F = BinaryField8b; + + let expr = ArithExpr::Const(F::new(5)); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert!(steps.is_empty(), "No steps should be generated for a constant"); + assert_eq!(retval, CircuitStepArgument::Const(F::new(5))); + } + + #[test] + fn test_circuit_steps_for_expr_variable() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(18); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert!(steps.is_empty(), "No steps should be generated for a variable"); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18)))); + } + + #[test] + fn test_circuit_steps_for_expr_addition() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(14) + ArithExpr::::Var(56); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 1, "One addition step should be generated"); + assert!(matches!( + steps[0], + CircuitStep::Add( + CircuitStepArgument::Expr(CircuitNode::Var(14)), + CircuitStepArgument::Expr(CircuitNode::Var(56)) + ) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0)))); + } + + #[test] + fn test_circuit_steps_for_expr_multiplication() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(36) * ArithExpr::Var(26); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 1, "One multiplication step should be generated"); + assert!(matches!( + steps[0], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Var(36)), + CircuitStepArgument::Expr(CircuitNode::Var(26)) + ) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_1() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(12).pow(1); + let (steps, retval) = circuit_steps_for_expr(&expr); + + // No steps should be generated for x^1 + assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps"); + + // The return value should just be the variable itself + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_2() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(10).pow(2); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step"); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10))) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_3() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(5).pow(3); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!( + steps.len(), + 2, + "Pow(3) should generate one squaring and one multiplication step" + ); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5))) + )); + assert!(matches!( + steps[1], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(0)), + CircuitStepArgument::Expr(CircuitNode::Var(5)) + ) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_4() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(7).pow(4); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps"); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7))) + )); + + assert!(matches!( + steps[1], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0))) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_5() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(3).pow(5); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!( + steps.len(), + 3, + "Pow(5) should generate two squaring steps and one multiplication" + ); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3))) + )); + assert!(matches!( + steps[1], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0))) + )); + assert!(matches!( + steps[2], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(1)), + CircuitStepArgument::Expr(CircuitNode::Var(3)) + ) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_8() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(4).pow(8); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps"); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4))) + )); + assert!(matches!( + steps[1], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0))) + )); + assert!(matches!( + steps[2], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1))) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_9() { + type F = BinaryField8b; + + let expr = ArithExpr::::Var(8).pow(9); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!( + steps.len(), + 4, + "Pow(9) should generate three squaring steps and one multiplication" + ); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8))) + )); + assert!(matches!( + steps[1], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0))) + )); + assert!(matches!( + steps[2], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1))) + )); + assert!(matches!( + steps[3], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(2)), + CircuitStepArgument::Expr(CircuitNode::Var(8)) + ) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_12() { + type F = BinaryField8b; + let expr = ArithExpr::::Var(6).pow(12); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps."); + + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6))) + )); + assert!(matches!( + steps[1], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(0)), + CircuitStepArgument::Expr(CircuitNode::Var(6)) + ) + )); + assert!(matches!( + steps[2], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1))) + )); + assert!(matches!( + steps[3], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2))) + )); + + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3)))); + } + + #[test] + fn test_circuit_steps_for_expr_pow_13() { + type F = BinaryField8b; + let expr = ArithExpr::::Var(7).pow(13); + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps."); + assert!(matches!( + steps[0], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7))) + )); + assert!(matches!( + steps[1], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(0)), + CircuitStepArgument::Expr(CircuitNode::Var(7)) + ) + )); + assert!(matches!( + steps[2], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1))) + )); + assert!(matches!( + steps[3], + CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2))) + )); + assert!(matches!( + steps[4], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Slot(3)), + CircuitStepArgument::Expr(CircuitNode::Var(7)) + ) + )); + assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4)))); + } + + #[test] + fn test_circuit_steps_for_expr_complex() { + type F = BinaryField8b; + + let expr = (ArithExpr::::Var(0) * ArithExpr::Var(1)) + + (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2) + - ArithExpr::Var(3); + + let (steps, retval) = circuit_steps_for_expr(&expr); + + assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps"); + + assert!( + matches!( + steps[0], + CircuitStep::Mul( + CircuitStepArgument::Expr(CircuitNode::Var(0)), + CircuitStepArgument::Expr(CircuitNode::Var(1)) + ) + ), + "First step should be multiplication x0 * x1" + ); + + assert!( + matches!( + steps[1], + CircuitStep::Add( + CircuitStepArgument::Const(F::ONE), + CircuitStepArgument::Expr(CircuitNode::Var(0)) + ) + ), + "Second step should be (1 - x0)" + ); + + assert!( + matches!( + steps[2], + CircuitStep::AddMul( + 0, + CircuitStepArgument::Expr(CircuitNode::Slot(1)), + CircuitStepArgument::Expr(CircuitNode::Var(2)) + ) + ), + "Third step should be (1 - x0) * x2" + ); + + assert!( + matches!( + steps[3], + CircuitStep::Add( + CircuitStepArgument::Expr(CircuitNode::Slot(0)), + CircuitStepArgument::Expr(CircuitNode::Var(3)) + ) + ), + "Fourth step should be x0 * x1 + (1 - x0) * x2 + x3" + ); + + assert!( + matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))), + "Final result should be stored in Slot(3)" + ); + } } From 401639f7a72c5bd0c925aa6633fe9ca50d1c0fe6 Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Tue, 18 Feb 2025 10:11:06 +0100 Subject: [PATCH 35/50] [serialization] Introduce SerializationMode (#36) Changes: Adds SerializaitonMode that specifies whether to use native (fast) or canonical (needed for transcript) serialization/deserializtion. You need to use the same mode for serialization and deserialization. SerializeCanonical is renamed to SerializeBytes, and takes an extra argument of type SerializationMode DeserializeCanonical is renamed to DeserializeBytes and takes an extra argument of type SerializationMode SerializeBytes and DeserializeBytes are now required bounds for the Field trait, rather than being generically implemented for TowerField. u16, u32, u64, u128 now serialize to/deserialize from little-endian rather than big-endian byte order, to be consistent with BinaryField*b serialization. The serialization traits are moved back to binius_utils Automatic implementations of SerializeBytes for Box and &(T: SerializeBytes) Automatic implementation of DeserializeBytes for Box --- crates/core/src/constraint_system/channel.rs | 8 +- crates/core/src/constraint_system/mod.rs | 24 +- .../src/merkle_tree/binary_merkle_tree.rs | 9 +- crates/core/src/merkle_tree/scheme.rs | 6 +- crates/core/src/oracle/constraint.rs | 8 +- crates/core/src/oracle/multilinear.rs | 109 +++-- crates/core/src/piop/prove.rs | 8 +- crates/core/src/piop/tests.rs | 8 +- crates/core/src/piop/verify.rs | 8 +- crates/core/src/polynomial/multivariate.rs | 13 +- crates/core/src/protocols/fri/prove.rs | 8 +- crates/core/src/protocols/fri/verify.rs | 6 +- crates/core/src/ring_switch/tests.rs | 6 +- crates/core/src/transcript/error.rs | 4 +- crates/core/src/transcript/mod.rs | 32 +- crates/core/src/transparent/constant.rs | 17 +- .../src/transparent/multilinear_extension.rs | 38 +- crates/core/src/transparent/powers.rs | 17 +- crates/core/src/transparent/select_row.rs | 17 +- crates/core/src/transparent/serialization.rs | 70 +-- crates/core/src/transparent/step_down.rs | 17 +- crates/core/src/transparent/step_up.rs | 17 +- crates/core/src/transparent/tower_basis.rs | 17 +- crates/field/Cargo.toml | 2 - crates/field/src/aes_field.rs | 81 +++- crates/field/src/binary_field.rs | 107 ++--- crates/field/src/field.rs | 3 + crates/field/src/lib.rs | 3 - crates/field/src/polyval.rs | 53 ++- crates/field/src/serialization/bytes.rs | 62 --- crates/field/src/serialization/canonical.rs | 292 ------------ crates/field/src/serialization/error.rs | 19 - crates/field/src/serialization/mod.rs | 9 - crates/field/src/underlier/small_uint.rs | 26 +- crates/hash/src/serialization.rs | 8 +- crates/macros/Cargo.toml | 1 + crates/macros/src/lib.rs | 78 ++-- crates/math/src/arith_expr.rs | 26 +- crates/utils/Cargo.toml | 2 + crates/utils/src/lib.rs | 4 + crates/utils/src/serialization.rs | 437 ++++++++++++++++++ 41 files changed, 902 insertions(+), 778 deletions(-) delete mode 100644 crates/field/src/serialization/bytes.rs delete mode 100644 crates/field/src/serialization/canonical.rs delete mode 100644 crates/field/src/serialization/error.rs delete mode 100644 crates/field/src/serialization/mod.rs create mode 100644 crates/utils/src/serialization.rs diff --git a/crates/core/src/constraint_system/channel.rs b/crates/core/src/constraint_system/channel.rs index 1242e781b..51bf5d5c3 100644 --- a/crates/core/src/constraint_system/channel.rs +++ b/crates/core/src/constraint_system/channel.rs @@ -52,14 +52,14 @@ use std::collections::HashMap; use binius_field::{as_packed_field::PackScalar, underlier::UnderlierType, TowerField}; -use binius_macros::{DeserializeCanonical, SerializeCanonical}; +use binius_macros::{DeserializeBytes, SerializeBytes}; use super::error::{Error, VerificationError}; use crate::{oracle::OracleId, witness::MultilinearExtensionIndex}; pub type ChannelId = usize; -#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct Flush { pub oracles: Vec, pub channel_id: ChannelId, @@ -68,7 +68,7 @@ pub struct Flush { pub multiplicity: u64, } -#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct Boundary { pub values: Vec, pub channel_id: ChannelId, @@ -76,7 +76,7 @@ pub struct Boundary { pub multiplicity: u64, } -#[derive(Debug, Clone, Copy, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Clone, Copy, SerializeBytes, DeserializeBytes)] pub enum FlushDirection { Push, Pull, diff --git a/crates/core/src/constraint_system/mod.rs b/crates/core/src/constraint_system/mod.rs index 2c040c41e..81719eac4 100644 --- a/crates/core/src/constraint_system/mod.rs +++ b/crates/core/src/constraint_system/mod.rs @@ -7,8 +7,9 @@ mod prove; pub mod validate; mod verify; -use binius_field::{serialization, BinaryField128b, DeserializeCanonical, TowerField}; -use binius_macros::SerializeCanonical; +use binius_field::{BinaryField128b, TowerField}; +use binius_macros::SerializeBytes; +use binius_utils::{DeserializeBytes, SerializationError, SerializationMode}; use channel::{ChannelId, Flush}; pub use prove::prove; pub use verify::verify; @@ -22,7 +23,7 @@ use crate::oracle::{ConstraintSet, MultilinearOracleSet, OracleId}; /// /// As a result, a ConstraintSystem allows us to validate all of these /// constraints against a witness, as well as enabling generic prove/verify -#[derive(Debug, Clone, SerializeCanonical)] +#[derive(Debug, Clone, SerializeBytes)] pub struct ConstraintSystem { pub oracles: MultilinearOracleSet, pub table_constraints: Vec>, @@ -31,17 +32,20 @@ pub struct ConstraintSystem { pub max_channel_id: ChannelId, } -impl DeserializeCanonical for ConstraintSystem { - fn deserialize_canonical(mut read_buf: impl bytes::Buf) -> Result +impl DeserializeBytes for ConstraintSystem { + fn deserialize( + mut read_buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result where Self: Sized, { Ok(Self { - oracles: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, - table_constraints: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, - non_zero_oracle_ids: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, - flushes: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, - max_channel_id: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + oracles: DeserializeBytes::deserialize(&mut read_buf, mode)?, + table_constraints: DeserializeBytes::deserialize(&mut read_buf, mode)?, + non_zero_oracle_ids: DeserializeBytes::deserialize(&mut read_buf, mode)?, + flushes: DeserializeBytes::deserialize(&mut read_buf, mode)?, + max_channel_id: DeserializeBytes::deserialize(&mut read_buf, mode)?, }) } } diff --git a/crates/core/src/merkle_tree/binary_merkle_tree.rs b/crates/core/src/merkle_tree/binary_merkle_tree.rs index 7cde01b8d..86383830d 100644 --- a/crates/core/src/merkle_tree/binary_merkle_tree.rs +++ b/crates/core/src/merkle_tree/binary_merkle_tree.rs @@ -2,10 +2,12 @@ use std::{array, fmt::Debug, mem::MaybeUninit}; -use binius_field::{SerializeCanonical, TowerField}; +use binius_field::TowerField; use binius_hash::{HashBuffer, PseudoCompressionFunction}; use binius_maybe_rayon::{prelude::*, slice::ParallelSlice}; -use binius_utils::{bail, checked_arithmetics::log2_strict_usize}; +use binius_utils::{ + bail, checked_arithmetics::log2_strict_usize, SerializationMode, SerializeBytes, +}; use digest::{crypto_common::BlockSizeUser, Digest, FixedOutputReset, Output}; use tracing::instrument; @@ -210,7 +212,8 @@ where { let mut hash_buffer = HashBuffer::new(hasher); for elem in elems { - SerializeCanonical::serialize_canonical(&elem, &mut hash_buffer) + let mode = SerializationMode::CanonicalTower; + SerializeBytes::serialize(&elem, &mut hash_buffer, mode) .expect("HashBuffer has infinite capacity"); } } diff --git a/crates/core/src/merkle_tree/scheme.rs b/crates/core/src/merkle_tree/scheme.rs index 63ef48380..5055e3549 100644 --- a/crates/core/src/merkle_tree/scheme.rs +++ b/crates/core/src/merkle_tree/scheme.rs @@ -2,11 +2,12 @@ use std::{array, fmt::Debug, marker::PhantomData}; -use binius_field::{SerializeCanonical, TowerField}; +use binius_field::TowerField; use binius_hash::{HashBuffer, PseudoCompressionFunction}; use binius_utils::{ bail, checked_arithmetics::{log2_ceil_usize, log2_strict_usize}, + SerializationMode, SerializeBytes, }; use bytes::Buf; use digest::{core_api::BlockSizeUser, Digest, Output}; @@ -173,7 +174,8 @@ where { let mut buffer = HashBuffer::new(&mut hasher); for elem in elems { - SerializeCanonical::serialize_canonical(elem, &mut buffer) + let mode = SerializationMode::CanonicalTower; + SerializeBytes::serialize(elem, &mut buffer, mode) .expect("HashBuffer has infinite capacity"); } } diff --git a/crates/core/src/oracle/constraint.rs b/crates/core/src/oracle/constraint.rs index bd489381b..9edd31f5d 100644 --- a/crates/core/src/oracle/constraint.rs +++ b/crates/core/src/oracle/constraint.rs @@ -4,7 +4,7 @@ use core::iter::IntoIterator; use std::sync::Arc; use binius_field::{Field, TowerField}; -use binius_macros::{DeserializeCanonical, SerializeCanonical}; +use binius_macros::{DeserializeBytes, SerializeBytes}; use binius_math::{ArithExpr, CompositionPolyOS}; use binius_utils::bail; use itertools::Itertools; @@ -16,7 +16,7 @@ use super::{Error, MultilinearOracleSet, MultilinearPolyVariant, OracleId}; pub type TypeErasedComposition

= Arc>; /// Constraint is a type erased composition along with a predicate on its values on the boolean hypercube -#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct Constraint { pub name: String, pub composition: ArithExpr, @@ -25,14 +25,14 @@ pub struct Constraint { /// Predicate can either be a sum of values of a composition on the hypercube (sumcheck) or equality to zero /// on the hypercube (zerocheck) -#[derive(Clone, Debug, SerializeCanonical, DeserializeCanonical)] +#[derive(Clone, Debug, SerializeBytes, DeserializeBytes)] pub enum ConstraintPredicate { Sum(F), Zero, } /// Constraint set is a group of constraints that operate over the same set of oracle-identified multilinears -#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct ConstraintSet { pub n_vars: usize, pub oracle_ids: Vec, diff --git a/crates/core/src/oracle/multilinear.rs b/crates/core/src/oracle/multilinear.rs index be06ca5da..46d391681 100644 --- a/crates/core/src/oracle/multilinear.rs +++ b/crates/core/src/oracle/multilinear.rs @@ -2,11 +2,10 @@ use std::{array, fmt::Debug, sync::Arc}; -use binius_field::{ - serialization, BinaryField128b, DeserializeCanonical, Field, SerializeCanonical, TowerField, -}; -use binius_macros::{DeserializeCanonical, SerializeCanonical}; -use binius_utils::bail; +use binius_field::{BinaryField128b, Field, TowerField}; +use binius_macros::{DeserializeBytes, SerializeBytes}; +use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes}; +use bytes::Buf; use getset::{CopyGetters, Getters}; use crate::{ @@ -283,18 +282,18 @@ impl MultilinearOracleSetAddition<'_, F> { /// /// The oracle set also tracks the committed polynomial in batches where each batch is committed /// together with a polynomial commitment scheme. -#[derive(Default, Debug, Clone, SerializeCanonical)] +#[derive(Default, Debug, Clone, SerializeBytes)] pub struct MultilinearOracleSet { oracles: Vec>, } -impl binius_field::DeserializeCanonical for MultilinearOracleSet { - fn deserialize_canonical(mut read_buf: impl bytes::Buf) -> Result +impl DeserializeBytes for MultilinearOracleSet { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result where Self: Sized, { Ok(Self { - oracles: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + oracles: DeserializeBytes::deserialize(read_buf, mode)?, }) } } @@ -451,7 +450,7 @@ impl MultilinearOracleSet { /// other oracles. This is formalized in [DP23] Section 4. /// /// [DP23]: -#[derive(Debug, Clone, PartialEq, Eq, SerializeCanonical)] +#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)] pub struct MultilinearPolyOracle { pub id: OracleId, pub name: Option, @@ -460,22 +459,25 @@ pub struct MultilinearPolyOracle { pub variant: MultilinearPolyVariant, } -impl DeserializeCanonical for MultilinearPolyOracle { - fn deserialize_canonical(mut read_buf: impl bytes::Buf) -> Result +impl DeserializeBytes for MultilinearPolyOracle { + fn deserialize( + mut read_buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result where Self: Sized, { Ok(Self { - id: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, - name: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, - n_vars: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, - tower_level: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, - variant: DeserializeCanonical::deserialize_canonical(&mut read_buf)?, + id: DeserializeBytes::deserialize(&mut read_buf, mode)?, + name: DeserializeBytes::deserialize(&mut read_buf, mode)?, + n_vars: DeserializeBytes::deserialize(&mut read_buf, mode)?, + tower_level: DeserializeBytes::deserialize(&mut read_buf, mode)?, + variant: DeserializeBytes::deserialize(&mut read_buf, mode)?, }) } } -#[derive(Debug, Clone, PartialEq, Eq, SerializeCanonical)] +#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes)] pub enum MultilinearPolyVariant { Committed, Transparent(TransparentPolyOracle), @@ -487,25 +489,28 @@ pub enum MultilinearPolyVariant { ZeroPadded(OracleId), } -impl DeserializeCanonical for MultilinearPolyVariant { - fn deserialize_canonical(mut buf: impl bytes::Buf) -> Result +impl DeserializeBytes for MultilinearPolyVariant { + fn deserialize( + mut buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result where Self: Sized, { - Ok(match u8::deserialize_canonical(&mut buf)? { + Ok(match u8::deserialize(&mut buf, mode)? { 0 => Self::Committed, - 1 => Self::Transparent(DeserializeCanonical::deserialize_canonical(&mut buf)?), + 1 => Self::Transparent(DeserializeBytes::deserialize(buf, mode)?), 2 => Self::Repeating { - id: DeserializeCanonical::deserialize_canonical(&mut buf)?, - log_count: DeserializeCanonical::deserialize_canonical(&mut buf)?, + id: DeserializeBytes::deserialize(&mut buf, mode)?, + log_count: DeserializeBytes::deserialize(buf, mode)?, }, - 3 => Self::Projected(DeserializeCanonical::deserialize_canonical(&mut buf)?), - 4 => Self::Shifted(DeserializeCanonical::deserialize_canonical(&mut buf)?), - 5 => Self::Packed(DeserializeCanonical::deserialize_canonical(&mut buf)?), - 6 => Self::LinearCombination(DeserializeCanonical::deserialize_canonical(&mut buf)?), - 7 => Self::ZeroPadded(DeserializeCanonical::deserialize_canonical(&mut buf)?), + 3 => Self::Projected(DeserializeBytes::deserialize(buf, mode)?), + 4 => Self::Shifted(DeserializeBytes::deserialize(buf, mode)?), + 5 => Self::Packed(DeserializeBytes::deserialize(buf, mode)?), + 6 => Self::LinearCombination(DeserializeBytes::deserialize(buf, mode)?), + 7 => Self::ZeroPadded(DeserializeBytes::deserialize(buf, mode)?), variant_index => { - return Err(serialization::Error::UnknownEnumVariant { + return Err(SerializationError::UnknownEnumVariant { name: "MultilinearPolyVariant", index: variant_index, }); @@ -523,25 +528,27 @@ pub struct TransparentPolyOracle { poly: Arc>, } -impl SerializeCanonical for TransparentPolyOracle { - fn serialize_canonical( +impl SerializeBytes for TransparentPolyOracle { + fn serialize( &self, mut write_buf: impl bytes::BufMut, - ) -> Result<(), binius_field::serialization::Error> { - self.poly.erased_serialize_canonical(&mut write_buf) + mode: SerializationMode, + ) -> Result<(), SerializationError> { + self.poly.erased_serialize(&mut write_buf, mode) } } -impl DeserializeCanonical for TransparentPolyOracle { - fn deserialize_canonical( - mut read_buf: impl bytes::Buf, - ) -> Result +impl DeserializeBytes for TransparentPolyOracle { + fn deserialize( + read_buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result where Self: Sized, { - let poly: Box> = - DeserializeCanonical::deserialize_canonical(&mut read_buf)?; - Ok(Self { poly: poly.into() }) + Ok(Self { + poly: Box::>::deserialize(read_buf, mode)?.into(), + }) } } @@ -571,15 +578,13 @@ impl PartialEq for TransparentPolyOracle { impl Eq for TransparentPolyOracle {} -#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub enum ProjectionVariant { FirstVars, LastVars, } -#[derive( - Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeCanonical, DeserializeCanonical, -)] +#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)] pub struct Projected { #[get_copy = "pub"] id: OracleId, @@ -609,16 +614,14 @@ impl Projected { } } -#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub enum ShiftVariant { CircularLeft, LogicalLeft, LogicalRight, } -#[derive( - Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeCanonical, DeserializeCanonical, -)] +#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)] pub struct Shifted { #[get_copy = "pub"] id: OracleId, @@ -660,9 +663,7 @@ impl Shifted { } } -#[derive( - Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeCanonical, DeserializeCanonical, -)] +#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)] pub struct Packed { #[get_copy = "pub"] id: OracleId, @@ -676,9 +677,7 @@ pub struct Packed { log_degree: usize, } -#[derive( - Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeCanonical, DeserializeCanonical, -)] +#[derive(Debug, Clone, PartialEq, Eq, Getters, CopyGetters, SerializeBytes, DeserializeBytes)] pub struct LinearCombination { #[get_copy = "pub"] n_vars: usize, diff --git a/crates/core/src/piop/prove.rs b/crates/core/src/piop/prove.rs index 27e480d38..2d1331ec6 100644 --- a/crates/core/src/piop/prove.rs +++ b/crates/core/src/piop/prove.rs @@ -2,7 +2,7 @@ use binius_field::{ packed::set_packed_slice, BinaryField, Field, PackedExtension, PackedField, - PackedFieldIndexable, SerializeCanonical, TowerField, + PackedFieldIndexable, TowerField, }; use binius_hal::ComputationBackend; use binius_math::{ @@ -10,7 +10,7 @@ use binius_math::{ }; use binius_maybe_rayon::{iter::IntoParallelIterator, prelude::*}; use binius_ntt::{NTTOptions, ThreadingSettings}; -use binius_utils::{bail, sorting::is_sorted_ascending}; +use binius_utils::{bail, sorting::is_sorted_ascending, SerializeBytes}; use either::Either; use itertools::{chain, Itertools}; @@ -175,7 +175,7 @@ where + PackedExtension, M: MultilinearPoly

+ Send + Sync, DomainFactory: EvaluationDomainFactory, - MTScheme: MerkleTreeScheme, + MTScheme: MerkleTreeScheme, MTProver: MerkleTreeProver, Challenger_: Challenger, Backend: ComputationBackend, @@ -254,7 +254,7 @@ where F: TowerField, FEncode: BinaryField, P: PackedFieldIndexable + PackedExtension, - MTScheme: MerkleTreeScheme, + MTScheme: MerkleTreeScheme, MTProver: MerkleTreeProver, Challenger_: Challenger, { diff --git a/crates/core/src/piop/tests.rs b/crates/core/src/piop/tests.rs index 88ee6caa5..67a52273b 100644 --- a/crates/core/src/piop/tests.rs +++ b/crates/core/src/piop/tests.rs @@ -3,15 +3,15 @@ use std::iter::repeat_with; use binius_field::{ - BinaryField, BinaryField16b, BinaryField8b, DeserializeCanonical, Field, - PackedBinaryField2x128b, PackedExtension, PackedField, PackedFieldIndexable, - SerializeCanonical, TowerField, + BinaryField, BinaryField16b, BinaryField8b, Field, PackedBinaryField2x128b, PackedExtension, + PackedField, PackedFieldIndexable, TowerField, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; use binius_math::{ DefaultEvaluationDomainFactory, MLEDirectAdapter, MultilinearExtension, MultilinearPoly, }; +use binius_utils::{DeserializeBytes, SerializeBytes}; use groestl_crypto::Groestl256; use rand::{rngs::StdRng, Rng, SeedableRng}; @@ -111,7 +111,7 @@ fn commit_prove_verify( + PackedExtension + PackedExtension + PackedExtension, - MTScheme: MerkleTreeScheme, + MTScheme: MerkleTreeScheme, { let merkle_scheme = merkle_prover.scheme(); diff --git a/crates/core/src/piop/verify.rs b/crates/core/src/piop/verify.rs index 8761a0a0b..a0c5a5e7c 100644 --- a/crates/core/src/piop/verify.rs +++ b/crates/core/src/piop/verify.rs @@ -2,10 +2,10 @@ use std::{borrow::Borrow, cmp::Ordering, iter, ops::Range}; -use binius_field::{BinaryField, DeserializeCanonical, ExtensionField, Field, TowerField}; +use binius_field::{BinaryField, ExtensionField, Field, TowerField}; use binius_math::evaluate_piecewise_multilinear; use binius_ntt::NTTOptions; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use getset::CopyGetters; use tracing::instrument; @@ -291,7 +291,7 @@ where F: TowerField + ExtensionField, FEncode: BinaryField, Challenger_: Challenger, - MTScheme: MerkleTreeScheme, + MTScheme: MerkleTreeScheme, { // Map of n_vars to sumcheck claim descriptions let sumcheck_claim_descs = make_sumcheck_claim_descs( @@ -412,7 +412,7 @@ where F: TowerField + ExtensionField, FEncode: BinaryField, Challenger_: Challenger, - MTScheme: MerkleTreeScheme, + MTScheme: MerkleTreeScheme, { let mut arities_iter = fri_params.fold_arities().iter(); let mut fri_commitments = Vec::with_capacity(fri_params.n_oracles()); diff --git a/crates/core/src/polynomial/multivariate.rs b/crates/core/src/polynomial/multivariate.rs index e154e8e25..8a4ca9c4b 100644 --- a/crates/core/src/polynomial/multivariate.rs +++ b/crates/core/src/polynomial/multivariate.rs @@ -2,11 +2,11 @@ use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sync::Arc}; -use binius_field::{serialization, Field, PackedField}; +use binius_field::{Field, PackedField}; use binius_math::{ ArithExpr, CompositionPolyOS, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef, }; -use binius_utils::bail; +use binius_utils::{bail, SerializationError, SerializationMode}; use bytes::BufMut; use itertools::Itertools; use rand::{rngs::StdRng, SeedableRng}; @@ -32,12 +32,13 @@ pub trait MultivariatePoly

: Debug + Send + Sync { /// Serialize a type erased MultivariatePoly. /// Since not every MultivariatePoly implements serialization, this defaults to returning an error. - fn erased_serialize_canonical( + fn erased_serialize( &self, write_buf: &mut dyn BufMut, - ) -> Result<(), serialization::Error> { - let _ = write_buf; - Err(serialization::Error::SerializationNotImplemented) + mode: SerializationMode, + ) -> Result<(), SerializationError> { + let _ = (write_buf, mode); + Err(SerializationError::SerializationNotImplemented) } } diff --git a/crates/core/src/protocols/fri/prove.rs b/crates/core/src/protocols/fri/prove.rs index 2bb66fd6a..2a3b9f49f 100644 --- a/crates/core/src/protocols/fri/prove.rs +++ b/crates/core/src/protocols/fri/prove.rs @@ -1,11 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{ - BinaryField, ExtensionField, PackedExtension, PackedField, SerializeCanonical, TowerField, -}; +use binius_field::{BinaryField, ExtensionField, PackedExtension, PackedField, TowerField}; use binius_hal::{make_portable_backend, ComputationBackend}; use binius_maybe_rayon::prelude::*; -use binius_utils::bail; +use binius_utils::{bail, SerializeBytes}; use bytemuck::zeroed_vec; use bytes::BufMut; use itertools::izip; @@ -287,7 +285,7 @@ where F: TowerField + ExtensionField, FA: BinaryField, MerkleProver: MerkleTreeProver, - VCS: MerkleTreeScheme, + VCS: MerkleTreeScheme, { /// Constructs a new folder. pub fn new( diff --git a/crates/core/src/protocols/fri/verify.rs b/crates/core/src/protocols/fri/verify.rs index 69c22b88d..85e548144 100644 --- a/crates/core/src/protocols/fri/verify.rs +++ b/crates/core/src/protocols/fri/verify.rs @@ -2,9 +2,9 @@ use std::iter; -use binius_field::{BinaryField, DeserializeCanonical, ExtensionField, TowerField}; +use binius_field::{BinaryField, ExtensionField, TowerField}; use binius_hal::{make_portable_backend, ComputationBackend}; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use bytes::Buf; use itertools::izip; use tracing::instrument; @@ -44,7 +44,7 @@ impl<'a, F, FA, VCS> FRIVerifier<'a, F, FA, VCS> where F: TowerField + ExtensionField, FA: BinaryField, - VCS: MerkleTreeScheme, + VCS: MerkleTreeScheme, { #[allow(clippy::too_many_arguments)] pub fn new( diff --git a/crates/core/src/ring_switch/tests.rs b/crates/core/src/ring_switch/tests.rs index b8c361cff..2cb0ce77d 100644 --- a/crates/core/src/ring_switch/tests.rs +++ b/crates/core/src/ring_switch/tests.rs @@ -6,8 +6,7 @@ use binius_field::{ arch::OptimalUnderlier128b, as_packed_field::{PackScalar, PackedType}, underlier::UnderlierType, - DeserializeCanonical, ExtensionField, Field, PackedField, PackedFieldIndexable, - SerializeCanonical, TowerField, + ExtensionField, Field, PackedField, PackedFieldIndexable, TowerField, }; use binius_hal::make_portable_backend; use binius_hash::compress::Groestl256ByteCompression; @@ -15,6 +14,7 @@ use binius_math::{ DefaultEvaluationDomainFactory, MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly, MultilinearQuery, }; +use binius_utils::{DeserializeBytes, SerializeBytes}; use groestl_crypto::Groestl256; use rand::prelude::*; @@ -269,7 +269,7 @@ fn commit_prove_verify_piop( Tower: TowerFamily, PackedType>: PackedFieldIndexable, FExt: PackedTop, - MTScheme: MerkleTreeScheme, Digest: SerializeCanonical + DeserializeCanonical>, + MTScheme: MerkleTreeScheme, Digest: SerializeBytes + DeserializeBytes>, MTProver: MerkleTreeProver, Scheme = MTScheme>, { let mut rng = StdRng::seed_from_u64(0); diff --git a/crates/core/src/transcript/error.rs b/crates/core/src/transcript/error.rs index 97b6754fd..ee57c376f 100644 --- a/crates/core/src/transcript/error.rs +++ b/crates/core/src/transcript/error.rs @@ -1,7 +1,5 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::serialization::Error as SerializationError; - #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Transcript is not empty, {remaining} bytes")] @@ -9,5 +7,5 @@ pub enum Error { #[error("Not enough bytes in the buffer")] NotEnoughBytes, #[error("Serialization error: {0}")] - Serialization(#[from] SerializationError), + Serialization(#[from] binius_utils::SerializationError), } diff --git a/crates/core/src/transcript/mod.rs b/crates/core/src/transcript/mod.rs index 08089985f..d014ea9a0 100644 --- a/crates/core/src/transcript/mod.rs +++ b/crates/core/src/transcript/mod.rs @@ -16,7 +16,8 @@ mod error; use std::{iter::repeat_with, slice}; -use binius_field::{DeserializeCanonical, PackedField, SerializeCanonical, TowerField}; +use binius_field::{PackedField, TowerField}; +use binius_utils::{DeserializeBytes, SerializationMode, SerializeBytes}; use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut}; pub use error::Error; use tracing::warn; @@ -257,13 +258,15 @@ impl TranscriptReader<'_, B> { self.buffer } - pub fn read(&mut self) -> Result { - T::deserialize_canonical(self.buffer()).map_err(Into::into) + pub fn read(&mut self) -> Result { + let mode = SerializationMode::CanonicalTower; + T::deserialize(self.buffer(), mode).map_err(Into::into) } - pub fn read_vec(&mut self, n: usize) -> Result, Error> { + pub fn read_vec(&mut self, n: usize) -> Result, Error> { + let mode = SerializationMode::CanonicalTower; let mut buffer = self.buffer(); - repeat_with(move || T::deserialize_canonical(&mut buffer).map_err(Into::into)) + repeat_with(move || T::deserialize(&mut buffer, mode).map_err(Into::into)) .take(n) .collect() } @@ -286,7 +289,8 @@ impl TranscriptReader<'_, B> { pub fn read_scalar_slice_into(&mut self, buf: &mut [F]) -> Result<(), Error> { let mut buffer = self.buffer(); for elem in buf { - *elem = DeserializeCanonical::deserialize_canonical(&mut buffer)?; + let mode = SerializationMode::CanonicalTower; + *elem = DeserializeBytes::deserialize(&mut buffer, mode)?; } Ok(()) } @@ -332,19 +336,19 @@ impl TranscriptWriter<'_, B> { self.buffer } - pub fn write(&mut self, value: &T) { + pub fn write(&mut self, value: &T) { self.proof_size_event_wrapper(|buffer| { value - .serialize_canonical(buffer) + .serialize(buffer, SerializationMode::CanonicalTower) .expect("TODO: propagate error"); }); } - pub fn write_slice(&mut self, values: &[T]) { + pub fn write_slice(&mut self, values: &[T]) { self.proof_size_event_wrapper(|buffer| { for value in values { value - .serialize_canonical(&mut *buffer) + .serialize(&mut *buffer, SerializationMode::CanonicalTower) .expect("TODO: propagate error"); } }); @@ -363,7 +367,7 @@ impl TranscriptWriter<'_, B> { pub fn write_scalar_slice(&mut self, elems: &[F]) { self.proof_size_event_wrapper(|buffer| { for elem in elems { - SerializeCanonical::serialize_canonical(elem, &mut *buffer) + SerializeBytes::serialize(elem, &mut *buffer, SerializationMode::CanonicalTower) .expect("TODO: propagate error"); } }); @@ -402,7 +406,8 @@ where Challenger_: Challenger, { fn sample(&mut self) -> F { - DeserializeCanonical::deserialize_canonical(self.combined.challenger.sampler()) + let mode = SerializationMode::CanonicalTower; + DeserializeBytes::deserialize(self.combined.challenger.sampler(), mode) .expect("challenger has infinite buffer") } } @@ -413,7 +418,8 @@ where Challenger_: Challenger, { fn sample(&mut self) -> F { - DeserializeCanonical::deserialize_canonical(self.combined.challenger.sampler()) + let mode = SerializationMode::CanonicalTower; + DeserializeBytes::deserialize(self.combined.challenger.sampler(), mode) .expect("challenger has infinite buffer") } } diff --git a/crates/core/src/transparent/constant.rs b/crates/core/src/transparent/constant.rs index ea7bd735a..860ced321 100644 --- a/crates/core/src/transparent/constant.rs +++ b/crates/core/src/transparent/constant.rs @@ -1,13 +1,13 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{ExtensionField, TowerField}; -use binius_macros::{erased_serialize_canonical, DeserializeCanonical, SerializeCanonical}; -use binius_utils::bail; +use binius_field::{BinaryField128b, ExtensionField, TowerField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; /// A constant polynomial. -#[derive(Debug, Copy, Clone, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Copy, Clone, SerializeBytes, DeserializeBytes)] pub struct Constant { n_vars: usize, value: F, @@ -15,12 +15,9 @@ pub struct Constant { } inventory::submit! { - >::register_deserializer( + >::register_deserializer( "Constant", - |buf: &mut dyn bytes::Buf| { - let deserialized = as binius_field::DeserializeCanonical>::deserialize_canonical(&mut *buf)?; - Ok(Box::new(deserialized)) - } + |buf, mode| Ok(Box::new(Constant::::deserialize(&mut *buf, mode)?)) ) } @@ -37,7 +34,7 @@ impl Constant { } } -#[erased_serialize_canonical] +#[erased_serialize_bytes] impl MultivariatePoly for Constant { fn n_vars(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/multilinear_extension.rs b/crates/core/src/transparent/multilinear_extension.rs index 9c59486ee..ef55e4d76 100644 --- a/crates/core/src/transparent/multilinear_extension.rs +++ b/crates/core/src/transparent/multilinear_extension.rs @@ -4,12 +4,13 @@ use std::{fmt::Debug, ops::Deref}; use binius_field::{ arch::OptimalUnderlier, as_packed_field::PackedType, packed::pack_slice, BinaryField128b, - DeserializeCanonical, ExtensionField, PackedField, RepackedExtension, SerializeCanonical, - TowerField, + BinaryField16b, BinaryField1b, BinaryField2b, BinaryField32b, BinaryField4b, BinaryField64b, + BinaryField8b, ExtensionField, PackedField, RepackedExtension, TowerField, }; use binius_hal::{make_portable_backend, ComputationBackendExt}; -use binius_macros::erased_serialize_canonical; +use binius_macros::erased_serialize_bytes; use binius_math::{MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly}; +use binius_utils::{DeserializeBytes, SerializationError, SerializationMode, SerializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -31,49 +32,50 @@ where data: MLEEmbeddingAdapter, } -impl SerializeCanonical for MultilinearExtensionTransparent +impl SerializeBytes for MultilinearExtensionTransparent where P: PackedField, PE: RepackedExtension

, PE::Scalar: TowerField + ExtensionField, Data: Deref + Debug + Send + Sync, { - fn serialize_canonical( + fn serialize( &self, write_buf: impl bytes::BufMut, - ) -> Result<(), binius_field::serialization::Error> { + mode: SerializationMode, + ) -> Result<(), SerializationError> { let elems = PE::iter_slice( self.data .packed_evals() .expect("Evals should always be available here"), ) .collect::>(); - SerializeCanonical::serialize_canonical(&elems, write_buf) + SerializeBytes::serialize(&elems, write_buf, mode) } } inventory::submit! { - >::register_deserializer( + >::register_deserializer( "MultilinearExtensionTransparent", - |buf: &mut dyn bytes::Buf| { + |buf, mode| { type U = OptimalUnderlier; type F = BinaryField128b; type P = PackedType; - let hypercube_evals: Vec = DeserializeCanonical::deserialize_canonical(&mut *buf)?; + let hypercube_evals = Vec::::deserialize(&mut *buf, mode)?; let result: Box> = if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { - Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { - Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { - Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { - Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { - Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { - Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) { - Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) + Box::new(MultilinearExtensionTransparent::, P, _>::from_values(packed_evals).unwrap()) } else { Box::new(MultilinearExtensionTransparent::::from_values(pack_slice(&hypercube_evals)).unwrap()) }; @@ -119,7 +121,7 @@ where } } -#[erased_serialize_canonical] +#[erased_serialize_bytes] impl MultivariatePoly for MultilinearExtensionTransparent where F: TowerField + ExtensionField, diff --git a/crates/core/src/transparent/powers.rs b/crates/core/src/transparent/powers.rs index 960608706..01a925a10 100644 --- a/crates/core/src/transparent/powers.rs +++ b/crates/core/src/transparent/powers.rs @@ -2,11 +2,11 @@ use std::iter::successors; -use binius_field::{PackedField, TowerField}; -use binius_macros::{erased_serialize_canonical, DeserializeCanonical, SerializeCanonical}; +use binius_field::{BinaryField128b, PackedField, TowerField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; use binius_maybe_rayon::prelude::*; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use bytemuck::zeroed_vec; use itertools::{izip, Itertools}; @@ -14,19 +14,16 @@ use crate::polynomial::{Error, MultivariatePoly}; /// A transparent multilinear polynomial whose evaluation at index $i$ is $g^i$ for /// some field element $g$. -#[derive(Debug, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, SerializeBytes, DeserializeBytes)] pub struct Powers { n_vars: usize, base: F, } inventory::submit! { - >::register_deserializer( + >::register_deserializer( "Powers", - |buf: &mut dyn bytes::Buf| { - let deserialized = as binius_field::DeserializeCanonical>::deserialize_canonical(&mut *buf)?; - Ok(Box::new(deserialized)) - } + |buf, mode| Ok(Box::new(Powers::::deserialize(&mut *buf, mode)?)) ) } @@ -60,7 +57,7 @@ impl Powers { } } -#[erased_serialize_canonical] +#[erased_serialize_bytes] impl> MultivariatePoly

for Powers { fn n_vars(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/select_row.rs b/crates/core/src/transparent/select_row.rs index ed86f7d7f..9ef3bf0e6 100644 --- a/crates/core/src/transparent/select_row.rs +++ b/crates/core/src/transparent/select_row.rs @@ -1,9 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{packed::set_packed_slice, BinaryField1b, Field, PackedField}; -use binius_macros::{erased_serialize_canonical, DeserializeCanonical, SerializeCanonical}; +use binius_field::{packed::set_packed_slice, BinaryField128b, BinaryField1b, Field, PackedField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -19,19 +19,16 @@ use crate::polynomial::{Error, MultivariatePoly}; /// ``` /// /// This is useful for defining boundary constraints -#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct SelectRow { n_vars: usize, index: usize, } inventory::submit! { - >::register_deserializer( + >::register_deserializer( "SelectRow", - |buf: &mut dyn bytes::Buf| { - let deserialized = ::deserialize_canonical(&mut *buf)?; - Ok(Box::new(deserialized)) - } + |buf, mode| Ok(Box::new(SelectRow::deserialize(&mut *buf, mode)?)) ) } @@ -61,7 +58,7 @@ impl SelectRow { } } -#[erased_serialize_canonical] +#[erased_serialize_bytes] impl MultivariatePoly for SelectRow { fn degree(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/serialization.rs b/crates/core/src/transparent/serialization.rs index bc785713a..2d5ef14f5 100644 --- a/crates/core/src/transparent/serialization.rs +++ b/crates/core/src/transparent/serialization.rs @@ -10,31 +10,34 @@ use std::{collections::HashMap, sync::LazyLock}; -use binius_field::{ - serialization::Error, BinaryField128b, DeserializeCanonical, SerializeCanonical, TowerField, -}; +use binius_field::{BinaryField128b, TowerField}; +use binius_utils::{DeserializeBytes, SerializationError, SerializationMode, SerializeBytes}; use crate::polynomial::MultivariatePoly; -impl SerializeCanonical for Box> { - fn serialize_canonical( +impl SerializeBytes for Box> { + fn serialize( &self, mut write_buf: impl bytes::BufMut, - ) -> Result<(), binius_field::serialization::Error> { - self.erased_serialize_canonical(&mut write_buf) + mode: SerializationMode, + ) -> Result<(), SerializationError> { + self.erased_serialize(&mut write_buf, mode) } } -impl DeserializeCanonical for Box> { - fn deserialize_canonical(mut read_buf: impl bytes::Buf) -> Result +impl DeserializeBytes for Box> { + fn deserialize( + mut read_buf: impl bytes::Buf, + mode: SerializationMode, + ) -> Result where Self: Sized, { - let name = String::deserialize_canonical(&mut read_buf)?; + let name = String::deserialize(&mut read_buf, mode)?; match REGISTRY.get(name.as_str()) { - Some(Some(erased_deserialize_canonical)) => erased_deserialize_canonical(&mut read_buf), - Some(None) => Err(Error::DeserializerNameConflict { name }), - None => Err(Error::DeserializerNotImplented), + Some(Some(erased_deserialize)) => erased_deserialize(&mut read_buf, mode), + Some(None) => Err(SerializationError::DeserializerNameConflict { name }), + None => Err(SerializationError::DeserializerNotImplented), } } } @@ -43,27 +46,26 @@ impl DeserializeCanonical for Box> { // This allows third party code to submit their own deserializers as well inventory::collect!(DeserializerEntry); -static REGISTRY: LazyLock< - HashMap<&'static str, Option>>, -> = LazyLock::new(|| { - let mut registry = HashMap::new(); - inventory::iter::> - .into_iter() - .for_each(|&DeserializerEntry { name, deserializer }| match registry.entry(name) { - std::collections::hash_map::Entry::Vacant(entry) => { - entry.insert(Some(deserializer)); - } - std::collections::hash_map::Entry::Occupied(mut entry) => { - entry.insert(None); - } - }); - registry -}); +static REGISTRY: LazyLock>>> = + LazyLock::new(|| { + let mut registry = HashMap::new(); + inventory::iter::> + .into_iter() + .for_each(|&DeserializerEntry { name, deserializer }| match registry.entry(name) { + std::collections::hash_map::Entry::Vacant(entry) => { + entry.insert(Some(deserializer)); + } + std::collections::hash_map::Entry::Occupied(mut entry) => { + entry.insert(None); + } + }); + registry + }); impl dyn MultivariatePoly { pub const fn register_deserializer( name: &'static str, - deserializer: ErasedDeserializeCanonical, + deserializer: ErasedDeserializeBytes, ) -> DeserializerEntry { DeserializerEntry { name, deserializer } } @@ -71,8 +73,10 @@ impl dyn MultivariatePoly { pub struct DeserializerEntry { name: &'static str, - deserializer: ErasedDeserializeCanonical, + deserializer: ErasedDeserializeBytes, } -type ErasedDeserializeCanonical = - fn(&mut dyn bytes::Buf) -> Result>, Error>; +type ErasedDeserializeBytes = fn( + &mut dyn bytes::Buf, + mode: SerializationMode, +) -> Result>, SerializationError>; diff --git a/crates/core/src/transparent/step_down.rs b/crates/core/src/transparent/step_down.rs index 7e0d6e8bb..00dec6771 100644 --- a/crates/core/src/transparent/step_down.rs +++ b/crates/core/src/transparent/step_down.rs @@ -1,9 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{Field, PackedField}; -use binius_macros::{erased_serialize_canonical, DeserializeCanonical, SerializeCanonical}; +use binius_field::{BinaryField128b, Field, PackedField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -21,19 +21,16 @@ use crate::polynomial::{Error, MultivariatePoly}; /// ``` /// /// This is useful for making constraints that are not enforced at the last rows of the trace -#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct StepDown { n_vars: usize, index: usize, } inventory::submit! { - >::register_deserializer( + >::register_deserializer( "StepDown", - |buf: &mut dyn bytes::Buf| { - let deserialized = ::deserialize_canonical(&mut *buf)?; - Ok(Box::new(deserialized)) - } + |buf, mode| Ok(Box::new(StepDown::deserialize(&mut *buf, mode)?)) ) } @@ -79,7 +76,7 @@ impl StepDown { } } -#[erased_serialize_canonical] +#[erased_serialize_bytes] impl MultivariatePoly for StepDown { fn degree(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/step_up.rs b/crates/core/src/transparent/step_up.rs index e764d0428..ad022df9a 100644 --- a/crates/core/src/transparent/step_up.rs +++ b/crates/core/src/transparent/step_up.rs @@ -1,9 +1,9 @@ // Copyright 2024-2025 Irreducible Inc. -use binius_field::{Field, PackedField}; -use binius_macros::{erased_serialize_canonical, DeserializeCanonical, SerializeCanonical}; +use binius_field::{BinaryField128b, Field, PackedField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -21,19 +21,16 @@ use crate::polynomial::{Error, MultivariatePoly}; /// ``` /// /// This is useful for making constraints that are not enforced at the first rows of the trace -#[derive(Debug, Clone, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] pub struct StepUp { n_vars: usize, index: usize, } inventory::submit! { - >::register_deserializer( + >::register_deserializer( "StepUp", - |buf: &mut dyn bytes::Buf| { - let deserialized = ::deserialize_canonical(&mut *buf)?; - Ok(Box::new(deserialized)) - } + |buf, mode| Ok(Box::new(StepUp::deserialize(&mut *buf, mode)?)) ) } @@ -75,7 +72,7 @@ impl StepUp { } } -#[erased_serialize_canonical] +#[erased_serialize_bytes] impl MultivariatePoly for StepUp { fn degree(&self) -> usize { self.n_vars diff --git a/crates/core/src/transparent/tower_basis.rs b/crates/core/src/transparent/tower_basis.rs index 471ef4e6c..8b8d32bf0 100644 --- a/crates/core/src/transparent/tower_basis.rs +++ b/crates/core/src/transparent/tower_basis.rs @@ -2,10 +2,10 @@ use std::marker::PhantomData; -use binius_field::{Field, PackedField, TowerField}; -use binius_macros::{erased_serialize_canonical, DeserializeCanonical, SerializeCanonical}; +use binius_field::{BinaryField128b, Field, PackedField, TowerField}; +use binius_macros::{erased_serialize_bytes, DeserializeBytes, SerializeBytes}; use binius_math::MultilinearExtension; -use binius_utils::bail; +use binius_utils::{bail, DeserializeBytes}; use crate::polynomial::{Error, MultivariatePoly}; @@ -21,7 +21,7 @@ use crate::polynomial::{Error, MultivariatePoly}; /// /// Thus, $\mathcal{T}_{\iota+k}$ has a $\mathcal{T}_{\iota}$-basis of size $2^k$: /// * $1, X_{\iota}, X_{\iota+1}, X_{\iota}X_{\iota+1}, X_{\iota+2}, \ldots, X_{\iota} X_{\iota+1} \ldots X_{\iota+k-1}$ -#[derive(Debug, Copy, Clone, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Copy, Clone, SerializeBytes, DeserializeBytes)] pub struct TowerBasis { k: usize, iota: usize, @@ -29,12 +29,9 @@ pub struct TowerBasis { } inventory::submit! { - >::register_deserializer( + >::register_deserializer( "TowerBasis", - |buf: &mut dyn bytes::Buf| { - let deserialized = as binius_field::DeserializeCanonical>::deserialize_canonical(&mut *buf)?; - Ok(Box::new(deserialized)) - } + |buf, mode| Ok(Box::new(TowerBasis::::deserialize(&mut *buf, mode)?)) ) } @@ -73,7 +70,7 @@ impl TowerBasis { } } -#[erased_serialize_canonical] +#[erased_serialize_bytes] impl MultivariatePoly for TowerBasis where F: TowerField, diff --git a/crates/field/Cargo.toml b/crates/field/Cargo.toml index 7b0a1b421..11903020b 100644 --- a/crates/field/Cargo.toml +++ b/crates/field/Cargo.toml @@ -11,10 +11,8 @@ workspace = true binius_maybe_rayon = { path = "../maybe_rayon", default-features = false } binius_utils = { path = "../utils", default-features = false } bytemuck.workspace = true -bytes.workspace = true cfg-if.workspace = true derive_more.workspace = true -generic-array.workspace = true rand.workspace = true seq-macro.workspace = true subtle.workspace = true diff --git a/crates/field/src/aes_field.rs b/crates/field/src/aes_field.rs index f2d480bff..74f7e76df 100644 --- a/crates/field/src/aes_field.rs +++ b/crates/field/src/aes_field.rs @@ -8,6 +8,10 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +use binius_utils::{ + bytes::{Buf, BufMut}, + DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, +}; use bytemuck::{Pod, Zeroable}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; @@ -283,18 +287,61 @@ impl_tower_field_conversion!(AESTowerField32b, BinaryField32b); impl_tower_field_conversion!(AESTowerField64b, BinaryField64b); impl_tower_field_conversion!(AESTowerField128b, BinaryField128b); +macro_rules! serialize_deserialize_non_canonical { + ($field:ident, canonical=$canonical:ident) => { + impl SerializeBytes for $field { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + match mode { + SerializationMode::Native => self.0.serialize(write_buf, mode), + SerializationMode::CanonicalTower => { + $canonical::from(*self).serialize(write_buf, mode) + } + } + } + } + + impl DeserializeBytes for $field { + fn deserialize( + read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + match mode { + SerializationMode::Native => { + Ok(Self(DeserializeBytes::deserialize(read_buf, mode)?)) + } + SerializationMode::CanonicalTower => { + Ok(Self::from($canonical::deserialize(read_buf, mode)?)) + } + } + } + } + }; +} + +serialize_deserialize_non_canonical!(AESTowerField8b, canonical = BinaryField8b); +serialize_deserialize_non_canonical!(AESTowerField16b, canonical = BinaryField16b); +serialize_deserialize_non_canonical!(AESTowerField32b, canonical = BinaryField32b); +serialize_deserialize_non_canonical!(AESTowerField64b, canonical = BinaryField64b); +serialize_deserialize_non_canonical!(AESTowerField128b, canonical = BinaryField128b); + #[cfg(test)] mod tests { - use bytes::BytesMut; + use binius_utils::{bytes::BytesMut, SerializationMode, SerializeBytes}; use proptest::{arbitrary::any, proptest}; use rand::thread_rng; use super::*; use crate::{ binary_field::tests::is_binary_field_valid_generator, underlier::WithUnderlier, - DeserializeCanonical, PackedAESBinaryField16x32b, PackedAESBinaryField4x32b, - PackedAESBinaryField8x32b, PackedBinaryField16x32b, PackedBinaryField4x32b, - PackedBinaryField8x32b, SerializeCanonical, + PackedAESBinaryField16x32b, PackedAESBinaryField4x32b, PackedAESBinaryField8x32b, + PackedBinaryField16x32b, PackedBinaryField4x32b, PackedBinaryField8x32b, }; fn check_square(f: impl Field) { @@ -593,22 +640,24 @@ mod tests { let aes64 = ::random(&mut rng); let aes128 = ::random(&mut rng); - SerializeCanonical::serialize_canonical(&aes8, &mut buffer).unwrap(); - SerializeCanonical::serialize_canonical(&aes16, &mut buffer).unwrap(); - SerializeCanonical::serialize_canonical(&aes32, &mut buffer).unwrap(); - SerializeCanonical::serialize_canonical(&aes64, &mut buffer).unwrap(); - SerializeCanonical::serialize_canonical(&aes128, &mut buffer).unwrap(); + let mode = SerializationMode::CanonicalTower; + + SerializeBytes::serialize(&aes8, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&aes16, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&aes32, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&aes64, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&aes128, &mut buffer, mode).unwrap(); - SerializeCanonical::serialize_canonical(&aes128, &mut buffer).unwrap(); + SerializeBytes::serialize(&aes128, &mut buffer, mode).unwrap(); let mut read_buffer = buffer.freeze(); - assert_eq!(AESTowerField8b::deserialize_canonical(&mut read_buffer).unwrap(), aes8); - assert_eq!(AESTowerField16b::deserialize_canonical(&mut read_buffer).unwrap(), aes16); - assert_eq!(AESTowerField32b::deserialize_canonical(&mut read_buffer).unwrap(), aes32); - assert_eq!(AESTowerField64b::deserialize_canonical(&mut read_buffer).unwrap(), aes64); - assert_eq!(AESTowerField128b::deserialize_canonical(&mut read_buffer).unwrap(), aes128); + assert_eq!(AESTowerField8b::deserialize(&mut read_buffer, mode).unwrap(), aes8); + assert_eq!(AESTowerField16b::deserialize(&mut read_buffer, mode).unwrap(), aes16); + assert_eq!(AESTowerField32b::deserialize(&mut read_buffer, mode).unwrap(), aes32); + assert_eq!(AESTowerField64b::deserialize(&mut read_buffer, mode).unwrap(), aes64); + assert_eq!(AESTowerField128b::deserialize(&mut read_buffer, mode).unwrap(), aes128); - assert_eq!(BinaryField128b::deserialize_canonical(&mut read_buffer).unwrap(), aes128.into()) + assert_eq!(BinaryField128b::deserialize(&mut read_buffer, mode).unwrap(), aes128.into()) } } diff --git a/crates/field/src/binary_field.rs b/crates/field/src/binary_field.rs index 6f2c6e6f0..31f30f32d 100644 --- a/crates/field/src/binary_field.rs +++ b/crates/field/src/binary_field.rs @@ -7,8 +7,11 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +use binius_utils::{ + bytes::{Buf, BufMut}, + DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, +}; use bytemuck::{Pod, Zeroable}; -use bytes::{Buf, BufMut}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; @@ -16,8 +19,7 @@ use super::{ binary_field_arithmetic::TowerFieldArithmetic, error::Error, extension::ExtensionField, }; use crate::{ - serialization::{DeserializeBytes, Error as SerializationError, SerializeBytes}, - underlier::{SmallU, U1, U2, U4}, + underlier::{U1, U2, U4}, Field, }; @@ -733,60 +735,36 @@ pub fn is_canonical_tower() -> bool { } macro_rules! serialize_deserialize { - ($bin_type:ty, SmallU<$U:literal>) => { + ($bin_type:ty) => { impl SerializeBytes for $bin_type { - fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> { - if write_buf.remaining_mut() < 1 { - ::binius_utils::bail!(SerializationError::WriteBufferFull); - } - let b = self.0.val(); - write_buf.put_u8(b); - Ok(()) + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + self.0.serialize(write_buf, mode) } } impl DeserializeBytes for $bin_type { - fn deserialize(mut read_buf: impl Buf) -> Result { - if read_buf.remaining() < 1 { - ::binius_utils::bail!(SerializationError::NotEnoughBytes); - } - let b: u8 = read_buf.get_u8(); - Ok(Self(SmallU::<$U>::new(b))) - } - } - }; - ($bin_type:ty, $inner_type:ty) => { - impl SerializeBytes for $bin_type { - fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> { - if write_buf.remaining_mut() < (<$inner_type>::BITS / 8) as usize { - ::binius_utils::bail!(SerializationError::WriteBufferFull); - } - write_buf.put_slice(&self.0.to_le_bytes()); - Ok(()) - } - } - - impl DeserializeBytes for $bin_type { - fn deserialize(mut read_buf: impl Buf) -> Result { - let mut inner = <$inner_type>::default().to_le_bytes(); - if read_buf.remaining() < inner.len() { - ::binius_utils::bail!(SerializationError::NotEnoughBytes); - } - read_buf.copy_to_slice(&mut inner); - Ok(Self(<$inner_type>::from_le_bytes(inner))) + fn deserialize( + read_buf: impl Buf, + mode: SerializationMode, + ) -> Result { + Ok(Self(DeserializeBytes::deserialize(read_buf, mode)?)) } } }; } -serialize_deserialize!(BinaryField1b, SmallU<1>); -serialize_deserialize!(BinaryField2b, SmallU<2>); -serialize_deserialize!(BinaryField4b, SmallU<4>); -serialize_deserialize!(BinaryField8b, u8); -serialize_deserialize!(BinaryField16b, u16); -serialize_deserialize!(BinaryField32b, u32); -serialize_deserialize!(BinaryField64b, u64); -serialize_deserialize!(BinaryField128b, u128); +serialize_deserialize!(BinaryField1b); +serialize_deserialize!(BinaryField2b); +serialize_deserialize!(BinaryField4b); +serialize_deserialize!(BinaryField8b); +serialize_deserialize!(BinaryField16b); +serialize_deserialize!(BinaryField32b); +serialize_deserialize!(BinaryField64b); +serialize_deserialize!(BinaryField128b); impl From for Choice { fn from(val: BinaryField1b) -> Self { @@ -877,7 +855,7 @@ impl From for u8 { #[cfg(test)] pub(crate) mod tests { - use bytes::BytesMut; + use binius_utils::{bytes::BytesMut, SerializationMode}; use proptest::prelude::*; use super::{ @@ -1246,6 +1224,7 @@ pub(crate) mod tests { #[test] fn test_serialization() { + let mode = SerializationMode::CanonicalTower; let mut buffer = BytesMut::new(); let b1 = BinaryField1b::from(0x1); let b8 = BinaryField8b::new(0x12); @@ -1256,25 +1235,25 @@ pub(crate) mod tests { let b64 = BinaryField64b::new(0x13579BDF02468ACE); let b128 = BinaryField128b::new(0x147AD0369CF258BE8899AABBCCDDEEFF); - b1.serialize(&mut buffer).unwrap(); - b8.serialize(&mut buffer).unwrap(); - b2.serialize(&mut buffer).unwrap(); - b16.serialize(&mut buffer).unwrap(); - b32.serialize(&mut buffer).unwrap(); - b4.serialize(&mut buffer).unwrap(); - b64.serialize(&mut buffer).unwrap(); - b128.serialize(&mut buffer).unwrap(); + b1.serialize(&mut buffer, mode).unwrap(); + b8.serialize(&mut buffer, mode).unwrap(); + b2.serialize(&mut buffer, mode).unwrap(); + b16.serialize(&mut buffer, mode).unwrap(); + b32.serialize(&mut buffer, mode).unwrap(); + b4.serialize(&mut buffer, mode).unwrap(); + b64.serialize(&mut buffer, mode).unwrap(); + b128.serialize(&mut buffer, mode).unwrap(); let mut read_buffer = buffer.freeze(); - assert_eq!(BinaryField1b::deserialize(&mut read_buffer).unwrap(), b1); - assert_eq!(BinaryField8b::deserialize(&mut read_buffer).unwrap(), b8); - assert_eq!(BinaryField2b::deserialize(&mut read_buffer).unwrap(), b2); - assert_eq!(BinaryField16b::deserialize(&mut read_buffer).unwrap(), b16); - assert_eq!(BinaryField32b::deserialize(&mut read_buffer).unwrap(), b32); - assert_eq!(BinaryField4b::deserialize(&mut read_buffer).unwrap(), b4); - assert_eq!(BinaryField64b::deserialize(&mut read_buffer).unwrap(), b64); - assert_eq!(BinaryField128b::deserialize(&mut read_buffer).unwrap(), b128); + assert_eq!(BinaryField1b::deserialize(&mut read_buffer, mode).unwrap(), b1); + assert_eq!(BinaryField8b::deserialize(&mut read_buffer, mode).unwrap(), b8); + assert_eq!(BinaryField2b::deserialize(&mut read_buffer, mode).unwrap(), b2); + assert_eq!(BinaryField16b::deserialize(&mut read_buffer, mode).unwrap(), b16); + assert_eq!(BinaryField32b::deserialize(&mut read_buffer, mode).unwrap(), b32); + assert_eq!(BinaryField4b::deserialize(&mut read_buffer, mode).unwrap(), b4); + assert_eq!(BinaryField64b::deserialize(&mut read_buffer, mode).unwrap(), b64); + assert_eq!(BinaryField128b::deserialize(&mut read_buffer, mode).unwrap(), b128); } #[test] diff --git a/crates/field/src/field.rs b/crates/field/src/field.rs index 10395e9c2..4a5469516 100644 --- a/crates/field/src/field.rs +++ b/crates/field/src/field.rs @@ -7,6 +7,7 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; +use binius_utils::{DeserializeBytes, SerializeBytes}; use rand::RngCore; use crate::{ @@ -49,6 +50,8 @@ pub trait Field: + InvertOrZero // `Underlier: PackScalar` is an obvious property but it can't be deduced by the compiler so we are id here. + WithUnderlier> + + SerializeBytes + + DeserializeBytes { /// The zero element of the field, the additive identity. const ZERO: Self; diff --git a/crates/field/src/lib.rs b/crates/field/src/lib.rs index f09a44b53..76414c038 100644 --- a/crates/field/src/lib.rs +++ b/crates/field/src/lib.rs @@ -33,7 +33,6 @@ pub mod packed_extension; pub mod packed_extension_ops; mod packed_polyval; pub mod polyval; -pub mod serialization; #[cfg(test)] mod tests; pub mod tower_levels; @@ -45,7 +44,6 @@ pub mod util; pub use aes_field::*; pub use arch::byte_sliced::*; pub use binary_field::*; -pub use bytes; pub use error::*; pub use extension::*; pub use field::Field; @@ -56,5 +54,4 @@ pub use packed_extension::*; pub use packed_extension_ops::*; pub use packed_polyval::*; pub use polyval::*; -pub use serialization::{DeserializeCanonical, SerializeCanonical}; pub use transpose::{square_transpose, transpose_scalars, Error as TransposeError}; diff --git a/crates/field/src/polyval.rs b/crates/field/src/polyval.rs index bdc15111f..cd3d944a6 100644 --- a/crates/field/src/polyval.rs +++ b/crates/field/src/polyval.rs @@ -9,7 +9,11 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; -use binius_utils::iter::IterExtensions; +use binius_utils::{ + bytes::{Buf, BufMut}, + iter::IterExtensions, + DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, +}; use bytemuck::{Pod, TransparentWrapper, Zeroable}; use rand::{Rng, RngCore}; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; @@ -450,6 +454,35 @@ impl ExtensionField for BinaryField128bPolyval { } } +impl SerializeBytes for BinaryField128bPolyval { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + match mode { + SerializationMode::Native => self.0.serialize(write_buf, mode), + SerializationMode::CanonicalTower => { + BinaryField128b::from(*self).serialize(write_buf, mode) + } + } + } +} + +impl DeserializeBytes for BinaryField128bPolyval { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + match mode { + SerializationMode::Native => Ok(Self(DeserializeBytes::deserialize(read_buf, mode)?)), + SerializationMode::CanonicalTower => { + Ok(Self::from(BinaryField128b::deserialize(read_buf, mode)?)) + } + } + } +} + impl BinaryField for BinaryField128bPolyval { const MULTIPLICATIVE_GENERATOR: Self = Self(0x72bdf2504ce49c03105433c1c25a4a7); } @@ -1031,7 +1064,7 @@ pub fn is_polyval_tower() -> bool { #[cfg(test)] mod tests { - use bytes::BytesMut; + use binius_utils::{bytes::BytesMut, SerializationMode, SerializeBytes}; use proptest::prelude::*; use rand::thread_rng; @@ -1043,9 +1076,9 @@ mod tests { }, binary_field::tests::is_binary_field_valid_generator, linear_transformation::PackedTransformationFactory, - AESTowerField128b, DeserializeCanonical, PackedAESBinaryField1x128b, - PackedAESBinaryField2x128b, PackedAESBinaryField4x128b, PackedBinaryField1x128b, - PackedBinaryField2x128b, PackedBinaryField4x128b, PackedField, SerializeCanonical, + AESTowerField128b, PackedAESBinaryField1x128b, PackedAESBinaryField2x128b, + PackedAESBinaryField4x128b, PackedBinaryField1x128b, PackedBinaryField2x128b, + PackedBinaryField4x128b, PackedField, }; #[test] @@ -1188,23 +1221,25 @@ mod tests { #[test] fn test_canonical_serialization() { + let mode = SerializationMode::CanonicalTower; let mut buffer = BytesMut::new(); let mut rng = thread_rng(); let b128_poly1 = ::random(&mut rng); let b128_poly2 = ::random(&mut rng); - SerializeCanonical::serialize_canonical(&b128_poly1, &mut buffer).unwrap(); - SerializeCanonical::serialize_canonical(&b128_poly2, &mut buffer).unwrap(); + SerializeBytes::serialize(&b128_poly1, &mut buffer, mode).unwrap(); + SerializeBytes::serialize(&b128_poly2, &mut buffer, mode).unwrap(); + let mode = SerializationMode::CanonicalTower; let mut read_buffer = buffer.freeze(); assert_eq!( - BinaryField128bPolyval::deserialize_canonical(&mut read_buffer).unwrap(), + BinaryField128bPolyval::deserialize(&mut read_buffer, mode).unwrap(), b128_poly1 ); assert_eq!( - BinaryField128bPolyval::deserialize_canonical(&mut read_buffer).unwrap(), + BinaryField128bPolyval::deserialize(&mut read_buffer, mode).unwrap(), b128_poly2 ); } diff --git a/crates/field/src/serialization/bytes.rs b/crates/field/src/serialization/bytes.rs deleted file mode 100644 index 59335f69d..000000000 --- a/crates/field/src/serialization/bytes.rs +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2024-2025 Irreducible Inc. - -use bytes::{Buf, BufMut}; -use generic_array::{ArrayLength, GenericArray}; - -use super::Error; - -/// Represents type that can be serialized to a byte buffer. -pub trait SerializeBytes { - fn serialize(&self, write_buf: impl BufMut) -> Result<(), Error>; -} - -/// Represents type that can be deserialized from a byte buffer. -pub trait DeserializeBytes { - fn deserialize(read_buf: impl Buf) -> Result - where - Self: Sized; -} - -impl> SerializeBytes for GenericArray { - fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - if write_buf.remaining_mut() < N::USIZE { - return Err(Error::WriteBufferFull); - } - write_buf.put_slice(self); - Ok(()) - } -} - -impl> DeserializeBytes for GenericArray { - fn deserialize(mut read_buf: impl Buf) -> Result { - if read_buf.remaining() < N::USIZE { - return Err(Error::NotEnoughBytes); - } - - let mut ret = Self::default(); - read_buf.copy_to_slice(&mut ret); - Ok(ret) - } -} - -#[cfg(test)] -mod tests { - use generic_array::typenum::U32; - use rand::{rngs::StdRng, RngCore, SeedableRng}; - - use super::*; - - #[test] - fn test_generic_array_serialize_deserialize() { - let mut rng = StdRng::seed_from_u64(0); - - let mut data = GenericArray::::default(); - rng.fill_bytes(&mut data); - - let mut buf = Vec::new(); - data.serialize(&mut buf).unwrap(); - - let data_deserialized = GenericArray::::deserialize(&mut buf.as_slice()).unwrap(); - assert_eq!(data_deserialized, data); - } -} diff --git a/crates/field/src/serialization/canonical.rs b/crates/field/src/serialization/canonical.rs deleted file mode 100644 index 6d7a85e4c..000000000 --- a/crates/field/src/serialization/canonical.rs +++ /dev/null @@ -1,292 +0,0 @@ -// Copyright 2025 Irreducible Inc. - -use bytes::{Buf, BufMut}; -use generic_array::{ArrayLength, GenericArray}; - -use super::{DeserializeBytes, Error, SerializeBytes}; -use crate::TowerField; - -/// Serialization where [`TowerField`] elements are written with canonical encoding. -pub trait SerializeCanonical { - fn serialize_canonical(&self, write_buf: impl BufMut) -> Result<(), Error>; -} - -/// Deserialization where [`TowerField`] elements are read with a canonical encoding. -pub trait DeserializeCanonical { - fn deserialize_canonical(read_buf: impl Buf) -> Result - where - Self: Sized; -} - -impl SerializeCanonical for F { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - SerializeBytes::serialize(&F::Canonical::from(*self), &mut write_buf) - } -} - -impl DeserializeCanonical for F { - fn deserialize_canonical(read_buf: impl Buf) -> Result - where - Self: Sized, - { - let canonical: F::Canonical = DeserializeBytes::deserialize(read_buf)?; - Ok(F::from(canonical)) - } -} - -impl SerializeCanonical for usize { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - SerializeCanonical::serialize_canonical(&(*self as u64), &mut write_buf) - } -} - -impl DeserializeCanonical for usize { - fn deserialize_canonical(mut read_buf: impl Buf) -> Result - where - Self: Sized, - { - let value: u64 = DeserializeCanonical::deserialize_canonical(&mut read_buf)?; - Ok(value as Self) - } -} - -impl SerializeCanonical for u128 { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - assert_enough_space_for(&write_buf, std::mem::size_of::())?; - write_buf.put_u128(*self); - Ok(()) - } -} - -impl DeserializeCanonical for u128 { - fn deserialize_canonical(mut read_buf: impl Buf) -> Result - where - Self: Sized, - { - assert_enough_data_for(&read_buf, std::mem::size_of::())?; - Ok(read_buf.get_u128()) - } -} - -impl SerializeCanonical for u64 { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - assert_enough_space_for(&write_buf, std::mem::size_of::())?; - write_buf.put_u64(*self); - Ok(()) - } -} - -impl DeserializeCanonical for u64 { - fn deserialize_canonical(mut read_buf: impl Buf) -> Result - where - Self: Sized, - { - assert_enough_data_for(&read_buf, std::mem::size_of::())?; - Ok(read_buf.get_u64()) - } -} - -impl SerializeCanonical for u32 { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - assert_enough_space_for(&write_buf, std::mem::size_of::())?; - write_buf.put_u32(*self); - Ok(()) - } -} - -impl DeserializeCanonical for u32 { - fn deserialize_canonical(mut read_buf: impl Buf) -> Result - where - Self: Sized, - { - assert_enough_data_for(&read_buf, std::mem::size_of::())?; - Ok(read_buf.get_u32()) - } -} - -impl SerializeCanonical for u16 { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - assert_enough_space_for(&write_buf, std::mem::size_of::())?; - write_buf.put_u16(*self); - Ok(()) - } -} - -impl DeserializeCanonical for u16 { - fn deserialize_canonical(mut read_buf: impl Buf) -> Result - where - Self: Sized, - { - assert_enough_data_for(&read_buf, std::mem::size_of::())?; - Ok(read_buf.get_u16()) - } -} - -impl SerializeCanonical for u8 { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - assert_enough_space_for(&write_buf, std::mem::size_of::())?; - write_buf.put_u8(*self); - Ok(()) - } -} - -impl DeserializeCanonical for u8 { - fn deserialize_canonical(mut read_buf: impl Buf) -> Result - where - Self: Sized, - { - assert_enough_data_for(&read_buf, std::mem::size_of::())?; - Ok(read_buf.get_u8()) - } -} - -impl SerializeCanonical for bool { - fn serialize_canonical(&self, write_buf: impl BufMut) -> Result<(), Error> { - u8::serialize_canonical(&(*self as u8), write_buf) - } -} - -impl DeserializeCanonical for bool { - fn deserialize_canonical(read_buf: impl Buf) -> Result - where - Self: Sized, - { - Ok(u8::deserialize_canonical(read_buf)? != 0) - } -} - -impl SerializeCanonical for std::marker::PhantomData { - fn serialize_canonical(&self, _write_buf: impl BufMut) -> Result<(), Error> { - Ok(()) - } -} - -impl DeserializeCanonical for std::marker::PhantomData { - fn deserialize_canonical(_read_buf: impl Buf) -> Result - where - Self: Sized, - { - Ok(Self) - } -} - -impl SerializeCanonical for &str { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - let bytes = self.as_bytes(); - SerializeCanonical::serialize_canonical(&bytes.len(), &mut write_buf)?; - assert_enough_space_for(&write_buf, bytes.len())?; - write_buf.put_slice(bytes); - Ok(()) - } -} - -impl SerializeCanonical for String { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - SerializeCanonical::serialize_canonical(&self.as_str(), &mut write_buf) - } -} - -impl DeserializeCanonical for String { - fn deserialize_canonical(mut read_buf: impl Buf) -> Result - where - Self: Sized, - { - let len = DeserializeCanonical::deserialize_canonical(&mut read_buf)?; - assert_enough_data_for(&read_buf, len)?; - Ok(Self::from_utf8(read_buf.copy_to_bytes(len).to_vec())?) - } -} - -impl SerializeCanonical for Vec { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - SerializeCanonical::serialize_canonical(&self.len(), &mut write_buf)?; - self.iter() - .try_for_each(|item| SerializeCanonical::serialize_canonical(item, &mut write_buf)) - } -} - -impl DeserializeCanonical for Vec { - fn deserialize_canonical(mut read_buf: impl Buf) -> Result - where - Self: Sized, - { - let len: usize = DeserializeCanonical::deserialize_canonical(&mut read_buf)?; - (0..len) - .map(|_| DeserializeCanonical::deserialize_canonical(&mut read_buf)) - .collect() - } -} - -impl SerializeCanonical for Option { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - match self { - Some(value) => { - SerializeCanonical::serialize_canonical(&true, &mut write_buf)?; - SerializeCanonical::serialize_canonical(value, &mut write_buf)?; - } - None => { - SerializeCanonical::serialize_canonical(&false, write_buf)?; - } - } - Ok(()) - } -} - -impl DeserializeCanonical for Option { - fn deserialize_canonical(mut read_buf: impl Buf) -> Result - where - Self: Sized, - { - Ok(match bool::deserialize_canonical(&mut read_buf)? { - true => Some(T::deserialize_canonical(&mut read_buf)?), - false => None, - }) - } -} - -impl SerializeCanonical for (U, V) { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - U::serialize_canonical(&self.0, &mut write_buf)?; - V::serialize_canonical(&self.1, write_buf) - } -} - -impl DeserializeCanonical for (U, V) { - fn deserialize_canonical(mut read_buf: impl Buf) -> Result - where - Self: Sized, - { - Ok((U::deserialize_canonical(&mut read_buf)?, V::deserialize_canonical(read_buf)?)) - } -} - -impl> SerializeCanonical for GenericArray { - fn serialize_canonical(&self, mut write_buf: impl BufMut) -> Result<(), Error> { - assert_enough_space_for(&write_buf, N::USIZE)?; - write_buf.put_slice(self); - Ok(()) - } -} - -impl> DeserializeCanonical for GenericArray { - fn deserialize_canonical(mut read_buf: impl Buf) -> Result { - assert_enough_data_for(&read_buf, N::USIZE)?; - let mut ret = Self::default(); - read_buf.copy_to_slice(&mut ret); - Ok(ret) - } -} - -fn assert_enough_space_for(write_buf: &impl BufMut, size: usize) -> Result<(), Error> { - if write_buf.remaining_mut() < size { - return Err(Error::WriteBufferFull); - } - Ok(()) -} - -fn assert_enough_data_for(read_buf: &impl Buf, size: usize) -> Result<(), Error> { - if read_buf.remaining() < size { - return Err(Error::NotEnoughBytes); - } - Ok(()) -} diff --git a/crates/field/src/serialization/error.rs b/crates/field/src/serialization/error.rs deleted file mode 100644 index 2fa17e5e2..000000000 --- a/crates/field/src/serialization/error.rs +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2024-2025 Irreducible Inc. - -#[derive(Clone, thiserror::Error, Debug)] -pub enum Error { - #[error("Write buffer is full")] - WriteBufferFull, - #[error("Not enough data in read buffer to deserialize")] - NotEnoughBytes, - #[error("Unknown enum variant index {name}::{index}")] - UnknownEnumVariant { name: &'static str, index: u8 }, - #[error("Serialization has not been implemented")] - SerializationNotImplemented, - #[error("Deserializer has not been implemented")] - DeserializerNotImplented, - #[error("Multiple deserializers with the same name {name} has been registered")] - DeserializerNameConflict { name: String }, - #[error("FromUtf8Error: {0}")] - FromUtf8Error(#[from] std::string::FromUtf8Error), -} diff --git a/crates/field/src/serialization/mod.rs b/crates/field/src/serialization/mod.rs deleted file mode 100644 index 99e080769..000000000 --- a/crates/field/src/serialization/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright 2024-2025 Irreducible Inc. - -mod bytes; -mod canonical; -mod error; - -pub use bytes::{DeserializeBytes, SerializeBytes}; -pub use canonical::{DeserializeCanonical, SerializeCanonical}; -pub use error::Error; diff --git a/crates/field/src/underlier/small_uint.rs b/crates/field/src/underlier/small_uint.rs index 2e99f211c..33a1e9b1c 100644 --- a/crates/field/src/underlier/small_uint.rs +++ b/crates/field/src/underlier/small_uint.rs @@ -6,7 +6,12 @@ use std::{ ops::{Not, Shl, Shr}, }; -use binius_utils::checked_arithmetics::checked_log_2; +use binius_utils::{ + bytes::{Buf, BufMut}, + checked_arithmetics::checked_log_2, + serialization::DeserializeBytes, + SerializationError, SerializationMode, SerializeBytes, +}; use bytemuck::{NoUninit, Zeroable}; use derive_more::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign}; use rand::{ @@ -222,3 +227,22 @@ impl From> for SmallU<4> { pub type U1 = SmallU<1>; pub type U2 = SmallU<2>; pub type U4 = SmallU<4>; + +impl SerializeBytes for SmallU { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + self.val().serialize(write_buf, mode) + } +} + +impl DeserializeBytes for SmallU { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + Ok(Self::new(DeserializeBytes::deserialize(read_buf, mode)?)) + } +} diff --git a/crates/hash/src/serialization.rs b/crates/hash/src/serialization.rs index d2a3699e6..f1292c646 100644 --- a/crates/hash/src/serialization.rs +++ b/crates/hash/src/serialization.rs @@ -2,7 +2,7 @@ use std::{borrow::Borrow, cmp::min}; -use binius_field::SerializeCanonical; +use binius_utils::{SerializationMode, SerializeBytes}; use bytes::{buf::UninitSlice, BufMut}; use digest::{ core_api::{Block, BlockSizeUser}, @@ -11,7 +11,7 @@ use digest::{ /// Adapter that wraps a [`Digest`] references and exposes the [`BufMut`] interface. /// -/// This adapter is useful so that structs that implement [`SerializeCanonical`] can be serialized +/// This adapter is useful so that structs that implement [`SerializeBytes`] can be serialized /// directly to a hasher. #[derive(Debug)] pub struct HashBuffer<'a, D: Digest + BlockSizeUser> { @@ -67,7 +67,7 @@ impl Drop for HashBuffer<'_, D> { /// Hashes a sequence of serializable items. pub fn hash_serialize(items: impl IntoIterator>) -> Output where - T: SerializeCanonical, + T: SerializeBytes, D: Digest + BlockSizeUser, { let mut hasher = D::new(); @@ -75,7 +75,7 @@ where let mut buffer = HashBuffer::new(&mut hasher); for item in items { item.borrow() - .serialize_canonical(&mut buffer) + .serialize(&mut buffer, SerializationMode::CanonicalTower) .expect("HashBuffer has infinite capacity"); } } diff --git a/crates/macros/Cargo.toml b/crates/macros/Cargo.toml index c0293a3cc..c3b8defb1 100644 --- a/crates/macros/Cargo.toml +++ b/crates/macros/Cargo.toml @@ -16,6 +16,7 @@ proc-macro2.workspace = true binius_core = { path = "../core" } binius_field = { path = "../field" } binius_math = { path = "../math" } +binius_utils = { path = "../utils" } paste.workspace = true rand.workspace = true diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index 4bde1bf1b..f717405e1 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -74,11 +74,11 @@ pub fn arith_circuit_poly(input: TokenStream) -> TokenStream { .into() } -/// Derives the trait binius_field::SerializeCanonical for a struct or enum +/// Derives the trait binius_utils::DeserializeBytes for a struct or enum /// -/// See the DeserializeCanonical derive macro docs for examples/tests -#[proc_macro_derive(SerializeCanonical)] -pub fn derive_serialize_canonical(input: TokenStream) -> TokenStream { +/// See the DeserializeBytes derive macro docs for examples/tests +#[proc_macro_derive(SerializeBytes)] +pub fn derive_serialize_bytes(input: TokenStream) -> TokenStream { let input: DeriveInput = parse_macro_input!(input); let span = input.span(); let name = input.ident; @@ -86,7 +86,7 @@ pub fn derive_serialize_canonical(input: TokenStream) -> TokenStream { generics.type_params_mut().for_each(|type_param| { type_param .bounds - .push(parse_quote!(binius_field::SerializeCanonical)) + .push(parse_quote!(binius_utils::SerializeBytes)) }); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let body = match input.data { @@ -94,7 +94,7 @@ pub fn derive_serialize_canonical(input: TokenStream) -> TokenStream { Data::Struct(data) => { let fields = field_names(data.fields, None); quote! { - #(binius_field::SerializeCanonical::serialize_canonical(&self.#fields, &mut write_buf)?;)* + #(binius_utils::SerializeBytes::serialize(&self.#fields, &mut write_buf, mode)?;)* } } Data::Enum(data) => { @@ -107,8 +107,8 @@ pub fn derive_serialize_canonical(input: TokenStream) -> TokenStream { let variant_index = i as u8; let fields = field_names(variant.fields.clone(), Some("field_")); let serialize_variant = quote! { - binius_field::SerializeCanonical::serialize_canonical(&#variant_index, &mut write_buf)?; - #(binius_field::SerializeCanonical::serialize_canonical(#fields, &mut write_buf)?;)* + binius_utils::SerializeBytes::serialize(&#variant_index, &mut write_buf, mode)?; + #(binius_utils::SerializeBytes::serialize(#fields, &mut write_buf, mode)?;)* }; match variant.fields { Fields::Named(_) => quote! { @@ -138,8 +138,8 @@ pub fn derive_serialize_canonical(input: TokenStream) -> TokenStream { } }; quote! { - impl #impl_generics binius_field::SerializeCanonical for #name #ty_generics #where_clause { - fn serialize_canonical(&self, mut write_buf: impl binius_field::bytes::BufMut) -> Result<(), binius_field::serialization::Error> { + impl #impl_generics binius_utils::SerializeBytes for #name #ty_generics #where_clause { + fn serialize(&self, mut write_buf: impl binius_utils::bytes::BufMut, mode: binius_utils::SerializationMode) -> Result<(), binius_utils::SerializationError> { #body Ok(()) } @@ -147,13 +147,14 @@ pub fn derive_serialize_canonical(input: TokenStream) -> TokenStream { }.into() } -/// Derives the trait binius_field::DeserializeCanonical for a struct or enum +/// Derives the trait binius_utils::DeserializeBytes for a struct or enum /// /// ``` -/// use binius_field::{BinaryField128b, SerializeCanonical, DeserializeCanonical}; -/// use binius_macros::{SerializeCanonical, DeserializeCanonical}; +/// use binius_field::BinaryField128b; +/// use binius_utils::{SerializeBytes, DeserializeBytes, SerializationMode}; +/// use binius_macros::{SerializeBytes, DeserializeBytes}; /// -/// #[derive(Debug, PartialEq, SerializeCanonical, DeserializeCanonical)] +/// #[derive(Debug, PartialEq, SerializeBytes, DeserializeBytes)] /// enum MyEnum { /// A(usize), /// B { x: u32, y: u32 }, @@ -163,14 +164,14 @@ pub fn derive_serialize_canonical(input: TokenStream) -> TokenStream { /// /// let mut buf = vec![]; /// let value = MyEnum::B { x: 42, y: 1337 }; -/// MyEnum::serialize_canonical(&value, &mut buf).unwrap(); +/// MyEnum::serialize(&value, &mut buf, SerializationMode::Native).unwrap(); /// assert_eq!( -/// MyEnum::deserialize_canonical(buf.as_slice()).unwrap(), +/// MyEnum::deserialize(buf.as_slice(), SerializationMode::Native).unwrap(), /// value /// ); /// /// -/// #[derive(Debug, PartialEq, SerializeCanonical, DeserializeCanonical)] +/// #[derive(Debug, PartialEq, SerializeBytes, DeserializeBytes)] /// struct MyStruct { /// data: Vec /// } @@ -179,14 +180,14 @@ pub fn derive_serialize_canonical(input: TokenStream) -> TokenStream { /// let value = MyStruct { /// data: vec![BinaryField128b::new(1234), BinaryField128b::new(5678)] /// }; -/// MyStruct::serialize_canonical(&value, &mut buf).unwrap(); +/// MyStruct::serialize(&value, &mut buf, SerializationMode::CanonicalTower).unwrap(); /// assert_eq!( -/// MyStruct::::deserialize_canonical(buf.as_slice()).unwrap(), +/// MyStruct::::deserialize(buf.as_slice(), SerializationMode::CanonicalTower).unwrap(), /// value /// ); /// ``` -#[proc_macro_derive(DeserializeCanonical)] -pub fn derive_deserialize_canonical(input: TokenStream) -> TokenStream { +#[proc_macro_derive(DeserializeBytes)] +pub fn derive_deserialize_bytes(input: TokenStream) -> TokenStream { let input: DeriveInput = parse_macro_input!(input); let span = input.span(); let name = input.ident; @@ -194,11 +195,11 @@ pub fn derive_deserialize_canonical(input: TokenStream) -> TokenStream { generics.type_params_mut().for_each(|type_param| { type_param .bounds - .push(parse_quote!(binius_field::DeserializeCanonical)) + .push(parse_quote!(binius_utils::DeserializeBytes)) }); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let deserialize_value = quote! { - binius_field::DeserializeCanonical::deserialize_canonical(&mut read_buf)? + binius_utils::DeserializeBytes::deserialize(&mut read_buf, mode)? }; let body = match input.data { Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(), @@ -255,7 +256,7 @@ pub fn derive_deserialize_canonical(input: TokenStream) -> TokenStream { Ok(match variant_index { #(#variants,)* _ => { - return Err(binius_field::serialization::Error::UnknownEnumVariant { + return Err(binius_utils::SerializationError::UnknownEnumVariant { name: #name, index: variant_index }) @@ -265,8 +266,8 @@ pub fn derive_deserialize_canonical(input: TokenStream) -> TokenStream { } }; quote! { - impl #impl_generics binius_field::DeserializeCanonical for #name #ty_generics #where_clause { - fn deserialize_canonical(mut read_buf: impl binius_field::bytes::Buf) -> Result + impl #impl_generics binius_utils::DeserializeBytes for #name #ty_generics #where_clause { + fn deserialize(mut read_buf: impl binius_utils::bytes::Buf, mode: binius_utils::SerializationMode) -> Result where Self: Sized { @@ -277,36 +278,33 @@ pub fn derive_deserialize_canonical(input: TokenStream) -> TokenStream { .into() } -/// Use on an impl block for MultivariatePoly, to automatically implement erased_serialize_canonical. +/// Use on an impl block for MultivariatePoly, to automatically implement erased_serialize_bytes. /// /// Importantly, this will serialize the concrete instance, prefixed by the identifier of the data type. /// /// This prefix can be used to figure out which concrete data type it should use for deserialization later. #[proc_macro_attribute] -pub fn erased_serialize_canonical(_attr: TokenStream, item: TokenStream) -> TokenStream { +pub fn erased_serialize_bytes(_attr: TokenStream, item: TokenStream) -> TokenStream { let mut item_impl: ItemImpl = parse_macro_input!(item); let syn::Type::Path(p) = &*item_impl.self_ty else { return syn::Error::new( item_impl.span(), - "#[erased_serialize_canonical] can only be used on an impl for a concrete type", + "#[erased_serialize_bytes] can only be used on an impl for a concrete type", ) .into_compile_error() .into(); }; let name = p.path.segments.last().unwrap().ident.to_string(); - - let method = parse_quote! { - fn erased_serialize_canonical( + item_impl.items.push(syn::ImplItem::Fn(parse_quote! { + fn erased_serialize( &self, - write_buf: &mut dyn binius_field::bytes::BufMut, - ) -> Result<(), binius_field::serialization::Error> { - binius_field::SerializeCanonical::serialize_canonical(&#name, &mut *write_buf)?; - binius_field::SerializeCanonical::serialize_canonical(self, &mut *write_buf) + write_buf: &mut dyn binius_utils::bytes::BufMut, + mode: binius_utils::SerializationMode, + ) -> Result<(), binius_utils::SerializationError> { + binius_utils::SerializeBytes::serialize(&#name, &mut *write_buf, mode)?; + binius_utils::SerializeBytes::serialize(self, &mut *write_buf, mode) } - }; - - item_impl.items.push(syn::ImplItem::Fn(method)); - + })); quote! { #item_impl } diff --git a/crates/math/src/arith_expr.rs b/crates/math/src/arith_expr.rs index c5cb0bbab..84cb05cc7 100644 --- a/crates/math/src/arith_expr.rs +++ b/crates/math/src/arith_expr.rs @@ -7,8 +7,8 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; -use binius_field::{DeserializeCanonical, Field, PackedField, SerializeCanonical, TowerField}; -use binius_macros::{DeserializeCanonical, SerializeCanonical}; +use binius_field::{Field, PackedField, TowerField}; +use binius_macros::{DeserializeBytes, SerializeBytes}; use super::error::Error; @@ -17,7 +17,7 @@ use super::error::Error; /// Arithmetic expressions are trees, where the leaves are either constants or variables, and the /// non-leaf nodes are arithmetic operations, such as addition, multiplication, etc. They are /// specific representations of multivariate polynomials. -#[derive(Debug, Clone, PartialEq, Eq, SerializeCanonical, DeserializeCanonical)] +#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub enum ArithExpr { Const(F), Var(usize), @@ -26,26 +26,6 @@ pub enum ArithExpr { Pow(Box>, u64), } -impl SerializeCanonical for Box> { - fn serialize_canonical( - &self, - write_buf: impl binius_field::bytes::BufMut, - ) -> Result<(), binius_field::serialization::Error> { - ArithExpr::::serialize_canonical(&self.to_owned(), write_buf) - } -} - -impl DeserializeCanonical for Box> { - fn deserialize_canonical( - read_buf: impl binius_field::bytes::Buf, - ) -> Result - where - Self: Sized, - { - Ok(Self::new(ArithExpr::::deserialize_canonical(read_buf)?)) - } -} - impl Display for ArithExpr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { diff --git a/crates/utils/Cargo.toml b/crates/utils/Cargo.toml index 5bdc5ad42..c0e418404 100644 --- a/crates/utils/Cargo.toml +++ b/crates/utils/Cargo.toml @@ -8,8 +8,10 @@ authors.workspace = true workspace = true [dependencies] +auto_impl.workspace = true binius_maybe_rayon = { path = "../maybe_rayon", default-features = false } bytemuck = { workspace = true, features = ["extern_crate_alloc"] } +bytes.workspace = true cfg-if.workspace = true generic-array.workspace = true itertools.workspace = true diff --git a/crates/utils/src/lib.rs b/crates/utils/src/lib.rs index 70606d3f0..3ac565f0b 100644 --- a/crates/utils/src/lib.rs +++ b/crates/utils/src/lib.rs @@ -13,6 +13,10 @@ pub mod felts; pub mod graph; pub mod iter; pub mod rayon; +pub mod serialization; pub mod sorting; pub mod sparse_index; pub mod thread_local_mut; + +pub use bytes; +pub use serialization::{DeserializeBytes, SerializationError, SerializationMode, SerializeBytes}; diff --git a/crates/utils/src/serialization.rs b/crates/utils/src/serialization.rs new file mode 100644 index 000000000..a6782b367 --- /dev/null +++ b/crates/utils/src/serialization.rs @@ -0,0 +1,437 @@ +// Copyright 2024-2025 Irreducible Inc. + +use auto_impl::auto_impl; +use bytes::{Buf, BufMut}; +use thiserror::Error; + +/// Serialize data according to Mode param +#[auto_impl(Box, &)] +pub trait SerializeBytes { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError>; +} + +/// Deserialize data according to Mode param +pub trait DeserializeBytes { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized; +} + +/// Specifies serialization/deserialization behavior +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SerializationMode { + /// This mode is faster, and serializes to the underlying bytes + Native, + /// Will first convert any tower fields into the Fan-Paar field equivalent + CanonicalTower, +} + +#[derive(Error, Debug, Clone)] +pub enum SerializationError { + #[error("Write buffer is full")] + WriteBufferFull, + #[error("Not enough data in read buffer to deserialize")] + NotEnoughBytes, + #[error("Unknown enum variant index {name}::{index}")] + UnknownEnumVariant { name: &'static str, index: u8 }, + #[error("Serialization has not been implemented")] + SerializationNotImplemented, + #[error("Deserializer has not been implemented")] + DeserializerNotImplented, + #[error("Multiple deserializers with the same name {name} has been registered")] + DeserializerNameConflict { name: String }, + #[error("FromUtf8Error: {0}")] + FromUtf8Error(#[from] std::string::FromUtf8Error), +} + +// Copyright 2025 Irreducible Inc. + +use generic_array::{ArrayLength, GenericArray}; + +impl DeserializeBytes for Box { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + Ok(Self::new(T::deserialize(read_buf, mode)?)) + } +} + +impl SerializeBytes for usize { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + SerializeBytes::serialize(&(*self as u64), &mut write_buf, mode) + } +} + +impl DeserializeBytes for usize { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + let value: u64 = DeserializeBytes::deserialize(&mut read_buf, mode)?; + Ok(value as Self) + } +} + +impl SerializeBytes for u128 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u128_le(*self); + Ok(()) + } +} + +impl DeserializeBytes for u128 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u128_le()) + } +} + +impl SerializeBytes for u64 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u64_le(*self); + Ok(()) + } +} + +impl DeserializeBytes for u64 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u64_le()) + } +} + +impl SerializeBytes for u32 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u32_le(*self); + Ok(()) + } +} + +impl DeserializeBytes for u32 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u32_le()) + } +} + +impl SerializeBytes for u16 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u16_le(*self); + Ok(()) + } +} + +impl DeserializeBytes for u16 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u16_le()) + } +} + +impl SerializeBytes for u8 { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, std::mem::size_of::())?; + write_buf.put_u8(*self); + Ok(()) + } +} + +impl DeserializeBytes for u8 { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + assert_enough_data_for(&read_buf, std::mem::size_of::())?; + Ok(read_buf.get_u8()) + } +} + +impl SerializeBytes for bool { + fn serialize( + &self, + write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + u8::serialize(&(*self as u8), write_buf, mode) + } +} + +impl DeserializeBytes for bool { + fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result + where + Self: Sized, + { + Ok(u8::deserialize(read_buf, mode)? != 0) + } +} + +impl SerializeBytes for std::marker::PhantomData { + fn serialize( + &self, + _write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl DeserializeBytes for std::marker::PhantomData { + fn deserialize( + _read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(Self) + } +} + +impl SerializeBytes for &str { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + let bytes = self.as_bytes(); + SerializeBytes::serialize(&bytes.len(), &mut write_buf, mode)?; + assert_enough_space_for(&write_buf, bytes.len())?; + write_buf.put_slice(bytes); + Ok(()) + } +} + +impl SerializeBytes for String { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + SerializeBytes::serialize(&self.as_str(), &mut write_buf, mode) + } +} + +impl DeserializeBytes for String { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + let len = DeserializeBytes::deserialize(&mut read_buf, mode)?; + assert_enough_data_for(&read_buf, len)?; + Ok(Self::from_utf8(read_buf.copy_to_bytes(len).to_vec())?) + } +} + +impl SerializeBytes for Vec { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + SerializeBytes::serialize(&self.len(), &mut write_buf, mode)?; + self.iter() + .try_for_each(|item| SerializeBytes::serialize(item, &mut write_buf, mode)) + } +} + +impl DeserializeBytes for Vec { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + let len: usize = DeserializeBytes::deserialize(&mut read_buf, mode)?; + (0..len) + .map(|_| DeserializeBytes::deserialize(&mut read_buf, mode)) + .collect() + } +} + +impl SerializeBytes for Option { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + match self { + Some(value) => { + SerializeBytes::serialize(&true, &mut write_buf, mode)?; + SerializeBytes::serialize(value, &mut write_buf, mode)?; + } + None => { + SerializeBytes::serialize(&false, write_buf, mode)?; + } + } + Ok(()) + } +} + +impl DeserializeBytes for Option { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok(match bool::deserialize(&mut read_buf, mode)? { + true => Some(T::deserialize(&mut read_buf, mode)?), + false => None, + }) + } +} + +impl SerializeBytes for (U, V) { + fn serialize( + &self, + mut write_buf: impl BufMut, + mode: SerializationMode, + ) -> Result<(), SerializationError> { + U::serialize(&self.0, &mut write_buf, mode)?; + V::serialize(&self.1, write_buf, mode) + } +} + +impl DeserializeBytes for (U, V) { + fn deserialize( + mut read_buf: impl Buf, + mode: SerializationMode, + ) -> Result + where + Self: Sized, + { + Ok((U::deserialize(&mut read_buf, mode)?, V::deserialize(read_buf, mode)?)) + } +} + +impl> SerializeBytes for GenericArray { + fn serialize( + &self, + mut write_buf: impl BufMut, + _mode: SerializationMode, + ) -> Result<(), SerializationError> { + assert_enough_space_for(&write_buf, N::USIZE)?; + write_buf.put_slice(self); + Ok(()) + } +} + +impl> DeserializeBytes for GenericArray { + fn deserialize( + mut read_buf: impl Buf, + _mode: SerializationMode, + ) -> Result { + assert_enough_data_for(&read_buf, N::USIZE)?; + let mut ret = Self::default(); + read_buf.copy_to_slice(&mut ret); + Ok(ret) + } +} + +#[inline] +fn assert_enough_space_for(write_buf: &impl BufMut, size: usize) -> Result<(), SerializationError> { + if write_buf.remaining_mut() < size { + return Err(SerializationError::WriteBufferFull); + } + Ok(()) +} + +#[inline] +fn assert_enough_data_for(read_buf: &impl Buf, size: usize) -> Result<(), SerializationError> { + if read_buf.remaining() < size { + return Err(SerializationError::NotEnoughBytes); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use generic_array::typenum::U32; + use rand::{rngs::StdRng, RngCore, SeedableRng}; + + use super::*; + + #[test] + fn test_generic_array_serialize_deserialize() { + let mut rng = StdRng::seed_from_u64(0); + + let mut data = GenericArray::::default(); + rng.fill_bytes(&mut data); + + let mut buf = Vec::new(); + data.serialize(&mut buf, SerializationMode::Native).unwrap(); + + let data_deserialized = + GenericArray::::deserialize(&mut buf.as_slice(), SerializationMode::Native) + .unwrap(); + assert_eq!(data_deserialized, data); + } +} From bfd5d47ad3b67b9003d083961d254ead704850c7 Mon Sep 17 00:00:00 2001 From: Aliaksei Dziadziuk Date: Tue, 18 Feb 2025 12:56:26 +0100 Subject: [PATCH 36/50] [gkr_int_mul] Fix type bounds (#34) --- .../protocols/gkr_gpa/gpa_sumcheck/prove.rs | 4 +-- .../core/src/protocols/gkr_int_mul/error.rs | 8 +++++ .../gkr_int_mul/generator_exponent/prove.rs | 20 ++++++------ .../gkr_int_mul/generator_exponent/tests.rs | 2 +- .../gkr_int_mul/generator_exponent/verify.rs | 5 ++- .../gkr_int_mul/generator_exponent/witness.rs | 31 +++++++++---------- crates/core/src/protocols/gkr_int_mul/mod.rs | 2 +- 7 files changed, 38 insertions(+), 34 deletions(-) diff --git a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs index 90bb03462..b7997ef63 100644 --- a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs +++ b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs @@ -48,9 +48,7 @@ impl<'a, F, FDomain, P, Composition, M, Backend> GPAProver<'a, FDomain, P, Compo where F: Field, FDomain: Field, - P: PackedFieldIndexable - + PackedExtension - + PackedExtension, + P: PackedFieldIndexable + PackedExtension, Composition: CompositionPolyOS

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, diff --git a/crates/core/src/protocols/gkr_int_mul/error.rs b/crates/core/src/protocols/gkr_int_mul/error.rs index caa824f00..6da9cf33d 100644 --- a/crates/core/src/protocols/gkr_int_mul/error.rs +++ b/crates/core/src/protocols/gkr_int_mul/error.rs @@ -13,4 +13,12 @@ pub enum Error { SumcheckError(#[from] SumcheckError), #[error("polynomial error: {0}")] Polynomial(#[from] PolynomialError), + #[error("verification failure: {0}")] + Verification(#[from] VerificationError), +} + +#[derive(Debug, thiserror::Error)] +pub enum VerificationError { + #[error("the proof contains an incorrect evaluation of the eq indicator")] + IncorrectEqIndEvaluation, } diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs index 070d0637e..f6c0ae388 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/prove.rs @@ -3,7 +3,7 @@ use std::array; use binius_field::{ - BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, + BinaryField1b, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField, }; use binius_hal::ComputationBackend; @@ -40,18 +40,17 @@ pub fn prove< backend: &Backend, ) -> Result, Error> where + F: ExtensionField + ExtensionField + TowerField, FDomain: Field, - PBits: PackedField, - PGenerator: PackedExtension - + PackedFieldIndexable + FGenerator: TowerField + ExtensionField + ExtensionField, + PBits: PackedField, + PGenerator: PackedField + + PackedExtension + PackedExtension, - PGenerator::Scalar: ExtensionField + ExtensionField, - PChallenge: PackedField - + PackedFieldIndexable + PChallenge: PackedFieldIndexable + + PackedExtension + PackedExtension + PackedExtension, - F: ExtensionField + ExtensionField + BinaryField + TowerField, - FGenerator: Field + TowerField, Backend: ComputationBackend, Challenger_: Challenger, { @@ -60,10 +59,11 @@ where let mut eval_point = claim.eval_point.clone(); let mut eval = claim.eval; + for exponent_bit_number in (1..EXPONENT_BIT_WIDTH).rev() { let this_round_exponent_bit = witness.exponent[exponent_bit_number].clone(); let this_round_generator_power_constant = - F::from(FGenerator::MULTIPLICATIVE_GENERATOR.pow([1 << exponent_bit_number])); + F::from(FGenerator::MULTIPLICATIVE_GENERATOR.pow(1 << exponent_bit_number)); let this_round_input_data = witness.single_bit_output_layers_data[exponent_bit_number - 1].clone(); diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs index 97508b59f..421b1efd7 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/tests.rs @@ -124,7 +124,7 @@ fn witness_gen_happens_correctly() { for (row_idx, this_row_exponent) in exponent.into_iter().enumerate() { assert_eq!( ::Scalar::MULTIPLICATIVE_GENERATOR - .pow([this_row_exponent as u64]), + .pow(this_row_exponent as u64), get_packed_slice(results, row_idx) ); } diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs index 1024481ab..66df9910e 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/verify.rs @@ -3,7 +3,6 @@ use std::array; use binius_field::{ExtensionField, TowerField}; -use binius_utils::bail; use super::{ super::error::Error, common::GeneratorExponentReductionOutput, utils::first_layer_inverse, @@ -13,7 +12,7 @@ use crate::{ polynomial::MultivariatePoly, protocols::{ gkr_gpa::LayerClaim, - gkr_int_mul::generator_exponent::compositions::MultiplyOrDont, + gkr_int_mul::{error::VerificationError, generator_exponent::compositions::MultiplyOrDont}, sumcheck::{self, zerocheck::ExtraProduct, CompositeSumClaim, SumcheckClaim}, }, transcript::VerifierTranscript, @@ -78,7 +77,7 @@ where EqIndPartialEval::new(log_size, sumcheck_query_point.clone())?.evaluate(&eval_point)?; if sumcheck_verification_output.multilinear_evals[0][2] != eq_eval { - bail!(Error::EqEvalDoesntVerify) + return Err(VerificationError::IncorrectEqIndEvaluation.into()); } eval_claims_on_bit_columns[exponent_bit_number] = LayerClaim { diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/witness.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/witness.rs index 885c47f9e..569c8f491 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/witness.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/witness.rs @@ -3,8 +3,7 @@ use std::{array, cmp::min, slice}; use binius_field::{ - ext_base_op_par, BinaryField, BinaryField1b, ExtensionField, Field, PackedExtension, - PackedField, PackedFieldIndexable, + ext_base_op_par, BinaryField, BinaryField1b, ExtensionField, PackedExtension, PackedField, }; use binius_maybe_rayon::{ prelude::{IndexedParallelIterator, ParallelIterator}, @@ -28,7 +27,7 @@ pub struct GeneratorExponentWitness< fn copy_witness_into_vec(poly: &MultilinearWitness) -> Vec

where P: PackedField, - PE: PackedField + PackedExtension, + PE: PackedExtension, PE::Scalar: ExtensionField, { let mut input_layer: Vec

= zeroed_vec(1 << poly.n_vars().saturating_sub(P::LOG_WIDTH)); @@ -68,9 +67,8 @@ fn evaluate_single_bit_output_packed( ) -> Vec where PBits: PackedField, - PGenerator: - PackedField + PackedFieldIndexable + PackedExtension, - PGenerator::Scalar: ExtensionField + BinaryField, + PGenerator: PackedExtension, + PGenerator::Scalar: BinaryField, { debug_assert_eq!( PBits::WIDTH * exponent_bit.len(), @@ -98,9 +96,8 @@ fn evaluate_first_layer_output_packed( ) -> Vec where PBits: PackedField, - PGenerator: - PackedField + PackedFieldIndexable + PackedExtension, - PGenerator::Scalar: ExtensionField, + PGenerator: PackedExtension, + PGenerator::Scalar: BinaryField, { let mut result = vec![PGenerator::zero(); exponent_bit.len() * PGenerator::Scalar::DEGREE]; @@ -119,11 +116,10 @@ impl<'a, PBits, PGenerator, PChallenge, const EXPONENT_BIT_WIDTH: usize> GeneratorExponentWitness<'a, PBits, PGenerator, PChallenge, EXPONENT_BIT_WIDTH> where PBits: PackedField, - PGenerator: - PackedField + PackedFieldIndexable + PackedExtension, - PGenerator::Scalar: ExtensionField + BinaryField, - PChallenge: PackedField + PackedExtension, - PChallenge::Scalar: ExtensionField, + PGenerator: PackedExtension, + PGenerator::Scalar: BinaryField, + PChallenge: PackedExtension, + PChallenge::Scalar: BinaryField, { pub fn new( exponent: [MultilinearWitness<'a, PChallenge>; EXPONENT_BIT_WIDTH], @@ -140,12 +136,15 @@ where PGenerator::Scalar::MULTIPLICATIVE_GENERATOR, ); + let mut generator_power_constant = PGenerator::Scalar::MULTIPLICATIVE_GENERATOR.square(); + for layer_idx_from_left in 1..EXPONENT_BIT_WIDTH { single_bit_output_layers_data[layer_idx_from_left] = evaluate_single_bit_output_packed( &exponent_data[layer_idx_from_left], - PGenerator::Scalar::MULTIPLICATIVE_GENERATOR.pow([1 << layer_idx_from_left]), + generator_power_constant, &single_bit_output_layers_data[layer_idx_from_left - 1], - ) + ); + generator_power_constant = generator_power_constant.square(); } Ok(Self { diff --git a/crates/core/src/protocols/gkr_int_mul/mod.rs b/crates/core/src/protocols/gkr_int_mul/mod.rs index 1e8861bff..02f0dbf2b 100644 --- a/crates/core/src/protocols/gkr_int_mul/mod.rs +++ b/crates/core/src/protocols/gkr_int_mul/mod.rs @@ -1,4 +1,4 @@ // Copyright 2024-2025 Irreducible Inc. mod error; -//pub mod generator_exponent; +pub mod generator_exponent; From a3fed422e92d8a0d4a57fb7999430f3a412094fe Mon Sep 17 00:00:00 2001 From: Artem Storozhuk Date: Wed, 19 Feb 2025 17:02:10 +0200 Subject: [PATCH 37/50] feat: Blake3 G function gadget (#16) --- crates/circuits/src/blake3.rs | 209 +++++++++++++++++++++++++++ crates/circuits/src/lib.rs | 1 + crates/circuits/src/unconstrained.rs | 28 ++++ 3 files changed, 238 insertions(+) create mode 100644 crates/circuits/src/blake3.rs diff --git a/crates/circuits/src/blake3.rs b/crates/circuits/src/blake3.rs new file mode 100644 index 000000000..4d5325587 --- /dev/null +++ b/crates/circuits/src/blake3.rs @@ -0,0 +1,209 @@ +// Copyright 2024-2025 Irreducible Inc. + +use binius_core::oracle::{OracleId, ShiftVariant}; +use binius_field::{BinaryField1b, Field}; +use binius_utils::checked_arithmetics::checked_log_2; + +use crate::{ + arithmetic, + arithmetic::Flags, + builder::{types::F, ConstraintSystemBuilder}, +}; + +type F1 = BinaryField1b; +const LOG_U32_BITS: usize = checked_log_2(32); + +// Gadget that performs two u32 variables XOR and then rotates the result +fn xor_rotate_right( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + log_size: usize, + a: OracleId, + b: OracleId, + rotate_right_offset: u32, +) -> Result { + assert!(rotate_right_offset <= 32); + + builder.push_namespace(name); + + let xor = builder + .add_linear_combination("xor", log_size, [(a, F::ONE), (b, F::ONE)]) + .unwrap(); + + let rotate = builder.add_shifted( + "rotate", + xor, + 32 - rotate_right_offset as usize, + LOG_U32_BITS, + ShiftVariant::CircularLeft, + )?; + + if let Some(witness) = builder.witness() { + let a_value = witness.get::(a)?.as_slice::(); + let b_value = witness.get::(b)?.as_slice::(); + + let mut xor_witness = witness.new_column::(xor); + let xor_value = xor_witness.as_mut_slice::(); + + for (idx, v) in xor_value.iter_mut().enumerate() { + *v = a_value[idx] ^ b_value[idx]; + } + + let mut rotate_witness = witness.new_column::(rotate); + let rotate_value = rotate_witness.as_mut_slice::(); + for (idx, v) in rotate_value.iter_mut().enumerate() { + *v = xor_value[idx].rotate_right(rotate_right_offset); + } + } + + builder.pop_namespace(); + + Ok(rotate) +} + +#[allow(clippy::too_many_arguments)] +pub fn blake3_g( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + a_in: OracleId, + b_in: OracleId, + c_in: OracleId, + d_in: OracleId, + mx: OracleId, + my: OracleId, + log_size: usize, +) -> Result<[OracleId; 4], anyhow::Error> { + builder.push_namespace(name); + + let ab = arithmetic::u32::add(builder, "a_in + b_in", a_in, b_in, Flags::Unchecked)?; + let a1 = arithmetic::u32::add(builder, "a_in + b_in + mx", ab, mx, Flags::Unchecked)?; + + let d1 = xor_rotate_right(builder, "(d_in ^ a1).rotate_right(16)", log_size, d_in, a1, 16u32)?; + + let c1 = arithmetic::u32::add(builder, "c_in + d1", c_in, d1, Flags::Unchecked)?; + + let b1 = xor_rotate_right(builder, "(b_in ^ c1).rotate_right(12)", log_size, b_in, c1, 12u32)?; + + let a1b1 = arithmetic::u32::add(builder, "a1 + b1", a1, b1, Flags::Unchecked)?; + let a2 = arithmetic::u32::add(builder, "a1 + b1 + my_in", a1b1, my, Flags::Unchecked)?; + + let d2 = xor_rotate_right(builder, "(d1 ^ a2).rotate_right(8)", log_size, d1, a2, 8u32)?; + + let c2 = arithmetic::u32::add(builder, "c1 + d2", c1, d2, Flags::Unchecked)?; + + let b2 = xor_rotate_right(builder, "(b1 ^ c2).rotate_right(7)", log_size, b1, c2, 7u32)?; + + builder.pop_namespace(); + + Ok([a2, b2, c2, d2]) +} + +#[cfg(test)] +mod tests { + use binius_core::constraint_system::validate::validate_witness; + use binius_field::BinaryField1b; + use binius_maybe_rayon::prelude::*; + + use crate::{ + blake3::blake3_g, + builder::ConstraintSystemBuilder, + unconstrained::{fixed_u32, unconstrained}, + }; + + type F1 = BinaryField1b; + + const LOG_SIZE: usize = 5; + + // The Blake3 mixing function, G, which mixes either a column or a diagonal. + // https://github.com/BLAKE3-team/BLAKE3/blob/master/reference_impl/reference_impl.rs + const fn g( + a_in: u32, + b_in: u32, + c_in: u32, + d_in: u32, + mx: u32, + my: u32, + ) -> (u32, u32, u32, u32) { + let a1 = a_in.wrapping_add(b_in).wrapping_add(mx); + let d1 = (d_in ^ a1).rotate_right(16); + let c1 = c_in.wrapping_add(d1); + let b1 = (b_in ^ c1).rotate_right(12); + + let a2 = a1.wrapping_add(b1).wrapping_add(my); + let d2 = (d1 ^ a2).rotate_right(8); + let c2 = c1.wrapping_add(d2); + let b2 = (b1 ^ c2).rotate_right(7); + + (a2, b2, c2, d2) + } + + #[test] + fn test_vector() { + // Let's use some fixed data input to check that our in-circuit computation + // produces same output as out-of-circuit one + let a = 0xaaaaaaaau32; + let b = 0xbbbbbbbbu32; + let c = 0xccccccccu32; + let d = 0xddddddddu32; + let mx = 0xffff00ffu32; + let my = 0xff00ffffu32; + + let (expected_0, expected_1, expected_2, expected_3) = g(a, b, c, d, mx, my); + + let size = 1 << LOG_SIZE; + + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + let a_in = fixed_u32::(&mut builder, "a", LOG_SIZE, vec![a; size]).unwrap(); + let b_in = fixed_u32::(&mut builder, "b", LOG_SIZE, vec![b; size]).unwrap(); + let c_in = fixed_u32::(&mut builder, "c", LOG_SIZE, vec![c; size]).unwrap(); + let d_in = fixed_u32::(&mut builder, "d", LOG_SIZE, vec![d; size]).unwrap(); + let mx_in = fixed_u32::(&mut builder, "mx", LOG_SIZE, vec![mx; size]).unwrap(); + let my_in = fixed_u32::(&mut builder, "my", LOG_SIZE, vec![my; size]).unwrap(); + + let output = + blake3_g(&mut builder, "g", a_in, b_in, c_in, d_in, mx_in, my_in, LOG_SIZE).unwrap(); + + if let Some(witness) = builder.witness() { + ( + witness.get::(output[0]).unwrap().as_slice::(), + witness.get::(output[1]).unwrap().as_slice::(), + witness.get::(output[2]).unwrap().as_slice::(), + witness.get::(output[3]).unwrap().as_slice::(), + ) + .into_par_iter() + .for_each(|(actual_0, actual_1, actual_2, actual_3)| { + assert_eq!(*actual_0, expected_0); + assert_eq!(*actual_1, expected_1); + assert_eq!(*actual_2, expected_2); + assert_eq!(*actual_3, expected_3); + }); + } + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); + } + + #[test] + fn test_random_input() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + let a_in = unconstrained::(&mut builder, "a", LOG_SIZE).unwrap(); + let b_in = unconstrained::(&mut builder, "b", LOG_SIZE).unwrap(); + let c_in = unconstrained::(&mut builder, "c", LOG_SIZE).unwrap(); + let d_in = unconstrained::(&mut builder, "d", LOG_SIZE).unwrap(); + let mx_in = unconstrained::(&mut builder, "mx", LOG_SIZE).unwrap(); + let my_in = unconstrained::(&mut builder, "my", LOG_SIZE).unwrap(); + + blake3_g(&mut builder, "g", a_in, b_in, c_in, d_in, mx_in, my_in, LOG_SIZE).unwrap(); + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); + } +} diff --git a/crates/circuits/src/lib.rs b/crates/circuits/src/lib.rs index 06c0de27a..59ae8195d 100644 --- a/crates/circuits/src/lib.rs +++ b/crates/circuits/src/lib.rs @@ -11,6 +11,7 @@ pub mod arithmetic; pub mod bitwise; +pub mod blake3; pub mod builder; pub mod collatz; pub mod keccakf; diff --git a/crates/circuits/src/unconstrained.rs b/crates/circuits/src/unconstrained.rs index 1e4da1b06..f39f48cda 100644 --- a/crates/circuits/src/unconstrained.rs +++ b/crates/circuits/src/unconstrained.rs @@ -34,3 +34,31 @@ where Ok(rng) } + +// Same as 'unconstrained' but uses some pre-defined values instead of a random ones +pub fn fixed_u32( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + log_size: usize, + value: Vec, +) -> Result +where + U: PackScalar + Pod, + F: TowerField + ExtensionField, + FS: TowerField, +{ + let rng = builder.add_committed(name, log_size, FS::TOWER_LEVEL); + + if let Some(witness) = builder.witness() { + witness + .new_column::(rng) + .as_mut_slice::() + .into_par_iter() + .zip(value.into_par_iter()) + .for_each(|(data, value)| { + *data = value; + }); + } + + Ok(rng) +} From 5b0cb1a92e6c93e1a7ee236f4daa0b9203fe552a Mon Sep 17 00:00:00 2001 From: Tobias Bergkvist Date: Wed, 19 Feb 2025 18:14:27 +0100 Subject: [PATCH 38/50] [circuits] Add test_circuit helper (#27) --- crates/circuits/src/arithmetic/u32.rs | 74 ++-- crates/circuits/src/bitwise.rs | 26 +- crates/circuits/src/builder/mod.rs | 1 + crates/circuits/src/builder/test_utils.rs | 23 + crates/circuits/src/collatz.rs | 25 +- crates/circuits/src/keccakf.rs | 28 +- .../big_integer_ops/byte_sliced_test_utils.rs | 394 ++++++++---------- .../src/lasso/lookups/u8_arithmetic.rs | 201 ++++----- crates/circuits/src/lasso/sha256.rs | 90 ++-- crates/circuits/src/lasso/u32add.rs | 78 ++-- crates/circuits/src/sha256.rs | 90 ++-- crates/circuits/src/u32fib.rs | 19 +- crates/circuits/src/vision.rs | 25 +- crates/core/src/constraint_system/channel.rs | 4 +- 14 files changed, 485 insertions(+), 593 deletions(-) create mode 100644 crates/circuits/src/builder/test_utils.rs diff --git a/crates/circuits/src/arithmetic/u32.rs b/crates/circuits/src/arithmetic/u32.rs index 49037356d..0dcdd1ae7 100644 --- a/crates/circuits/src/arithmetic/u32.rs +++ b/crates/circuits/src/arithmetic/u32.rs @@ -330,63 +330,47 @@ pub fn constant( #[cfg(test)] mod tests { - use binius_core::constraint_system::validate::validate_witness; use binius_field::{BinaryField1b, TowerField}; - use crate::{arithmetic, builder::ConstraintSystemBuilder, unconstrained::unconstrained}; + use crate::{arithmetic, builder::test_utils::test_circuit, unconstrained::unconstrained}; #[test] fn test_mul_const() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - - let a = builder.add_committed("a", 5, BinaryField1b::TOWER_LEVEL); - if let Some(witness) = builder.witness() { - witness - .new_column::(a) - .as_mut_slice::() - .iter_mut() - .for_each(|v| *v = 0b01000000_00000000_00000000_00000000u32); - } - - let _c = arithmetic::u32::mul_const(&mut builder, "mul3", a, 3, arithmetic::Flags::Checked) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let a = builder.add_committed("a", 5, BinaryField1b::TOWER_LEVEL); + if let Some(witness) = builder.witness() { + witness + .new_column::(a) + .as_mut_slice::() + .iter_mut() + .for_each(|v| *v = 0b01000000_00000000_00000000_00000000u32); + } + let _c = arithmetic::u32::mul_const(builder, "mul3", a, 3, arithmetic::Flags::Checked)?; + Ok(vec![]) + }) + .unwrap(); } #[test] fn test_add() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 14; - let a = unconstrained::(&mut builder, "a", log_size).unwrap(); - let b = unconstrained::(&mut builder, "b", log_size).unwrap(); - let _c = arithmetic::u32::add(&mut builder, "u32add", a, b, arithmetic::Flags::Unchecked) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let log_size = 14; + let a = unconstrained::(builder, "a", log_size)?; + let b = unconstrained::(builder, "b", log_size)?; + let _c = arithmetic::u32::add(builder, "u32add", a, b, arithmetic::Flags::Unchecked)?; + Ok(vec![]) + }) + .unwrap(); } #[test] fn test_sub() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - - let a = unconstrained::(&mut builder, "a", 7).unwrap(); - let b = unconstrained::(&mut builder, "a", 7).unwrap(); - let _c = - arithmetic::u32::sub(&mut builder, "c", a, b, arithmetic::Flags::Unchecked).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let a = unconstrained::(builder, "a", 7).unwrap(); + let b = unconstrained::(builder, "a", 7).unwrap(); + let _c = arithmetic::u32::sub(builder, "c", a, b, arithmetic::Flags::Unchecked)?; + Ok(vec![]) + }) + .unwrap(); } } diff --git a/crates/circuits/src/bitwise.rs b/crates/circuits/src/bitwise.rs index 5a7113541..42d7cbd77 100644 --- a/crates/circuits/src/bitwise.rs +++ b/crates/circuits/src/bitwise.rs @@ -98,25 +98,21 @@ pub fn or( #[cfg(test)] mod tests { - use binius_core::constraint_system::validate::validate_witness; use binius_field::BinaryField1b; - use crate::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; + use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained}; #[test] fn test_bitwise() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 6; - let a = unconstrained::(&mut builder, "a", log_size).unwrap(); - let b = unconstrained::(&mut builder, "b", log_size).unwrap(); - let _and = super::and(&mut builder, "and", a, b).unwrap(); - let _xor = super::xor(&mut builder, "xor", a, b).unwrap(); - let _or = super::or(&mut builder, "or", a, b).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let log_size = 6; + let a = unconstrained::(builder, "a", log_size)?; + let b = unconstrained::(builder, "b", log_size)?; + let _and = super::and(builder, "and", a, b)?; + let _xor = super::xor(builder, "xor", a, b)?; + let _or = super::or(builder, "or", a, b)?; + Ok(vec![]) + }) + .unwrap(); } } diff --git a/crates/circuits/src/builder/mod.rs b/crates/circuits/src/builder/mod.rs index 9c706fe3f..7ee629f3c 100644 --- a/crates/circuits/src/builder/mod.rs +++ b/crates/circuits/src/builder/mod.rs @@ -1,6 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. pub mod constraint_system; +pub mod test_utils; pub mod types; pub mod witness; diff --git a/crates/circuits/src/builder/test_utils.rs b/crates/circuits/src/builder/test_utils.rs new file mode 100644 index 000000000..0aa80ad92 --- /dev/null +++ b/crates/circuits/src/builder/test_utils.rs @@ -0,0 +1,23 @@ +// Copyright 2025 Irreducible Inc. + +use binius_core::constraint_system::{channel::Boundary, validate::validate_witness}; + +use super::{types::F, ConstraintSystemBuilder}; + +pub fn test_circuit( + build_circuit: fn(&mut ConstraintSystemBuilder) -> Result>, anyhow::Error>, +) -> Result<(), anyhow::Error> { + let mut verifier_builder = ConstraintSystemBuilder::new(); + let verifier_boundaries = build_circuit(&mut verifier_builder)?; + let verifier_constraint_system = verifier_builder.build()?; + + let allocator = bumpalo::Bump::new(); + let mut prover_builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let prover_boundaries = build_circuit(&mut prover_builder)?; + let prover_witness = prover_builder.take_witness()?; + let _prover_constraint_system = prover_builder.build()?; + + assert_eq!(verifier_boundaries, prover_boundaries); + validate_witness(&verifier_constraint_system, &verifier_boundaries, &prover_witness)?; + Ok(()) +} diff --git a/crates/circuits/src/collatz.rs b/crates/circuits/src/collatz.rs index 04f3b3c11..b76fdf2fd 100644 --- a/crates/circuits/src/collatz.rs +++ b/crates/circuits/src/collatz.rs @@ -193,24 +193,17 @@ pub fn ensure_odd( #[cfg(test)] mod tests { - use binius_core::constraint_system::validate::validate_witness; - - use crate::{builder::ConstraintSystemBuilder, collatz::Collatz}; + use crate::{builder::test_utils::test_circuit, collatz::Collatz}; #[test] fn test_collatz() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - - let x0 = 9999999; - - let mut collatz = Collatz::new(x0); - let advice = collatz.init_prover(); - - let boundaries = collatz.build(&mut builder, advice).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let x0 = 9999999; + let mut collatz = Collatz::new(x0); + let advice = collatz.init_prover(); + let boundaries = collatz.build(builder, advice)?; + Ok(boundaries) + }) + .unwrap(); } } diff --git a/crates/circuits/src/keccakf.rs b/crates/circuits/src/keccakf.rs index fee7025f6..89d11d560 100644 --- a/crates/circuits/src/keccakf.rs +++ b/crates/circuits/src/keccakf.rs @@ -508,28 +508,20 @@ const KECCAKF_RC: [u64; ROUNDS_PER_PERMUTATION] = [ #[cfg(test)] mod tests { - use binius_core::constraint_system::validate::validate_witness; use rand::{rngs::StdRng, Rng, SeedableRng}; - use super::KeccakfState; - use crate::builder::ConstraintSystemBuilder; + use super::{keccakf, KeccakfState}; + use crate::builder::test_utils::test_circuit; #[test] fn test_keccakf() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 5; - - let mut rng = StdRng::seed_from_u64(0); - let input_states = vec![KeccakfState(rng.gen())]; - let _state_out = super::keccakf(&mut builder, &Some(input_states), log_size); - - let witness = builder.take_witness().unwrap(); - - let constraint_system = builder.build().unwrap(); - - let boundaries = vec![]; - - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let log_size = 5; + let mut rng = StdRng::seed_from_u64(0); + let input_states = vec![KeccakfState(rng.gen())]; + let _state_out = keccakf(builder, &Some(input_states), log_size)?; + Ok(vec![]) + }) + .unwrap(); } } diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs index 2767cc9ea..7ce699621 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs @@ -3,18 +3,18 @@ use std::{array, fmt::Debug}; use alloy_primitives::U512; -use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; +use binius_core::oracle::OracleId; use binius_field::{ tower_levels::TowerLevel, BinaryField1b, BinaryField32b, BinaryField8b, Field, TowerField, }; -use rand::{rngs::ThreadRng, thread_rng, Rng}; +use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng}; use super::{ byte_sliced_add, byte_sliced_add_carryfree, byte_sliced_double_conditional_increment, byte_sliced_modular_mul, byte_sliced_mul, }; use crate::{ - builder::ConstraintSystemBuilder, + builder::test_utils::test_circuit, lasso::{ batch::LookupBatch, lookups::u8_arithmetic::{add_carryfree_lookup, add_lookup, dci_lookup, mul_lookup}, @@ -26,7 +26,7 @@ use crate::{ type B8 = BinaryField8b; type B32 = BinaryField32b; -pub fn random_u512(rng: &mut ThreadRng) -> U512 { +pub fn random_u512(rng: &mut impl Rng) -> U512 { let limbs = array::from_fn(|_| rng.gen()); U512::from_limbs(limbs) } @@ -35,156 +35,133 @@ pub fn test_bytesliced_add() where TL: TowerLevel, { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 14; - - let x_in = - array::from_fn(|_| unconstrained::(&mut builder, "x", log_size).unwrap()); - let y_in = - array::from_fn(|_| unconstrained::(&mut builder, "y", log_size).unwrap()); - let c_in = unconstrained::(&mut builder, "cin first", log_size).unwrap(); - - let lookup_t_add = add_lookup(&mut builder, "add table").unwrap(); - - let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); - let _sum_and_cout = byte_sliced_add::( - &mut builder, - "lasso_bytesliced_add", - &x_in, - &y_in, - c_in, - log_size, - &mut lookup_batch_add, - ) + test_circuit(|builder| { + let log_size = 14; + let x_in = + array::from_fn(|_| unconstrained::(builder, "x", log_size).unwrap()); + let y_in = + array::from_fn(|_| unconstrained::(builder, "y", log_size).unwrap()); + let c_in = unconstrained::(builder, "cin first", log_size)?; + let lookup_t_add = add_lookup(builder, "add table")?; + let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); + let _sum_and_cout = byte_sliced_add::( + builder, + "lasso_bytesliced_add", + &x_in, + &y_in, + c_in, + log_size, + &mut lookup_batch_add, + )?; + lookup_batch_add.execute::(builder)?; + Ok(vec![]) + }) .unwrap(); - - lookup_batch_add.execute::(&mut builder).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } pub fn test_bytesliced_add_carryfree() where TL: TowerLevel, { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 14; - let x_in = array::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL)); - let y_in = array::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL)); - let c_in = builder.add_committed("c", log_size, BinaryField1b::TOWER_LEVEL); - - if let Some(witness) = builder.witness() { - let mut x_in: [_; WIDTH] = - array::from_fn(|byte_idx| witness.new_column::(x_in[byte_idx])); - let mut y_in: [_; WIDTH] = - array::from_fn(|byte_idx| witness.new_column::(y_in[byte_idx])); - let mut c_in = witness.new_column::(c_in); - - let x_in_bytes_u8: [_; WIDTH] = x_in.each_mut().map(|col| col.as_mut_slice::()); - let y_in_bytes_u8: [_; WIDTH] = y_in.each_mut().map(|col| col.as_mut_slice::()); - let c_in_u8 = c_in.as_mut_slice::(); - - for row_idx in 0..1 << log_size { - let mut rng = thread_rng(); - let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8); - let mut x = random_u512(&mut rng); - x &= input_bitmask; - let mut y = random_u512(&mut rng); - y &= input_bitmask; - - let mut c: bool = rng.gen(); - - while (x + y + U512::from(c)) > input_bitmask { - x = random_u512(&mut rng); + test_circuit(|builder| { + let log_size = 14; + let x_in = + array::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL)); + let y_in = + array::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL)); + let c_in = builder.add_committed("c", log_size, BinaryField1b::TOWER_LEVEL); + + if let Some(witness) = builder.witness() { + let mut x_in: [_; WIDTH] = + array::from_fn(|byte_idx| witness.new_column::(x_in[byte_idx])); + let mut y_in: [_; WIDTH] = + array::from_fn(|byte_idx| witness.new_column::(y_in[byte_idx])); + let mut c_in = witness.new_column::(c_in); + + let x_in_bytes_u8: [_; WIDTH] = x_in.each_mut().map(|col| col.as_mut_slice::()); + let y_in_bytes_u8: [_; WIDTH] = y_in.each_mut().map(|col| col.as_mut_slice::()); + let c_in_u8 = c_in.as_mut_slice::(); + + for row_idx in 0..1 << log_size { + let mut rng = thread_rng(); + let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8); + let mut x = random_u512(&mut rng); x &= input_bitmask; - y = random_u512(&mut rng); + let mut y = random_u512(&mut rng); y &= input_bitmask; - c = rng.gen(); - } - for byte_idx in 0..WIDTH { - x_in_bytes_u8[byte_idx][row_idx] = x.byte(byte_idx); + let mut c: bool = rng.gen(); - y_in_bytes_u8[byte_idx][row_idx] = y.byte(byte_idx); - } + while (x + y + U512::from(c)) > input_bitmask { + x = random_u512(&mut rng); + x &= input_bitmask; + y = random_u512(&mut rng); + y &= input_bitmask; + c = rng.gen(); + } - c_in_u8[row_idx / 8] |= (c as u8) << (row_idx % 8); - } - } + for byte_idx in 0..WIDTH { + x_in_bytes_u8[byte_idx][row_idx] = x.byte(byte_idx); - let lookup_t_add = add_lookup(&mut builder, "add table").unwrap(); - let lookup_t_add_carryfree = add_carryfree_lookup(&mut builder, "add table").unwrap(); + y_in_bytes_u8[byte_idx][row_idx] = y.byte(byte_idx); + } - let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); - let mut lookup_batch_add_carryfree = LookupBatch::new([lookup_t_add_carryfree]); + c_in_u8[row_idx / 8] |= (c as u8) << (row_idx % 8); + } + } - let _sum_and_cout = byte_sliced_add_carryfree::( - &mut builder, - "lasso_bytesliced_add_carryfree", - &x_in, - &y_in, - c_in, - log_size, - &mut lookup_batch_add, - &mut lookup_batch_add_carryfree, - ) + let lookup_t_add = add_lookup(builder, "add table")?; + let lookup_t_add_carryfree = add_carryfree_lookup(builder, "add table")?; + + let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); + let mut lookup_batch_add_carryfree = LookupBatch::new([lookup_t_add_carryfree]); + + let _sum_and_cout = byte_sliced_add_carryfree::( + builder, + "lasso_bytesliced_add_carryfree", + &x_in, + &y_in, + c_in, + log_size, + &mut lookup_batch_add, + &mut lookup_batch_add_carryfree, + )?; + + lookup_batch_add.execute::(builder)?; + lookup_batch_add_carryfree.execute::(builder)?; + Ok(vec![]) + }) .unwrap(); - - lookup_batch_add.execute::(&mut builder).unwrap(); - lookup_batch_add_carryfree - .execute::(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } pub fn test_bytesliced_double_conditional_increment() where TL: TowerLevel, { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 14; - - let x_in = - array::from_fn(|_| unconstrained::(&mut builder, "x", log_size).unwrap()); - - let first_c_in = unconstrained::(&mut builder, "cin first", log_size).unwrap(); - - let second_c_in = unconstrained::(&mut builder, "cin second", log_size).unwrap(); - - let zero_oracle_carry = - transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - let lookup_t_dci = dci_lookup(&mut builder, "add table").unwrap(); - - let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); - - let _sum_and_cout = byte_sliced_double_conditional_increment::( - &mut builder, - "lasso_bytesliced_DCI", - &x_in, - first_c_in, - second_c_in, - log_size, - zero_oracle_carry, - &mut lookup_batch_dci, - ) + test_circuit(|builder| { + let log_size = 14; + let x_in = + array::from_fn(|_| unconstrained::(builder, "x", log_size).unwrap()); + let first_c_in = unconstrained::(builder, "cin first", log_size)?; + let second_c_in = unconstrained::(builder, "cin second", log_size)?; + let zero_oracle_carry = + transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?; + let lookup_t_dci = dci_lookup(builder, "add table")?; + let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); + let _sum_and_cout = byte_sliced_double_conditional_increment::( + builder, + "lasso_bytesliced_DCI", + &x_in, + first_c_in, + second_c_in, + log_size, + zero_oracle_carry, + &mut lookup_batch_dci, + )?; + lookup_batch_dci.execute::(builder)?; + Ok(vec![]) + }) .unwrap(); - - lookup_batch_dci.execute::(&mut builder).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } pub fn test_bytesliced_mul() @@ -192,43 +169,34 @@ where TL: TowerLevel, TL::Base: TowerLevel, { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 14; - - let mult_a = - array::from_fn(|_| unconstrained::(&mut builder, "a", log_size).unwrap()); - let mult_b = - array::from_fn(|_| unconstrained::(&mut builder, "b", log_size).unwrap()); - - let zero_oracle_carry = - transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - - let lookup_t_mul = mul_lookup(&mut builder, "mul lookup").unwrap(); - let lookup_t_add = add_lookup(&mut builder, "add lookup").unwrap(); - let lookup_t_dci = dci_lookup(&mut builder, "dci lookup").unwrap(); - - let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]); - let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); - let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); - - let _sum_and_cout = byte_sliced_mul::( - &mut builder, - "lasso_bytesliced_mul", - &mult_a, - &mult_b, - log_size, - zero_oracle_carry, - &mut lookup_batch_mul, - &mut lookup_batch_add, - &mut lookup_batch_dci, - ) + test_circuit(|builder| { + let log_size = 14; + let mult_a = + array::from_fn(|_| unconstrained::(builder, "a", log_size).unwrap()); + let mult_b = + array::from_fn(|_| unconstrained::(builder, "b", log_size).unwrap()); + let zero_oracle_carry = + transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?; + let lookup_t_mul = mul_lookup(builder, "mul lookup")?; + let lookup_t_add = add_lookup(builder, "add lookup")?; + let lookup_t_dci = dci_lookup(builder, "dci lookup")?; + let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]); + let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); + let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]); + let _sum_and_cout = byte_sliced_mul::( + builder, + "lasso_bytesliced_mul", + &mult_a, + &mult_b, + log_size, + zero_oracle_carry, + &mut lookup_batch_mul, + &mut lookup_batch_add, + &mut lookup_batch_dci, + )?; + Ok(vec![]) + }) .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } pub fn test_bytesliced_modular_mul() @@ -237,66 +205,56 @@ where TL::Base: TowerLevel, >::Data: Debug, { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 14; - - let mut rng = thread_rng(); - - let mult_a = builder.add_committed_multiple::("a", log_size, B8::TOWER_LEVEL); - let mult_b = builder.add_committed_multiple::("b", log_size, B8::TOWER_LEVEL); - - let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8); - - let modulus = (random_u512(&mut rng) % input_bitmask) + U512::from(1u8); + test_circuit(|builder| { + let log_size = 14; + let mut rng = thread_rng(); + let mult_a = builder.add_committed_multiple::("a", log_size, B8::TOWER_LEVEL); + let mult_b = builder.add_committed_multiple::("b", log_size, B8::TOWER_LEVEL); + let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8); + let modulus = + (random_u512(&mut StdRng::from_seed([42; 32])) % input_bitmask) + U512::from(1u8); - if let Some(witness) = builder.witness() { - let mut mult_a: [_; WIDTH] = - array::from_fn(|byte_idx| witness.new_column::(mult_a[byte_idx])); + if let Some(witness) = builder.witness() { + let mut mult_a: [_; WIDTH] = + array::from_fn(|byte_idx| witness.new_column::(mult_a[byte_idx])); - let mult_a_u8 = mult_a.each_mut().map(|col| col.as_mut_slice::()); + let mult_a_u8 = mult_a.each_mut().map(|col| col.as_mut_slice::()); - let mut mult_b: [_; WIDTH] = - array::from_fn(|byte_idx| witness.new_column::(mult_b[byte_idx])); + let mut mult_b: [_; WIDTH] = + array::from_fn(|byte_idx| witness.new_column::(mult_b[byte_idx])); - let mult_b_u8 = mult_b.each_mut().map(|col| col.as_mut_slice::()); + let mult_b_u8 = mult_b.each_mut().map(|col| col.as_mut_slice::()); - for row_idx in 0..1 << log_size { - let mut a = random_u512(&mut rng); - let mut b = random_u512(&mut rng); + for row_idx in 0..1 << log_size { + let mut a = random_u512(&mut rng); + let mut b = random_u512(&mut rng); - a %= modulus; - b %= modulus; + a %= modulus; + b %= modulus; - for byte_idx in 0..WIDTH { - mult_a_u8[byte_idx][row_idx] = a.byte(byte_idx); - mult_b_u8[byte_idx][row_idx] = b.byte(byte_idx); + for byte_idx in 0..WIDTH { + mult_a_u8[byte_idx][row_idx] = a.byte(byte_idx); + mult_b_u8[byte_idx][row_idx] = b.byte(byte_idx); + } } } - } - let modulus_input: [_; WIDTH] = array::from_fn(|byte_idx| modulus.byte(byte_idx)); - - let zero_oracle_byte = - transparent::constant(&mut builder, "zero carry", log_size, BinaryField8b::ZERO).unwrap(); - - let zero_oracle_carry = - transparent::constant(&mut builder, "zero carry", log_size, BinaryField1b::ZERO).unwrap(); - - let _modded_product = byte_sliced_modular_mul::( - &mut builder, - "lasso_bytesliced_mul", - &mult_a, - &mult_b, - &modulus_input, - log_size, - zero_oracle_byte, - zero_oracle_carry, - ) + let modulus_input: [_; WIDTH] = array::from_fn(|byte_idx| modulus.byte(byte_idx)); + let zero_oracle_byte = + transparent::constant(builder, "zero carry", log_size, BinaryField8b::ZERO)?; + let zero_oracle_carry = + transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?; + let _modded_product = byte_sliced_modular_mul::( + builder, + "lasso_bytesliced_mul", + &mult_a, + &mult_b, + &modulus_input, + log_size, + zero_oracle_byte, + zero_oracle_carry, + )?; + Ok(vec![]) + }) .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } diff --git a/crates/circuits/src/lasso/lookups/u8_arithmetic.rs b/crates/circuits/src/lasso/lookups/u8_arithmetic.rs index e1d25d834..bb9a4e6c3 100644 --- a/crates/circuits/src/lasso/lookups/u8_arithmetic.rs +++ b/crates/circuits/src/lasso/lookups/u8_arithmetic.rs @@ -149,11 +149,10 @@ pub fn dci_lookup( #[cfg(test)] mod tests { - use binius_core::constraint_system::validate::validate_witness; use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b}; use crate::{ - builder::ConstraintSystemBuilder, + builder::test_utils::test_circuit, lasso::{self, batch::LookupBatch}, unconstrained::unconstrained, }; @@ -161,139 +160,111 @@ mod tests { #[test] fn test_lasso_u8add_carryfree_rejects_carry() { // TODO: Make this test 100% certain to pass instead of 2^14 bits of security from randomness - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 14; - let x_in = unconstrained::(&mut builder, "x", log_size).unwrap(); - let y_in = unconstrained::(&mut builder, "y", log_size).unwrap(); - let c_in = unconstrained::(&mut builder, "c", log_size).unwrap(); - - let lookup_t = super::add_carryfree_lookup(&mut builder, "add cf table").unwrap(); - let mut lookup_batch = LookupBatch::new([lookup_t]); - let _sum_and_cout = lasso::u8add_carryfree( - &mut builder, - &mut lookup_batch, - "lasso_u8add", - x_in, - y_in, - c_in, - log_size, - ) - .unwrap(); - - lookup_batch - .execute::(&mut builder) - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness) - .expect_err("Rejected overflowing add"); + test_circuit(|builder| { + let log_size = 14; + let x_in = unconstrained::(builder, "x", log_size)?; + let y_in = unconstrained::(builder, "y", log_size)?; + let c_in = unconstrained::(builder, "c", log_size)?; + + let lookup_t = super::add_carryfree_lookup(builder, "add cf table")?; + let mut lookup_batch = LookupBatch::new([lookup_t]); + let _sum_and_cout = lasso::u8add_carryfree( + builder, + &mut lookup_batch, + "lasso_u8add", + x_in, + y_in, + c_in, + log_size, + )?; + lookup_batch.execute::(builder)?; + Ok(vec![]) + }) + .expect_err("Rejected overflowing add"); } #[test] fn test_lasso_u8mul() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 10; - - let mult_a = unconstrained::(&mut builder, "mult_a", log_size).unwrap(); - let mult_b = unconstrained::(&mut builder, "mult_b", log_size).unwrap(); - - let mul_lookup_table = super::mul_lookup(&mut builder, "mul table").unwrap(); - - let mut lookup_batch = LookupBatch::new([mul_lookup_table]); - - let _product = lasso::u8mul( - &mut builder, - &mut lookup_batch, - "lasso_u8mul", - mult_a, - mult_b, - 1 << log_size, - ) - .unwrap(); + test_circuit(|builder| { + let log_size = 10; - lookup_batch - .execute::(&mut builder) - .unwrap(); + let mult_a = unconstrained::(builder, "mult_a", log_size)?; + let mult_b = unconstrained::(builder, "mult_b", log_size)?; - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); - } - - #[test] - fn test_lasso_batched_u8mul() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 10; - let mul_lookup_table = super::mul_lookup(&mut builder, "mul table").unwrap(); - - let mut lookup_batch = LookupBatch::new([mul_lookup_table]); + let mul_lookup_table = super::mul_lookup(builder, "mul table")?; - for _ in 0..10 { - let mult_a = unconstrained::(&mut builder, "mult_a", log_size).unwrap(); - let mult_b = unconstrained::(&mut builder, "mult_b", log_size).unwrap(); + let mut lookup_batch = LookupBatch::new([mul_lookup_table]); let _product = lasso::u8mul( - &mut builder, + builder, &mut lookup_batch, "lasso_u8mul", mult_a, mult_b, 1 << log_size, - ) - .unwrap(); - } - - lookup_batch - .execute::(&mut builder) - .unwrap(); + )?; - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + lookup_batch.execute::(builder)?; + Ok(vec![]) + }) + .unwrap(); } #[test] - fn test_lasso_batched_u8mul_rejects() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 10; - - // We try to feed in the add table instead - let mul_lookup_table = super::add_lookup(&mut builder, "mul table").unwrap(); - - let mut lookup_batch = LookupBatch::new([mul_lookup_table]); - - // TODO?: Make this test fail 100% of the time, even though its almost impossible with rng - for _ in 0..10 { - let mult_a = unconstrained::(&mut builder, "mult_a", log_size).unwrap(); - let mult_b = unconstrained::(&mut builder, "mult_b", log_size).unwrap(); + fn test_lasso_batched_u8mul() { + test_circuit(|builder| { + let log_size = 10; + let mul_lookup_table = super::mul_lookup(builder, "mul table")?; + + let mut lookup_batch = LookupBatch::new([mul_lookup_table]); + + for _ in 0..10 { + let mult_a = unconstrained::(builder, "mult_a", log_size)?; + let mult_b = unconstrained::(builder, "mult_b", log_size)?; + + let _product = lasso::u8mul( + builder, + &mut lookup_batch, + "lasso_u8mul", + mult_a, + mult_b, + 1 << log_size, + )?; + } - let _product = lasso::u8mul( - &mut builder, - &mut lookup_batch, - "lasso_u8mul", - mult_a, - mult_b, - 1 << log_size, - ) - .unwrap(); - } + lookup_batch.execute::(builder)?; + Ok(vec![]) + }) + .unwrap(); + } - lookup_batch - .execute::(&mut builder) - .unwrap(); + #[test] + fn test_lasso_batched_u8mul_rejects() { + test_circuit(|builder| { + let log_size = 10; + + // We try to feed in the add table instead + let mul_lookup_table = super::add_lookup(builder, "mul table")?; + + let mut lookup_batch = LookupBatch::new([mul_lookup_table]); + + // TODO?: Make this test fail 100% of the time, even though its almost impossible with rng + for _ in 0..10 { + let mult_a = unconstrained::(builder, "mult_a", log_size)?; + let mult_b = unconstrained::(builder, "mult_b", log_size)?; + let _product = lasso::u8mul( + builder, + &mut lookup_batch, + "lasso_u8mul", + mult_a, + mult_b, + 1 << log_size, + )?; + } - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness) - .expect_err("Channels should be unbalanced"); + lookup_batch.execute::(builder)?; + Ok(vec![]) + }) + .expect_err("Channels should be unbalanced"); } } diff --git a/crates/circuits/src/lasso/sha256.rs b/crates/circuits/src/lasso/sha256.rs index cd51271bc..1b8f3c3dd 100644 --- a/crates/circuits/src/lasso/sha256.rs +++ b/crates/circuits/src/lasso/sha256.rs @@ -274,63 +274,61 @@ pub fn sha256( #[cfg(test)] mod tests { - use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; + use binius_core::oracle::OracleId; use binius_field::{as_packed_field::PackedType, BinaryField1b, BinaryField8b, TowerField}; use sha2::{compress256, digest::generic_array::GenericArray}; use crate::{ - builder::{types::U, ConstraintSystemBuilder}, + builder::{test_utils::test_circuit, types::U}, unconstrained::unconstrained, }; #[test] fn test_sha256_lasso() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = PackedType::::LOG_WIDTH + BinaryField8b::TOWER_LEVEL; - let input: [OracleId; 16] = std::array::from_fn(|i| { - unconstrained::(&mut builder, i, log_size).unwrap() - }); - let state_output = super::sha256(&mut builder, input, log_size).unwrap(); - - let witness = builder.witness().unwrap(); - - let input_witneses: [_; 16] = std::array::from_fn(|i| { - witness - .get::(input[i]) - .unwrap() - .as_slice::() - }); - - let output_witneses: [_; 8] = std::array::from_fn(|i| { - witness - .get::(state_output[i]) - .unwrap() - .as_slice::() - }); - - let mut generic_array_input = GenericArray::::default(); - - let n_compressions = input_witneses[0].len(); - - for j in 0..n_compressions { - for i in 0..16 { - for z in 0..4 { - generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; + test_circuit(|builder| { + let log_size = PackedType::::LOG_WIDTH + BinaryField8b::TOWER_LEVEL; + let input: [OracleId; 16] = std::array::from_fn(|i| { + unconstrained::(builder, i, log_size).unwrap() + }); + let state_output = super::sha256(builder, input, log_size).unwrap(); + + if let Some(witness) = builder.witness() { + let input_witneses: [_; 16] = std::array::from_fn(|i| { + witness + .get::(input[i]) + .unwrap() + .as_slice::() + }); + + let output_witneses: [_; 8] = std::array::from_fn(|i| { + witness + .get::(state_output[i]) + .unwrap() + .as_slice::() + }); + + let mut generic_array_input = GenericArray::::default(); + + let n_compressions = input_witneses[0].len(); + + for j in 0..n_compressions { + for i in 0..16 { + for z in 0..4 { + generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; + } + } + + let mut output = crate::sha256::INIT; + compress256(&mut output, &[generic_array_input]); + + for i in 0..8 { + assert_eq!(output[i], output_witneses[i][j]); + } } } - let mut output = crate::sha256::INIT; - compress256(&mut output, &[generic_array_input]); - - for i in 0..8 { - assert_eq!(output[i], output_witneses[i][j]); - } - } - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + Ok(vec![]) + }) + .unwrap(); } } diff --git a/crates/circuits/src/lasso/u32add.rs b/crates/circuits/src/lasso/u32add.rs index 271cb07ce..9b260b8c2 100644 --- a/crates/circuits/src/lasso/u32add.rs +++ b/crates/circuits/src/lasso/u32add.rs @@ -236,62 +236,48 @@ impl Drop for SeveralU32add { #[cfg(test)] mod tests { - use binius_core::constraint_system::validate::validate_witness; use binius_field::{BinaryField1b, BinaryField8b}; use super::SeveralU32add; - use crate::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; + use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained}; #[test] fn test_several_lasso_u32add() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - - let mut several_u32_add = SeveralU32add::new(&mut builder).unwrap(); - - for log_size in [11, 12, 13] { - // BinaryField8b is used here because we utilize an 8x8x1→8 table - let add_a_u8 = unconstrained::(&mut builder, "add_a", log_size).unwrap(); - let add_b_u8 = unconstrained::(&mut builder, "add_b", log_size).unwrap(); - let _sum = several_u32_add - .u32add::( - &mut builder, - "lasso_u32add", - add_a_u8, - add_b_u8, - ) - .unwrap(); - } - - several_u32_add - .finalize(&mut builder, "lasso_u32add") - .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let mut several_u32_add = SeveralU32add::new(builder).unwrap(); + for log_size in [11, 12, 13] { + // BinaryField8b is used here because we utilize an 8x8x1→8 table + let add_a_u8 = unconstrained::(builder, "add_a", log_size).unwrap(); + let add_b_u8 = unconstrained::(builder, "add_b", log_size).unwrap(); + let _sum = several_u32_add + .u32add::( + builder, + "lasso_u32add", + add_a_u8, + add_b_u8, + ) + .unwrap(); + } + several_u32_add.finalize(builder, "lasso_u32add").unwrap(); + Ok(vec![]) + }) + .unwrap(); } #[test] fn test_lasso_u32add() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 14; - - let add_a = unconstrained::(&mut builder, "add_a", log_size).unwrap(); - let add_b = unconstrained::(&mut builder, "add_b", log_size).unwrap(); - let _sum = super::u32add::( - &mut builder, - "lasso_u32add", - add_a, - add_b, - ) + test_circuit(|builder| { + let log_size = 14; + let add_a = unconstrained::(builder, "add_a", log_size)?; + let add_b = unconstrained::(builder, "add_b", log_size)?; + let _sum = super::u32add::( + builder, + "lasso_u32add", + add_a, + add_b, + )?; + Ok(vec![]) + }) .unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); } } diff --git a/crates/circuits/src/sha256.rs b/crates/circuits/src/sha256.rs index d28502d15..53b4a33b2 100644 --- a/crates/circuits/src/sha256.rs +++ b/crates/circuits/src/sha256.rs @@ -315,63 +315,61 @@ pub fn sha256( #[cfg(test)] mod tests { - use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; + use binius_core::oracle::OracleId; use binius_field::{as_packed_field::PackedType, BinaryField1b}; use sha2::{compress256, digest::generic_array::GenericArray}; use crate::{ - builder::{types::U, ConstraintSystemBuilder}, + builder::{test_utils::test_circuit, types::U}, unconstrained::unconstrained, }; #[test] fn test_sha256() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = PackedType::::LOG_WIDTH; - let input: [OracleId; 16] = std::array::from_fn(|i| { - unconstrained::(&mut builder, i, log_size).unwrap() - }); - let state_output = super::sha256(&mut builder, input, log_size).unwrap(); - - let witness = builder.witness().unwrap(); - - let input_witneses: [_; 16] = std::array::from_fn(|i| { - witness - .get::(input[i]) - .unwrap() - .as_slice::() - }); - - let output_witneses: [_; 8] = std::array::from_fn(|i| { - witness - .get::(state_output[i]) - .unwrap() - .as_slice::() - }); - - let mut generic_array_input = GenericArray::::default(); - - let n_compressions = input_witneses[0].len(); - - for j in 0..n_compressions { - for i in 0..16 { - for z in 0..4 { - generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; + test_circuit(|builder| { + let log_size = PackedType::::LOG_WIDTH; + let input: [OracleId; 16] = std::array::from_fn(|i| { + unconstrained::(builder, i, log_size).unwrap() + }); + let state_output = super::sha256(builder, input, log_size).unwrap(); + + if let Some(witness) = builder.witness() { + let input_witneses: [_; 16] = std::array::from_fn(|i| { + witness + .get::(input[i]) + .unwrap() + .as_slice::() + }); + + let output_witneses: [_; 8] = std::array::from_fn(|i| { + witness + .get::(state_output[i]) + .unwrap() + .as_slice::() + }); + + let mut generic_array_input = GenericArray::::default(); + + let n_compressions = input_witneses[0].len(); + + for j in 0..n_compressions { + for i in 0..16 { + for z in 0..4 { + generic_array_input[i * 4 + z] = input_witneses[i][j].to_be_bytes()[z]; + } + } + + let mut output = crate::sha256::INIT; + compress256(&mut output, &[generic_array_input]); + + for i in 0..8 { + assert_eq!(output[i], output_witneses[i][j]); + } } } - let mut output = crate::sha256::INIT; - compress256(&mut output, &[generic_array_input]); - - for i in 0..8 { - assert_eq!(output[i], output_witneses[i][j]); - } - } - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + Ok(vec![]) + }) + .unwrap(); } } diff --git a/crates/circuits/src/u32fib.rs b/crates/circuits/src/u32fib.rs index 4c53c389d..59563fb7c 100644 --- a/crates/circuits/src/u32fib.rs +++ b/crates/circuits/src/u32fib.rs @@ -74,20 +74,15 @@ pub fn u32fib( #[cfg(test)] mod tests { - use binius_core::constraint_system::validate::validate_witness; - - use crate::builder::ConstraintSystemBuilder; + use crate::builder::test_utils::test_circuit; #[test] fn test_u32fib() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size_1b = 14; - let _ = super::u32fib(&mut builder, "u32fib", log_size_1b).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let log_size_1b = 14; + let _ = super::u32fib(builder, "u32fib", log_size_1b)?; + Ok(vec![]) + }) + .unwrap(); } } diff --git a/crates/circuits/src/vision.rs b/crates/circuits/src/vision.rs index e8294f244..182566db5 100644 --- a/crates/circuits/src/vision.rs +++ b/crates/circuits/src/vision.rs @@ -462,25 +462,22 @@ where { #[cfg(test)] mod tests { - use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; + use binius_core::oracle::OracleId; use binius_field::BinaryField32b; use super::vision_permutation; - use crate::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; + use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained}; #[test] fn test_vision32b() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let log_size = 8; - let state_in: [OracleId; 24] = std::array::from_fn(|i| { - unconstrained::(&mut builder, format!("p_in[{i}]"), log_size).unwrap() - }); - let _state_out = vision_permutation(&mut builder, log_size, state_in).unwrap(); - - let witness = builder.take_witness().unwrap(); - let constraint_system = builder.build().unwrap(); - let boundaries = vec![]; - validate_witness(&constraint_system, &boundaries, &witness).unwrap(); + test_circuit(|builder| { + let log_size = 8; + let state_in: [OracleId; 24] = std::array::from_fn(|i| { + unconstrained::(builder, format!("p_in[{i}]"), log_size).unwrap() + }); + let _state_out = vision_permutation(builder, log_size, state_in).unwrap(); + Ok(vec![]) + }) + .unwrap(); } } diff --git a/crates/core/src/constraint_system/channel.rs b/crates/core/src/constraint_system/channel.rs index 51bf5d5c3..01891f180 100644 --- a/crates/core/src/constraint_system/channel.rs +++ b/crates/core/src/constraint_system/channel.rs @@ -68,7 +68,7 @@ pub struct Flush { pub multiplicity: u64, } -#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] +#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub struct Boundary { pub values: Vec, pub channel_id: ChannelId, @@ -76,7 +76,7 @@ pub struct Boundary { pub multiplicity: u64, } -#[derive(Debug, Clone, Copy, SerializeBytes, DeserializeBytes)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeBytes, DeserializeBytes)] pub enum FlushDirection { Push, Pull, From 23b3eba244e0c7b3e6e844950b81aede96dd7b68 Mon Sep 17 00:00:00 2001 From: Dmytro Gordon Date: Mon, 24 Feb 2025 13:52:00 +0200 Subject: [PATCH 39/50] Leave only the object-safe version of the `CompositionPoly` trait (#43) --- crates/core/benches/composition_poly.rs | 4 +- crates/core/src/composition/index.rs | 6 +- .../src/composition/product_composition.rs | 4 +- crates/core/src/constraint_system/verify.rs | 12 +- crates/core/src/oracle/composite.rs | 10 +- crates/core/src/oracle/constraint.rs | 4 +- crates/core/src/polynomial/arith_circuit.rs | 88 +--- crates/core/src/polynomial/cached.rs | 272 ------------ crates/core/src/polynomial/mod.rs | 2 - crates/core/src/polynomial/multivariate.rs | 40 +- .../protocols/gkr_gpa/gpa_sumcheck/prove.rs | 12 +- .../generator_exponent/compositions.rs | 4 +- crates/core/src/protocols/sumcheck/common.rs | 4 +- .../src/protocols/sumcheck/front_loaded.rs | 4 +- .../protocols/sumcheck/prove/prover_state.rs | 4 +- .../sumcheck/prove/regular_sumcheck.rs | 12 +- .../protocols/sumcheck/prove/univariate.rs | 12 +- .../src/protocols/sumcheck/prove/zerocheck.rs | 20 +- crates/core/src/protocols/sumcheck/tests.rs | 13 +- .../core/src/protocols/sumcheck/univariate.rs | 26 +- .../sumcheck/univariate_zerocheck.rs | 4 +- crates/core/src/protocols/sumcheck/verify.rs | 10 +- .../core/src/protocols/sumcheck/zerocheck.rs | 14 +- crates/core/src/protocols/test_utils.rs | 8 +- crates/hal/src/backend.rs | 6 +- crates/hal/src/cpu.rs | 4 +- crates/hal/src/sumcheck_round_calculator.rs | 6 +- crates/macros/src/arith_circuit_poly.rs | 393 +----------------- crates/macros/src/composition_poly.rs | 41 +- crates/macros/src/lib.rs | 6 +- crates/macros/tests/arithmetic_circuit.rs | 2 +- crates/math/src/composition_poly.rs | 26 +- 32 files changed, 153 insertions(+), 920 deletions(-) delete mode 100644 crates/core/src/polynomial/cached.rs diff --git a/crates/core/benches/composition_poly.rs b/crates/core/benches/composition_poly.rs index 845c865b4..166dcfe51 100644 --- a/crates/core/benches/composition_poly.rs +++ b/crates/core/benches/composition_poly.rs @@ -8,7 +8,7 @@ use binius_field::{ PackedField, }; use binius_macros::{arith_circuit_poly, composition_poly}; -use binius_math::{ArithExpr as Expr, CompositionPolyOS}; +use binius_math::{ArithExpr as Expr, CompositionPoly}; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use rand::{thread_rng, RngCore}; @@ -28,7 +28,7 @@ fn generate_input_data(mut rng: impl RngCore) -> Vec> { fn evaluate_arith_circuit_poly( query: &[&[P]], - arith_circuit_poly: &impl CompositionPolyOS

, + arith_circuit_poly: &impl CompositionPoly

, ) { for i in 0..BATCH_SIZE { let result = arith_circuit_poly diff --git a/crates/core/src/composition/index.rs b/crates/core/src/composition/index.rs index 0eff19493..2a918439b 100644 --- a/crates/core/src/composition/index.rs +++ b/crates/core/src/composition/index.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use binius_field::{Field, PackedField}; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::bail; use crate::polynomial::Error; @@ -34,7 +34,7 @@ impl IndexComposition { } } -impl, const N: usize> CompositionPolyOS

+impl, const N: usize> CompositionPoly

for IndexComposition { fn n_vars(&self) -> usize { @@ -159,7 +159,7 @@ mod tests { }; assert_eq!( - (&composition as &dyn CompositionPolyOS).expression(), + (&composition as &dyn CompositionPoly).expression(), ArithExpr::Add( Box::new(ArithExpr::Var(1)), Box::new(ArithExpr::Mul( diff --git a/crates/core/src/composition/product_composition.rs b/crates/core/src/composition/product_composition.rs index 245c28653..7fddf3312 100644 --- a/crates/core/src/composition/product_composition.rs +++ b/crates/core/src/composition/product_composition.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::PackedField; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::bail; #[derive(Debug, Default, Copy, Clone)] @@ -17,7 +17,7 @@ impl ProductComposition { } } -impl CompositionPolyOS

for ProductComposition { +impl CompositionPoly

for ProductComposition { fn n_vars(&self) -> usize { self.n_vars() } diff --git a/crates/core/src/constraint_system/verify.rs b/crates/core/src/constraint_system/verify.rs index 57d7822f6..415f909d3 100644 --- a/crates/core/src/constraint_system/verify.rs +++ b/crates/core/src/constraint_system/verify.rs @@ -4,7 +4,7 @@ use std::{cmp::Reverse, iter}; use binius_field::{BinaryField, PackedField, TowerField}; use binius_hash::PseudoCompressionFunction; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::{bail, checked_arithmetics::log2_ceil_usize}; use digest::{core_api::BlockSizeUser, Digest, Output}; use itertools::{izip, multiunzip, Itertools}; @@ -310,7 +310,7 @@ pub fn max_n_vars_and_skip_rounds( ) -> (usize, usize) where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, { let max_n_vars = max_n_vars(zerocheck_claims); @@ -334,7 +334,7 @@ where fn max_n_vars(zerocheck_claims: &[ZerocheckClaim]) -> usize where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, { zerocheck_claims .iter() @@ -567,7 +567,7 @@ pub fn get_flush_dedup_sumcheck_metas( #[derive(Debug)] pub struct FlushSumcheckComposition; -impl CompositionPolyOS

for FlushSumcheckComposition { +impl CompositionPoly

for FlushSumcheckComposition { fn n_vars(&self) -> usize { 2 } @@ -639,7 +639,7 @@ pub fn get_post_flush_sumcheck_eval_claims_without_eq( Ok(evalcheck_claims) } -pub struct DedupSumcheckClaims> { +pub struct DedupSumcheckClaims> { sumcheck_claims: Vec>, gkr_eval_points: Vec>, flush_selectors_unique_by_claim: Vec>, @@ -649,7 +649,7 @@ pub struct DedupSumcheckClaims> #[allow(clippy::type_complexity)] pub fn get_flush_dedup_sumcheck_claims( flush_sumcheck_metas: Vec>, -) -> Result>, Error> { +) -> Result>, Error> { let n_claims = flush_sumcheck_metas.len(); let mut sumcheck_claims = Vec::with_capacity(n_claims); let mut gkr_eval_points = Vec::with_capacity(n_claims); diff --git a/crates/core/src/oracle/composite.rs b/crates/core/src/oracle/composite.rs index 7c904deaf..d90102a22 100644 --- a/crates/core/src/oracle/composite.rs +++ b/crates/core/src/oracle/composite.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use binius_field::TowerField; -use binius_math::CompositionPolyOS; +use binius_math::CompositionPoly; use binius_utils::bail; use crate::oracle::{Error, MultilinearPolyOracle, OracleId}; @@ -12,11 +12,11 @@ use crate::oracle::{Error, MultilinearPolyOracle, OracleId}; pub struct CompositePolyOracle { n_vars: usize, inner: Vec>, - composition: Arc>, + composition: Arc>, } impl CompositePolyOracle { - pub fn new + 'static>( + pub fn new + 'static>( n_vars: usize, inner: Vec>, composition: C, @@ -67,7 +67,7 @@ impl CompositePolyOracle { self.inner.clone() } - pub fn composition(&self) -> Arc> { + pub fn composition(&self) -> Arc> { self.composition.clone() } } @@ -82,7 +82,7 @@ mod tests { #[derive(Clone, Debug)] struct TestByteComposition; - impl CompositionPolyOS for TestByteComposition { + impl CompositionPoly for TestByteComposition { fn n_vars(&self) -> usize { 3 } diff --git a/crates/core/src/oracle/constraint.rs b/crates/core/src/oracle/constraint.rs index 9edd31f5d..df6833945 100644 --- a/crates/core/src/oracle/constraint.rs +++ b/crates/core/src/oracle/constraint.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use binius_field::{Field, TowerField}; use binius_macros::{DeserializeBytes, SerializeBytes}; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::bail; use itertools::Itertools; @@ -13,7 +13,7 @@ use super::{Error, MultilinearOracleSet, MultilinearPolyVariant, OracleId}; /// Composition trait object that can be used to create lists of compositions of differing /// concrete types. -pub type TypeErasedComposition

= Arc>; +pub type TypeErasedComposition

= Arc>; /// Constraint is a type erased composition along with a predicate on its values on the boolean hypercube #[derive(Debug, Clone, SerializeBytes, DeserializeBytes)] diff --git a/crates/core/src/polynomial/arith_circuit.rs b/crates/core/src/polynomial/arith_circuit.rs index b0fe14f6f..2bec2030e 100644 --- a/crates/core/src/polynomial/arith_circuit.rs +++ b/crates/core/src/polynomial/arith_circuit.rs @@ -3,14 +3,12 @@ use std::{fmt::Debug, mem::MaybeUninit, sync::Arc}; use binius_field::{ExtensionField, Field, PackedField, TowerField}; -use binius_math::{ArithExpr, CompositionPoly, CompositionPolyOS, Error}; +use binius_math::{ArithExpr, CompositionPoly, Error}; use stackalloc::{ helpers::{slice_assume_init, slice_assume_init_mut}, stackalloc_uninit, }; -use super::MultivariatePoly; - /// Convert the expression to a sequence of arithmetic operations that can be evaluated in sequence. fn circuit_steps_for_expr( expr: &ArithExpr, @@ -119,9 +117,9 @@ enum CircuitStep { /// Describes polynomial evaluations using a directed acyclic graph of expressions. /// -/// This is meant as an alternative to a hard-coded CompositionPolyOS. +/// This is meant as an alternative to a hard-coded CompositionPoly. /// -/// The advantage over a hard coded CompositionPolyOS is that this can be constructed and manipulated dynamically at runtime +/// The advantage over a hard coded CompositionPoly is that this can be constructed and manipulated dynamically at runtime /// and the object representing different polnomials can be stored in a homogeneous collection. #[derive(Debug, Clone)] pub struct ArithCircuitPoly { @@ -177,7 +175,9 @@ impl ArithCircuitPoly { } } -impl CompositionPoly for ArithCircuitPoly { +impl>> CompositionPoly

+ for ArithCircuitPoly +{ fn degree(&self) -> usize { self.degree } @@ -190,11 +190,11 @@ impl CompositionPoly for ArithCircuitPoly { self.tower_level } - fn expression>(&self) -> ArithExpr { + fn expression(&self) -> ArithExpr { self.expr.convert_field() } - fn evaluate>>(&self, query: &[P]) -> Result { + fn evaluate(&self, query: &[P]) -> Result { if query.len() != self.n_vars { return Err(Error::IncorrectQuerySize { expected: self.n_vars, @@ -258,11 +258,7 @@ impl CompositionPoly for ArithCircuitPoly { }) } - fn batch_evaluate>>( - &self, - batch_query: &[&[P]], - evals: &mut [P], - ) -> Result<(), Error> { + fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> { let row_len = evals.len(); if batch_query.iter().any(|row| row.len() != row_len) { return Err(Error::BatchEvaluateSizeMismatch); @@ -368,52 +364,6 @@ impl CompositionPoly for ArithCircuitPoly { } } -impl>> CompositionPolyOS

- for ArithCircuitPoly -{ - fn degree(&self) -> usize { - CompositionPoly::degree(self) - } - - fn n_vars(&self) -> usize { - CompositionPoly::n_vars(self) - } - - fn expression(&self) -> ArithExpr { - self.expr.convert_field() - } - - fn binary_tower_level(&self) -> usize { - CompositionPoly::binary_tower_level(self) - } - - fn evaluate(&self, query: &[P]) -> Result { - CompositionPoly::evaluate(self, query) - } - - fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> { - CompositionPoly::batch_evaluate(self, batch_query, evals) - } -} - -impl MultivariatePoly for ArithCircuitPoly { - fn degree(&self) -> usize { - CompositionPoly::degree(&self) - } - - fn n_vars(&self) -> usize { - CompositionPoly::n_vars(&self) - } - - fn binary_tower_level(&self) -> usize { - CompositionPoly::binary_tower_level(&self) - } - - fn evaluate(&self, query: &[F]) -> Result { - CompositionPoly::evaluate(&self, query).map_err(|e| e.into()) - } -} - /// Apply a binary operation to two arguments and store the result in `current_evals`. /// `op` must be a function that takes two arguments and initialized the result with the third argument. fn apply_binary_op>>( @@ -485,7 +435,7 @@ mod tests { use binius_field::{ BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedField, TowerField, }; - use binius_math::CompositionPolyOS; + use binius_math::CompositionPoly; use binius_utils::felts; use super::*; @@ -498,7 +448,7 @@ mod tests { let expr = ArithExpr::Const(F::new(123)); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); assert_eq!(typed_circuit.degree(), 0); assert_eq!(typed_circuit.n_vars(), 0); @@ -519,7 +469,7 @@ mod tests { let expr = ArithExpr::Var(0); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), 0); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -547,7 +497,7 @@ mod tests { let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), 3); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -567,7 +517,7 @@ mod tests { let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), 3); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -593,7 +543,7 @@ mod tests { let expr = ArithExpr::Var(0).pow(13); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), 0); assert_eq!(typed_circuit.degree(), 13); assert_eq!(typed_circuit.n_vars(), 1); @@ -619,7 +569,7 @@ mod tests { let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123))); let circuit = ArithCircuitPoly::::new(expr); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), 3); assert_eq!(typed_circuit.degree(), 3); assert_eq!(typed_circuit.n_vars(), 2); @@ -681,7 +631,7 @@ mod tests { let circuit = ArithCircuitPoly::::new(expr); assert_eq!(circuit.steps.len(), 2); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 2); @@ -740,7 +690,7 @@ mod tests { let circuit = ArithCircuitPoly::::new(expr); assert_eq!(circuit.steps.len(), 1); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), 1); assert_eq!(typed_circuit.degree(), 1); assert_eq!(typed_circuit.n_vars(), 1); @@ -767,7 +717,7 @@ mod tests { let circuit = ArithCircuitPoly::::new(expr); assert_eq!(circuit.steps.len(), 5); - let typed_circuit: &dyn CompositionPolyOS

= &circuit; + let typed_circuit: &dyn CompositionPoly

= &circuit; assert_eq!(typed_circuit.binary_tower_level(), 0); assert_eq!(typed_circuit.degree(), 24); assert_eq!(typed_circuit.n_vars(), 1); diff --git a/crates/core/src/polynomial/cached.rs b/crates/core/src/polynomial/cached.rs deleted file mode 100644 index 181257e5f..000000000 --- a/crates/core/src/polynomial/cached.rs +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2024-2025 Irreducible Inc. - -use std::{ - any::{Any, TypeId}, - collections::HashMap, - fmt::Debug, - marker::PhantomData, -}; - -use binius_field::{ExtensionField, Field, PackedField}; -use binius_math::{ArithExpr, CompositionPoly, CompositionPolyOS, Error}; - -/// Cached composition poly wrapper. -/// -/// It stores the efficient implementations of the composition poly for some known set of packed field types. -/// We are usually able to use this when the inner poly is constructed with a macro for the known field and packed field types. -#[derive(Default, Debug)] -pub struct CachedPoly> { - inner: Inner, - cache: PackedFieldCache, -} - -impl> CachedPoly { - /// Create a new cached polynomial with the given inner polynomial. - pub fn new(inner: Inner) -> Self { - Self { - inner, - cache: Default::default(), - } - } - - /// Register efficient implementations for the `P` packed field type in the cache. - pub fn register>>( - &mut self, - composition: impl CompositionPolyOS

+ 'static, - ) { - self.cache.register(composition); - } -} - -impl> CompositionPoly for CachedPoly { - fn n_vars(&self) -> usize { - self.inner.n_vars() - } - - fn degree(&self) -> usize { - self.inner.degree() - } - - fn binary_tower_level(&self) -> usize { - self.inner.binary_tower_level() - } - - fn expression>(&self) -> ArithExpr { - self.inner.expression() - } - - fn evaluate>>(&self, query: &[P]) -> Result { - if let Some(result) = self.cache.try_evaluate(query) { - result - } else { - self.inner.evaluate(query) - } - } - - fn batch_evaluate>>( - &self, - batch_query: &[&[P]], - evals: &mut [P], - ) -> Result<(), Error> { - if let Some(result) = self.cache.try_batch_evaluate(batch_query, evals) { - result - } else { - self.inner.batch_evaluate(batch_query, evals) - } - } -} - -impl, P: PackedField>> - CompositionPolyOS

for CachedPoly -{ - fn binary_tower_level(&self) -> usize { - CompositionPoly::binary_tower_level(&self) - } - - fn n_vars(&self) -> usize { - CompositionPoly::n_vars(&self) - } - - fn degree(&self) -> usize { - CompositionPoly::degree(&self) - } - - fn expression(&self) -> ArithExpr { - CompositionPoly::expression(&self) - } - - fn evaluate(&self, query: &[P]) -> Result { - CompositionPoly::evaluate(&self, query) - } - - fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> { - CompositionPoly::batch_evaluate(&self, batch_query, evals) - } -} - -#[derive(Default)] -struct PackedFieldCache { - /// Map from the packed field type 'P to the efficient implementation of the composition polynomial - /// with actual type `Box>`. - entries: HashMap>, - _pd: PhantomData, -} - -impl PackedFieldCache { - /// Register efficient implementations for the `P` packed field type in the cache. - fn register>>( - &mut self, - composition: impl CompositionPolyOS

+ 'static, - ) { - let boxed_composition = Box::new(composition) as Box>; - self.entries - .insert(TypeId::of::

(), Box::new(boxed_composition) as Box); - } - - /// Try to evaluate the expression using the efficient implementation for the `P` packed field type. - /// If no implementation is found, return None. - fn try_evaluate>>( - &self, - query: &[P], - ) -> Option> { - if let Some(entry) = self.entries.get(&TypeId::of::

()) { - let entry = entry - .downcast_ref::>>() - .expect("cast must succeed"); - Some(entry.evaluate(query)) - } else { - None - } - } - - /// Try to batch evaluate the expression using the efficient implementation for the `P` packed field type. - /// If no implementation is found, return None. - fn try_batch_evaluate>>( - &self, - batch_query: &[&[P]], - evals: &mut [P], - ) -> Option> { - if let Some(entry) = self.entries.get(&TypeId::of::

()) { - let entry = entry - .downcast_ref::>>() - .expect("cast must succeed"); - Some(entry.batch_evaluate(batch_query, evals)) - } else { - None - } - } -} - -impl Debug for PackedFieldCache { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PackedFieldCache") - .field("cached_implementations", &self.entries.len()) - .finish() - } -} - -#[cfg(test)] -mod tests { - use std::iter::zip; - - use binius_field::{BinaryField8b, ExtensionField, PackedBinaryField16x8b, PackedField}; - use binius_math::{ArithExpr, CompositionPolyOS}; - - use super::*; - use crate::polynomial::{cached::CachedPoly, ArithCircuitPoly}; - - fn ensure_equal_batch_eval_results( - circuit_1: &impl CompositionPolyOS

, - circuit_2: &impl CompositionPolyOS

, - batch_query: &[&[P]], - ) { - for row in 0..batch_query[0].len() { - let query = batch_query.iter().map(|q| q[row]).collect::>(); - - assert_eq!(circuit_1.evaluate(&query).unwrap(), circuit_2.evaluate(&query).unwrap()); - } - - let result_1 = { - let mut uncached_evals = vec![P::zero(); batch_query[0].len()]; - circuit_1 - .batch_evaluate(batch_query, &mut uncached_evals) - .unwrap(); - uncached_evals - }; - - let result_2 = { - let mut cached_evals = vec![P::zero(); batch_query[0].len()]; - circuit_2 - .batch_evaluate(batch_query, &mut cached_evals) - .unwrap(); - cached_evals - }; - - assert_eq!(result_1, result_2); - } - - #[derive(Debug, Copy, Clone)] - struct AddComposition; - - impl>> CompositionPolyOS

- for AddComposition - { - fn binary_tower_level(&self) -> usize { - 0 - } - - fn n_vars(&self) -> usize { - 1 - } - - fn degree(&self) -> usize { - 1 - } - - fn expression(&self) -> ArithExpr { - ArithExpr::Const(BinaryField8b::new(123).into()) + ArithExpr::Var(0) - } - - fn evaluate(&self, query: &[P]) -> Result { - Ok(query[0] + P::broadcast(BinaryField8b::new(123).into())) - } - - fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> { - for (input, output) in zip(batch_query[0], evals) { - *output = *input + P::broadcast(BinaryField8b::new(123).into()); - } - - Ok(()) - } - } - - #[test] - fn test_cached_impl() { - let expr = ArithExpr::Const(BinaryField8b::new(123)) + ArithExpr::Var(0); - let circuit = ArithCircuitPoly::::new(expr); - - let composition = AddComposition; - - let mut cached_circuit = CachedPoly::new(circuit.clone()); - cached_circuit.register::(composition); - - let batch_query = [(0..255).map(BinaryField8b::new).collect::>()]; - let batch_query = batch_query.iter().map(|q| q.as_slice()).collect::>(); - ensure_equal_batch_eval_results(&circuit, &cached_circuit, &batch_query); - } - - #[test] - fn test_uncached_impl() { - let expr = ArithExpr::Const(BinaryField8b::new(123)) + ArithExpr::Var(0); - let circuit = ArithCircuitPoly::::new(expr); - - let composition = AddComposition; - - let mut cached_circuit = CachedPoly::new(circuit.clone()); - cached_circuit.register::(composition); - - let batch_query = [(0..255).map(BinaryField8b::new).collect::>()]; - let batch_query = batch_query.iter().map(|q| q.as_slice()).collect::>(); - ensure_equal_batch_eval_results(&circuit, &cached_circuit, &batch_query); - } -} diff --git a/crates/core/src/polynomial/mod.rs b/crates/core/src/polynomial/mod.rs index 1c21083ad..7f8e4d762 100644 --- a/crates/core/src/polynomial/mod.rs +++ b/crates/core/src/polynomial/mod.rs @@ -1,7 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. mod arith_circuit; -mod cached; mod error; mod multivariate; #[allow(dead_code)] @@ -9,6 +8,5 @@ mod multivariate; pub mod test_utils; pub use arith_circuit::*; -pub use cached::*; pub use error::*; pub use multivariate::*; diff --git a/crates/core/src/polynomial/multivariate.rs b/crates/core/src/polynomial/multivariate.rs index 8a4ca9c4b..e670f7aca 100644 --- a/crates/core/src/polynomial/multivariate.rs +++ b/crates/core/src/polynomial/multivariate.rs @@ -4,7 +4,7 @@ use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sy use binius_field::{Field, PackedField}; use binius_math::{ - ArithExpr, CompositionPolyOS, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef, + ArithExpr, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef, }; use binius_utils::{bail, SerializationError, SerializationMode}; use bytes::BufMut; @@ -15,8 +15,8 @@ use super::error::Error; /// A multivariate polynomial over a binary tower field. /// -/// The definition `MultivariatePoly` is nearly identical to that of [`CompositionPolyOS`], except that -/// `MultivariatePoly` is _object safe_, whereas `CompositionPolyOS` is not. +/// The definition `MultivariatePoly` is nearly identical to that of [`CompositionPoly`], except that +/// `MultivariatePoly` is _object safe_, whereas `CompositionPoly` is not. pub trait MultivariatePoly

: Debug + Send + Sync { /// The number of variables. fn n_vars(&self) -> usize; @@ -46,7 +46,7 @@ pub trait MultivariatePoly

: Debug + Send + Sync { #[derive(Clone, Debug)] pub struct IdentityCompositionPoly; -impl CompositionPolyOS

for IdentityCompositionPoly { +impl CompositionPoly

for IdentityCompositionPoly { fn n_vars(&self) -> usize { 1 } @@ -71,7 +71,7 @@ impl CompositionPolyOS

for IdentityCompositionPoly { } } -/// An adapter that constructs a [`CompositionPolyOS`] for a field from a [`CompositionPolyOS`] for a +/// An adapter that constructs a [`CompositionPoly`] for a field from a [`CompositionPoly`] for a /// packing of that field. /// /// This is not intended for use in performance-critical code sections. @@ -84,7 +84,7 @@ pub struct CompositionScalarAdapter { impl CompositionScalarAdapter where P: PackedField, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { pub const fn new(composition: Composition) -> Self { Self { @@ -94,11 +94,11 @@ where } } -impl CompositionPolyOS for CompositionScalarAdapter +impl CompositionPoly for CompositionScalarAdapter where F: Field, P: PackedField, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn n_vars(&self) -> usize { self.composition.n_vars() @@ -153,7 +153,7 @@ where impl MultilinearComposite where P: PackedField, - C: CompositionPolyOS

, + C: CompositionPoly

, M: MultilinearPoly

, { pub fn new(n_vars: usize, composition: C, multilinears: Vec) -> Result { @@ -219,12 +219,10 @@ where impl MultilinearComposite where P: PackedField, - C: CompositionPolyOS

+ 'static, + C: CompositionPoly

+ 'static, M: MultilinearPoly

, { - pub fn to_arc_dyn_composition( - self, - ) -> MultilinearComposite>, M> { + pub fn to_arc_dyn_composition(self) -> MultilinearComposite>, M> { MultilinearComposite { n_vars: self.n_vars, composition: Arc::new(self.composition), @@ -279,7 +277,7 @@ where /// for two distinct multivariate polynomials f and g. /// /// NOTE: THIS IS NOT ADVERSARIALLY COLLISION RESISTANT, COLLISIONS CAN BE MANUFACTURED EASILY -pub fn composition_hash>(composition: &C) -> P { +pub fn composition_hash>(composition: &C) -> P { let mut rng = StdRng::from_seed([0; 32]); let random_point = repeat_with(|| P::random(&mut rng)) @@ -293,7 +291,7 @@ pub fn composition_hash>(composition: &C #[cfg(test)] mod tests { - use binius_math::{ArithExpr, CompositionPolyOS}; + use binius_math::{ArithExpr, CompositionPoly}; #[test] fn test_fingerprint_same_32b() { @@ -303,7 +301,7 @@ mod tests { let expr = (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -320,7 +318,7 @@ mod tests { let expr = ArithExpr::Var(0) + ArithExpr::Var(1); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -338,7 +336,7 @@ mod tests { let expr = (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -354,7 +352,7 @@ mod tests { let expr = ArithExpr::Var(0) + ArithExpr::Var(1); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -372,7 +370,7 @@ mod tests { let expr = (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; @@ -388,7 +386,7 @@ mod tests { let expr = ArithExpr::Var(0) + ArithExpr::Var(1); let circuit_poly = &crate::polynomial::ArithCircuitPoly::::new(expr) - as &dyn CompositionPolyOS; + as &dyn CompositionPoly; let product_composition = crate::composition::ProductComposition::<2> {}; diff --git a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs index b7997ef63..ca11ea7ff 100644 --- a/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs +++ b/crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs @@ -4,9 +4,7 @@ use std::ops::Range; use binius_field::{util::eq, Field, PackedExtension, PackedField, PackedFieldIndexable}; use binius_hal::{ComputationBackend, SumcheckEvaluator}; -use binius_math::{ - CompositionPolyOS, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly, -}; +use binius_math::{CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly}; use binius_maybe_rayon::prelude::*; use binius_utils::bail; use itertools::izip; @@ -49,7 +47,7 @@ where F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -194,7 +192,7 @@ where P: PackedFieldIndexable + PackedExtension + PackedExtension, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -289,7 +287,7 @@ where F: Field, P: PackedField + PackedExtension + PackedExtension, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn eval_point_indices(&self) -> Range { // By definition of grand product GKR circuit, the composition evaluation is a multilinear @@ -343,7 +341,7 @@ where F: Field, P: PackedField + PackedExtension, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { #[instrument( skip_all, diff --git a/crates/core/src/protocols/gkr_int_mul/generator_exponent/compositions.rs b/crates/core/src/protocols/gkr_int_mul/generator_exponent/compositions.rs index a15251587..35cbebf91 100644 --- a/crates/core/src/protocols/gkr_int_mul/generator_exponent/compositions.rs +++ b/crates/core/src/protocols/gkr_int_mul/generator_exponent/compositions.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{Field, PackedField}; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::bail; #[derive(Debug)] @@ -12,7 +12,7 @@ where pub generator_power_constant: F, } -impl CompositionPolyOS

for MultiplyOrDont { +impl CompositionPoly

for MultiplyOrDont { fn n_vars(&self) -> usize { 2 } diff --git a/crates/core/src/protocols/sumcheck/common.rs b/crates/core/src/protocols/sumcheck/common.rs index c0d1e2cc7..8a9bb6f2d 100644 --- a/crates/core/src/protocols/sumcheck/common.rs +++ b/crates/core/src/protocols/sumcheck/common.rs @@ -6,7 +6,7 @@ use binius_field::{ util::{inner_product_unchecked, powers}, ExtensionField, Field, PackedField, }; -use binius_math::{CompositionPolyOS, MultilinearPoly}; +use binius_math::{CompositionPoly, MultilinearPoly}; use binius_utils::bail; use getset::{CopyGetters, Getters}; use tracing::instrument; @@ -45,7 +45,7 @@ pub struct SumcheckClaim { impl SumcheckClaim where - Composition: CompositionPolyOS, + Composition: CompositionPoly, { /// Constructs a new sumcheck claim. /// diff --git a/crates/core/src/protocols/sumcheck/front_loaded.rs b/crates/core/src/protocols/sumcheck/front_loaded.rs index bdba716ad..3b1c4c937 100644 --- a/crates/core/src/protocols/sumcheck/front_loaded.rs +++ b/crates/core/src/protocols/sumcheck/front_loaded.rs @@ -3,7 +3,7 @@ use std::{cmp, cmp::Ordering, collections::VecDeque, iter}; use binius_field::{Field, TowerField}; -use binius_math::{evaluate_univariate, CompositionPolyOS}; +use binius_math::{evaluate_univariate, CompositionPoly}; use binius_utils::sorting::is_sorted_ascending; use bytes::Buf; @@ -60,7 +60,7 @@ pub struct BatchVerifier { impl BatchVerifier where F: TowerField, - C: CompositionPolyOS + Clone, + C: CompositionPoly + Clone, { /// Constructs a new verifier for the front-loaded batched sumcheck. /// diff --git a/crates/core/src/protocols/sumcheck/prove/prover_state.rs b/crates/core/src/protocols/sumcheck/prove/prover_state.rs index 185d32f84..d88272f11 100644 --- a/crates/core/src/protocols/sumcheck/prove/prover_state.rs +++ b/crates/core/src/protocols/sumcheck/prove/prover_state.rs @@ -8,7 +8,7 @@ use std::{ use binius_field::{util::powers, Field, PackedExtension, PackedField}; use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear}; use binius_math::{ - evaluate_univariate, CompositionPolyOS, MLEDirectAdapter, MultilinearPoly, MultilinearQuery, + evaluate_univariate, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQuery, }; use binius_maybe_rayon::prelude::*; use binius_utils::bail; @@ -253,7 +253,7 @@ where ) -> Result>, Error> where Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { Ok(self.backend.sumcheck_compute_round_evals( self.n_vars, diff --git a/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs b/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs index 9ecb400a9..2a6ccb3f7 100644 --- a/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/regular_sumcheck.rs @@ -4,9 +4,7 @@ use std::{marker::PhantomData, ops::Range}; use binius_field::{Field, PackedExtension, PackedField}; use binius_hal::{ComputationBackend, SumcheckEvaluator}; -use binius_math::{ - CompositionPolyOS, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly, -}; +use binius_math::{CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly}; use binius_maybe_rayon::prelude::*; use binius_utils::bail; use itertools::izip; @@ -31,7 +29,7 @@ where F: Field, P: PackedField, M: MultilinearPoly

+ Send + Sync, - Composition: CompositionPolyOS

+ 'a, + Composition: CompositionPoly

+ 'a, { let n_vars = multilinears .first() @@ -85,7 +83,7 @@ where F: Field, FDomain: Field, P: PackedField + PackedExtension + PackedExtension, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -169,7 +167,7 @@ where F: Field, FDomain: Field, P: PackedField + PackedExtension + PackedExtension, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -219,7 +217,7 @@ where F: Field, P: PackedField + PackedExtension + PackedExtension, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn eval_point_indices(&self) -> Range { // NB: We skip evaluation of $r(X)$ at $X = 0$ as it is derivable from the diff --git a/crates/core/src/protocols/sumcheck/prove/univariate.rs b/crates/core/src/protocols/sumcheck/prove/univariate.rs index 91995a1b7..fbcf3efe3 100644 --- a/crates/core/src/protocols/sumcheck/prove/univariate.rs +++ b/crates/core/src/protocols/sumcheck/prove/univariate.rs @@ -9,7 +9,7 @@ use binius_field::{ }; use binius_hal::{ComputationBackend, ComputationBackendExt}; use binius_math::{ - CompositionPolyOS, Error as MathError, EvaluationDomainFactory, + CompositionPoly, Error as MathError, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEDirectAdapter, MultilinearPoly, }; use binius_maybe_rayon::prelude::*; @@ -329,7 +329,7 @@ where P: PackedFieldIndexable + PackedExtension + PackedExtension, - Composition: CompositionPolyOS>, + Composition: CompositionPoly>, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -744,7 +744,7 @@ mod tests { }; use binius_hal::make_portable_backend; use binius_math::{ - CompositionPolyOS, DefaultEvaluationDomainFactory, EvaluationDomainFactory, MultilinearPoly, + CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, MultilinearPoly, }; use binius_ntt::SingleThreadedNTT; use rand::{prelude::StdRng, SeedableRng}; @@ -879,11 +879,11 @@ mod tests { let compositions = [ Arc::new(IndexComposition::new(9, [0, 1], ProductComposition::<2> {}).unwrap()) - as Arc>>, + as Arc>>, Arc::new(IndexComposition::new(9, [2, 3, 4], ProductComposition::<3> {}).unwrap()) - as Arc>>, + as Arc>>, Arc::new(IndexComposition::new(9, [5, 6, 7, 8], ProductComposition::<4> {}).unwrap()) - as Arc>>, + as Arc>>, ]; let backend = make_portable_backend(); diff --git a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs index 0e11eb8d0..bb9867c61 100644 --- a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs @@ -9,7 +9,7 @@ use binius_field::{ }; use binius_hal::{ComputationBackend, SumcheckEvaluator}; use binius_math::{ - CompositionPolyOS, EvaluationDomainFactory, InterpolationDomain, MLEDirectAdapter, + CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MLEDirectAdapter, MultilinearPoly, MultilinearQuery, }; use binius_maybe_rayon::prelude::*; @@ -47,7 +47,7 @@ where F: Field, P: PackedField, M: MultilinearPoly

+ Send + Sync, - Composition: CompositionPolyOS

+ 'a, + Composition: CompositionPoly

+ 'a, { let n_vars = multilinears .first() @@ -118,8 +118,8 @@ where + PackedExtension + PackedExtension + PackedExtension, - CompositionBase: CompositionPolyOS<

>::PackedSubfield>, - Composition: CompositionPolyOS

, + CompositionBase: CompositionPoly<

>::PackedSubfield>, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync + 'm, Backend: ComputationBackend, { @@ -252,8 +252,8 @@ where + PackedExtension + PackedExtension + PackedExtension, - CompositionBase: CompositionPolyOS> + 'static, - Composition: CompositionPolyOS

+ 'static, + CompositionBase: CompositionPoly> + 'static, + Composition: CompositionPoly

+ 'static, M: MultilinearPoly

+ Send + Sync + 'm, Backend: ComputationBackend, { @@ -450,7 +450,7 @@ where F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -540,7 +540,7 @@ where F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync, Backend: ComputationBackend, { @@ -629,7 +629,7 @@ impl SumcheckEvaluator where P: PackedField>, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn eval_point_indices(&self) -> Range { // In the first round of zerocheck we can uniquely determine the degree d @@ -717,7 +717,7 @@ impl SumcheckEvaluator where P: PackedField>, FDomain: Field, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn eval_point_indices(&self) -> Range { // We can uniquely derive the degree d univariate round polynomial r from evaluations at diff --git a/crates/core/src/protocols/sumcheck/tests.rs b/crates/core/src/protocols/sumcheck/tests.rs index e798de0fb..a77209641 100644 --- a/crates/core/src/protocols/sumcheck/tests.rs +++ b/crates/core/src/protocols/sumcheck/tests.rs @@ -16,7 +16,7 @@ use binius_field::{ }; use binius_hal::{make_portable_backend, ComputationBackend, ComputationBackendExt}; use binius_math::{ - ArithExpr, CompositionPolyOS, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, + ArithExpr, CompositionPoly, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly, MultilinearQuery, }; use binius_maybe_rayon::{current_num_threads, prelude::*}; @@ -50,7 +50,7 @@ struct PowerComposition { exponent: usize, } -impl CompositionPolyOS

for PowerComposition { +impl CompositionPoly

for PowerComposition { fn n_vars(&self) -> usize { 1 } @@ -103,7 +103,7 @@ fn compute_composite_sum( where P: PackedField, M: MultilinearPoly

+ Send + Sync, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { let n_vars = multilinears .first() @@ -263,7 +263,7 @@ fn make_test_sumcheck<'a, F, FDomain, P, PExt, Backend>( backend: &'a Backend, ) -> ( Vec>, - SumcheckClaim + Clone + 'static>, + SumcheckClaim + Clone + 'static>, impl SumcheckProver + 'a, ) where @@ -287,10 +287,9 @@ where .map(MLEEmbeddingAdapter::<_, PExt, _>::from) .collect::>(); - let mut claim_composite_sums = - Vec::>>>::new(); + let mut claim_composite_sums = Vec::>>>::new(); let mut prover_composite_sums = - Vec::>>>::new(); + Vec::>>>::new(); if max_degree >= 1 { let identity_composition = diff --git a/crates/core/src/protocols/sumcheck/univariate.rs b/crates/core/src/protocols/sumcheck/univariate.rs index d433b5d68..df833eef4 100644 --- a/crates/core/src/protocols/sumcheck/univariate.rs +++ b/crates/core/src/protocols/sumcheck/univariate.rs @@ -230,7 +230,7 @@ mod tests { }; use binius_hal::ComputationBackend; use binius_math::{ - CompositionPolyOS, DefaultEvaluationDomainFactory, EvaluationDomainFactory, + CompositionPoly, DefaultEvaluationDomainFactory, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory, MultilinearPoly, }; use groestl_crypto::Groestl256; @@ -437,31 +437,31 @@ mod tests { let prover_compositions = [ ( "pair".into(), - pair.clone() as Arc>>, - pair.clone() as Arc>>, + pair.clone() as Arc>>, + pair.clone() as Arc>>, ), ( "triple".into(), - triple.clone() as Arc>>, - triple.clone() as Arc>>, + triple.clone() as Arc>>, + triple.clone() as Arc>>, ), ( "quad".into(), - quad.clone() as Arc>>, - quad.clone() as Arc>>, + quad.clone() as Arc>>, + quad.clone() as Arc>>, ), ]; let prover_adapter_compositions = [ - CompositionScalarAdapter::new(pair.clone() as Arc>), - CompositionScalarAdapter::new(triple.clone() as Arc>), - CompositionScalarAdapter::new(quad.clone() as Arc>), + CompositionScalarAdapter::new(pair.clone() as Arc>), + CompositionScalarAdapter::new(triple.clone() as Arc>), + CompositionScalarAdapter::new(quad.clone() as Arc>), ]; let verifier_compositions = [ - pair as Arc>, - triple as Arc>, - quad as Arc>, + pair as Arc>, + triple as Arc>, + quad as Arc>, ]; for skip_rounds in 0..=max_n_vars { diff --git a/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs b/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs index eb14a3bba..75756851a 100644 --- a/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/univariate_zerocheck.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{util::inner_product_unchecked, Field, TowerField}; -use binius_math::{CompositionPolyOS, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory}; +use binius_math::{CompositionPoly, EvaluationDomainFactory, IsomorphicEvaluationDomainFactory}; use binius_utils::{bail, sorting::is_sorted_ascending}; use tracing::instrument; @@ -50,7 +50,7 @@ pub fn batch_verify_zerocheck_univariate_round( ) -> Result, Error> where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, Challenger_: Challenger, { // Check that the claims are in descending order by n_vars diff --git a/crates/core/src/protocols/sumcheck/verify.rs b/crates/core/src/protocols/sumcheck/verify.rs index b06de94ac..fe778c1a2 100644 --- a/crates/core/src/protocols/sumcheck/verify.rs +++ b/crates/core/src/protocols/sumcheck/verify.rs @@ -1,7 +1,7 @@ // Copyright 2024-2025 Irreducible Inc. use binius_field::{Field, TowerField}; -use binius_math::{evaluate_univariate, CompositionPolyOS}; +use binius_math::{evaluate_univariate, CompositionPoly}; use binius_utils::{bail, sorting::is_sorted_ascending}; use itertools::izip; use tracing::instrument; @@ -34,7 +34,7 @@ pub fn batch_verify( ) -> Result, Error> where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, Challenger_: Challenger, { let start = BatchVerifyStart { @@ -69,7 +69,7 @@ pub fn batch_verify_with_start( ) -> Result, Error> where F: TowerField, - Composition: CompositionPolyOS, + Composition: CompositionPoly, Challenger_: Challenger, { let BatchVerifyStart { @@ -177,7 +177,7 @@ pub fn compute_expected_batch_composite_evaluation_single_claim Result where - Composition: CompositionPolyOS, + Composition: CompositionPoly, { let composite_evals = claim .composite_sums() @@ -193,7 +193,7 @@ fn compute_expected_batch_composite_evaluation_multi_claim], ) -> Result where - Composition: CompositionPolyOS, + Composition: CompositionPoly, { izip!(batch_coeffs, claims, multilinear_evals.iter()) .map(|(batch_coeff, claim, multilinear_evals)| { diff --git a/crates/core/src/protocols/sumcheck/zerocheck.rs b/crates/core/src/protocols/sumcheck/zerocheck.rs index bdf4cd8e7..5d47e28bd 100644 --- a/crates/core/src/protocols/sumcheck/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/zerocheck.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use binius_field::{util::eq, Field, PackedField}; -use binius_math::{ArithExpr, CompositionPolyOS}; +use binius_math::{ArithExpr, CompositionPoly}; use binius_utils::{bail, sorting::is_sorted_ascending}; use getset::CopyGetters; @@ -22,7 +22,7 @@ pub struct ZerocheckClaim { impl ZerocheckClaim where - Composition: CompositionPolyOS, + Composition: CompositionPoly, { pub fn new( n_vars: usize, @@ -60,7 +60,7 @@ where } /// Requirement: zerocheck challenges have been sampled before this is called -pub fn reduce_to_sumchecks>( +pub fn reduce_to_sumchecks>( claims: &[ZerocheckClaim], ) -> Result>>, Error> { // Check that the claims are in descending order by n_vars @@ -100,7 +100,7 @@ pub fn reduce_to_sumchecks>( /// /// Note that due to univariatization of some rounds the number of challenges may be less than /// the maximum number of variables among claims. -pub fn verify_sumcheck_outputs>( +pub fn verify_sumcheck_outputs>( claims: &[ZerocheckClaim], zerocheck_challenges: &[F], sumcheck_output: BatchSumcheckOutput, @@ -158,10 +158,10 @@ pub struct ExtraProduct { pub inner: Composition, } -impl CompositionPolyOS

for ExtraProduct +impl CompositionPoly

for ExtraProduct where P: PackedField, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { fn n_vars(&self) -> usize { self.inner.n_vars() + 1 @@ -239,7 +239,7 @@ mod tests { F: Field, FDomain: Field, P: PackedFieldIndexable + PackedExtension + RepackedExtension

, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, M: MultilinearPoly

+ Send + Sync + 'static, Backend: ComputationBackend, { diff --git a/crates/core/src/protocols/test_utils.rs b/crates/core/src/protocols/test_utils.rs index 892a5c276..bfa4a0b47 100644 --- a/crates/core/src/protocols/test_utils.rs +++ b/crates/core/src/protocols/test_utils.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use binius_field::{ExtensionField, Field, PackedField}; -use binius_math::{ArithExpr, CompositionPolyOS, MLEEmbeddingAdapter, MultilinearExtension}; +use binius_math::{ArithExpr, CompositionPoly, MLEEmbeddingAdapter, MultilinearExtension}; use rand::Rng; use crate::polynomial::Error as PolynomialError; @@ -19,10 +19,10 @@ impl AddOneComposition { } } -impl CompositionPolyOS

for AddOneComposition +impl CompositionPoly

for AddOneComposition where P: PackedField, - Inner: CompositionPolyOS

, + Inner: CompositionPoly

, { fn n_vars(&self) -> usize { self.inner.n_vars() @@ -56,7 +56,7 @@ impl TestProductComposition { } } -impl

CompositionPolyOS

for TestProductComposition +impl

CompositionPoly

for TestProductComposition where P: PackedField, { diff --git a/crates/hal/src/backend.rs b/crates/hal/src/backend.rs index b91dc9b35..8d9dac1d2 100644 --- a/crates/hal/src/backend.rs +++ b/crates/hal/src/backend.rs @@ -7,7 +7,7 @@ use std::{ use binius_field::{Field, PackedExtension, PackedField}; use binius_math::{ - CompositionPolyOS, MultilinearExtension, MultilinearPoly, MultilinearQuery, MultilinearQueryRef, + CompositionPoly, MultilinearExtension, MultilinearPoly, MultilinearQuery, MultilinearQueryRef, }; use binius_maybe_rayon::iter::FromParallelIterator; use tracing::instrument; @@ -56,7 +56,7 @@ pub trait ComputationBackend: Send + Sync + Debug { P: PackedExtension, M: MultilinearPoly

+ Send + Sync, Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

; + Composition: CompositionPoly

; /// Partially evaluate the polynomial with assignment to the high-indexed variables. fn evaluate_partial_high( @@ -98,7 +98,7 @@ where P: PackedExtension, M: MultilinearPoly

+ Send + Sync, Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { T::sumcheck_compute_round_evals( self, diff --git a/crates/hal/src/cpu.rs b/crates/hal/src/cpu.rs index ef19db96b..acd17b225 100644 --- a/crates/hal/src/cpu.rs +++ b/crates/hal/src/cpu.rs @@ -4,7 +4,7 @@ use std::fmt::Debug; use binius_field::{Field, PackedExtension, PackedField}; use binius_math::{ - eq_ind_partial_eval, CompositionPolyOS, MultilinearExtension, MultilinearPoly, + eq_ind_partial_eval, CompositionPoly, MultilinearExtension, MultilinearPoly, MultilinearQueryRef, }; use tracing::instrument; @@ -50,7 +50,7 @@ impl ComputationBackend for CpuBackend { P: PackedExtension, M: MultilinearPoly

+ Send + Sync, Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { calculate_round_evals(n_vars, tensor_query, multilinears, evaluators, evaluation_points) } diff --git a/crates/hal/src/sumcheck_round_calculator.rs b/crates/hal/src/sumcheck_round_calculator.rs index 8e79713cf..d14b1594b 100644 --- a/crates/hal/src/sumcheck_round_calculator.rs +++ b/crates/hal/src/sumcheck_round_calculator.rs @@ -8,7 +8,7 @@ use std::iter; use binius_field::{Field, PackedExtension, PackedField, PackedSubfield}; use binius_math::{ - deinterleave, extrapolate_lines, CompositionPolyOS, MultilinearPoly, MultilinearQuery, + deinterleave, extrapolate_lines, CompositionPoly, MultilinearPoly, MultilinearQuery, MultilinearQueryRef, }; use binius_maybe_rayon::prelude::*; @@ -58,7 +58,7 @@ where P: PackedField + PackedExtension, M: MultilinearPoly

+ Send + Sync, Evaluator: SumcheckEvaluator + Sync, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { let empty_query = MultilinearQuery::with_capacity(0); let tensor_query = tensor_query.unwrap_or_else(|| empty_query.to_ref()); @@ -86,7 +86,7 @@ where P: PackedField + PackedExtension, Evaluator: SumcheckEvaluator + Sync, Access: SumcheckMultilinearAccess

+ Sync, - Composition: CompositionPolyOS

, + Composition: CompositionPoly

, { let n_multilinears = multilinears.len(); let n_round_evals = evaluators diff --git a/crates/macros/src/arith_circuit_poly.rs b/crates/macros/src/arith_circuit_poly.rs index 0f8f1f6ad..93ba74a79 100644 --- a/crates/macros/src/arith_circuit_poly.rs +++ b/crates/macros/src/arith_circuit_poly.rs @@ -3,55 +3,22 @@ use quote::{quote, ToTokens}; use syn::{bracketed, parse::Parse, parse_quote, spanned::Spanned, Token}; -use crate::composition_poly::CompositionPolyItem; - #[derive(Debug)] pub(crate) struct ArithCircuitPolyItem { poly: syn::Expr, - /// We create a composition poly to cache the efficient evaluation implementations - /// for the known packed field types. - composition_poly: CompositionPolyItem, field_name: syn::Ident, } impl ToTokens for ArithCircuitPolyItem { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { - let Self { - poly, - composition_poly, - field_name, - } = self; - - let mut register_cached_impls = proc_macro2::TokenStream::new(); - let packed_extensions = get_packed_extensions(field_name); - if packed_extensions.is_empty() { - register_cached_impls.extend(quote! { result }); - } else { - register_cached_impls.extend(quote! ( - let mut cached = binius_core::polynomial::CachedPoly::new(composition); - - )); - - for packed_extension in get_packed_extensions(field_name) { - register_cached_impls.extend(quote! { - cached.register::(composition.clone()); - }); - } - - register_cached_impls.extend(quote! { - cached - }); - } + let Self { poly, field_name } = self; tokens.extend(quote! { { use binius_field::Field; use binius_math::ArithExpr as Expr; - let mut result = binius_core::polynomial::ArithCircuitPoly::::new(#poly); - let composition = #composition_poly; - - #register_cached_impls + binius_core::polynomial::ArithCircuitPoly::::new(#poly) } }); } @@ -59,7 +26,6 @@ impl ToTokens for ArithCircuitPolyItem { impl Parse for ArithCircuitPolyItem { fn parse(input: syn::parse::ParseStream) -> syn::Result { - let original_tokens = input.fork(); let vars: Vec = { let content; bracketed!(content in input); @@ -73,14 +39,8 @@ impl Parse for ArithCircuitPolyItem { input.parse::()?; let field_name = input.parse()?; - // Here we assume that the `composition_poly` shares the expression syntax with the `arithmetic_circuit_poly`. - let composition_poly = CompositionPolyItem::parse(&original_tokens)?; - Ok(Self { - poly, - composition_poly, - field_name, - }) + Ok(Self { poly, field_name }) } } @@ -120,350 +80,3 @@ fn flatten_expr(expr: &syn::Expr, vars: &[syn::Ident]) -> Result Vec { - match ident.to_string().as_str() { - "BinaryField1b" => vec![ - parse_quote!(PackedBinaryField1x1b), - parse_quote!(PackedBinaryField2x1b), - parse_quote!(PackedBinaryField4x1b), - parse_quote!(PackedBinaryField8x1b), - parse_quote!(PackedBinaryField16x1b), - parse_quote!(PackedBinaryField32x1b), - parse_quote!(PackedBinaryField64x1b), - parse_quote!(PackedBinaryField128x1b), - parse_quote!(PackedBinaryField256x1b), - parse_quote!(PackedBinaryField512x1b), - parse_quote!(PackedBinaryField1x2b), - parse_quote!(PackedBinaryField2x2b), - parse_quote!(PackedBinaryField4x2b), - parse_quote!(PackedBinaryField8x2b), - parse_quote!(PackedBinaryField16x2b), - parse_quote!(PackedBinaryField32x2b), - parse_quote!(PackedBinaryField64x2b), - parse_quote!(PackedBinaryField128x2b), - parse_quote!(PackedBinaryField256x2b), - parse_quote!(PackedBinaryField1x4b), - parse_quote!(PackedBinaryField2x4b), - parse_quote!(PackedBinaryField4x4b), - parse_quote!(PackedBinaryField8x4b), - parse_quote!(PackedBinaryField16x4b), - parse_quote!(PackedBinaryField32x4b), - parse_quote!(PackedBinaryField64x4b), - parse_quote!(PackedBinaryField128x4b), - parse_quote!(PackedBinaryField1x8b), - parse_quote!(PackedBinaryField2x8b), - parse_quote!(PackedBinaryField4x8b), - parse_quote!(PackedBinaryField8x8b), - parse_quote!(PackedBinaryField16x8b), - parse_quote!(PackedBinaryField32x8b), - parse_quote!(PackedBinaryField64x8b), - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - parse_quote!(PackedAESBinaryField1x8b), - parse_quote!(PackedAESBinaryField2x8b), - parse_quote!(PackedAESBinaryField4x8b), - parse_quote!(PackedAESBinaryField8x8b), - parse_quote!(PackedAESBinaryField16x8b), - parse_quote!(PackedAESBinaryField32x8b), - parse_quote!(PackedAESBinaryField64x8b), - parse_quote!(PackedAESBinaryField1x16b), - parse_quote!(PackedAESBinaryField2x16b), - parse_quote!(PackedAESBinaryField4x16b), - parse_quote!(PackedAESBinaryField8x16b), - parse_quote!(PackedAESBinaryField16x16b), - parse_quote!(PackedAESBinaryField32x16b), - parse_quote!(PackedAESBinaryField1x32b), - parse_quote!(PackedAESBinaryField2x32b), - parse_quote!(PackedAESBinaryField4x32b), - parse_quote!(PackedAESBinaryField8x32b), - parse_quote!(PackedAESBinaryField16x32b), - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - parse_quote!(PackedBinaryPolyval1x128b), - parse_quote!(PackedBinaryPolyval2x128b), - parse_quote!(PackedBinaryPolyval4x128b), - ], - "BinaryField2b" => { - vec![ - parse_quote!(PackedBinaryField1x2b), - parse_quote!(PackedBinaryField2x2b), - parse_quote!(PackedBinaryField4x2b), - parse_quote!(PackedBinaryField8x2b), - parse_quote!(PackedBinaryField16x2b), - parse_quote!(PackedBinaryField32x2b), - parse_quote!(PackedBinaryField64x2b), - parse_quote!(PackedBinaryField128x2b), - parse_quote!(PackedBinaryField256x2b), - parse_quote!(PackedBinaryField1x4b), - parse_quote!(PackedBinaryField2x4b), - parse_quote!(PackedBinaryField4x4b), - parse_quote!(PackedBinaryField8x4b), - parse_quote!(PackedBinaryField16x4b), - parse_quote!(PackedBinaryField32x4b), - parse_quote!(PackedBinaryField64x4b), - parse_quote!(PackedBinaryField128x4b), - parse_quote!(PackedBinaryField1x8b), - parse_quote!(PackedBinaryField2x8b), - parse_quote!(PackedBinaryField4x8b), - parse_quote!(PackedBinaryField8x8b), - parse_quote!(PackedBinaryField16x8b), - parse_quote!(PackedBinaryField32x8b), - parse_quote!(PackedBinaryField64x8b), - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField4b" => { - vec![ - parse_quote!(PackedBinaryField1x4b), - parse_quote!(PackedBinaryField2x4b), - parse_quote!(PackedBinaryField4x4b), - parse_quote!(PackedBinaryField8x4b), - parse_quote!(PackedBinaryField16x4b), - parse_quote!(PackedBinaryField32x4b), - parse_quote!(PackedBinaryField64x4b), - parse_quote!(PackedBinaryField128x4b), - parse_quote!(PackedBinaryField1x8b), - parse_quote!(PackedBinaryField2x8b), - parse_quote!(PackedBinaryField4x8b), - parse_quote!(PackedBinaryField8x8b), - parse_quote!(PackedBinaryField16x8b), - parse_quote!(PackedBinaryField32x8b), - parse_quote!(PackedBinaryField64x8b), - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField8b" => { - vec![ - parse_quote!(PackedBinaryField1x8b), - parse_quote!(PackedBinaryField2x8b), - parse_quote!(PackedBinaryField4x8b), - parse_quote!(PackedBinaryField8x8b), - parse_quote!(PackedBinaryField16x8b), - parse_quote!(PackedBinaryField32x8b), - parse_quote!(PackedBinaryField64x8b), - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField16b" => { - vec![ - parse_quote!(PackedBinaryField1x16b), - parse_quote!(PackedBinaryField2x16b), - parse_quote!(PackedBinaryField4x16b), - parse_quote!(PackedBinaryField8x16b), - parse_quote!(PackedBinaryField16x16b), - parse_quote!(PackedBinaryField32x16b), - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField32b" => { - vec![ - parse_quote!(PackedBinaryField1x32b), - parse_quote!(PackedBinaryField2x32b), - parse_quote!(PackedBinaryField4x32b), - parse_quote!(PackedBinaryField8x32b), - parse_quote!(PackedBinaryField16x32b), - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - "BinaryField64b" => { - vec![ - parse_quote!(PackedBinaryField1x64b), - parse_quote!(PackedBinaryField2x64b), - parse_quote!(PackedBinaryField4x64b), - parse_quote!(PackedBinaryField8x64b), - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - - "BinaryField128b" => { - vec![ - parse_quote!(PackedBinaryField1x128b), - parse_quote!(PackedBinaryField2x128b), - parse_quote!(PackedBinaryField4x128b), - ] - } - - "AESTowerField8b" => { - vec![ - parse_quote!(PackedAESBinaryField1x8b), - parse_quote!(PackedAESBinaryField2x8b), - parse_quote!(PackedAESBinaryField4x8b), - parse_quote!(PackedAESBinaryField8x8b), - parse_quote!(PackedAESBinaryField16x8b), - parse_quote!(PackedAESBinaryField32x8b), - parse_quote!(PackedAESBinaryField64x8b), - parse_quote!(PackedAESBinaryField1x16b), - parse_quote!(PackedAESBinaryField2x16b), - parse_quote!(PackedAESBinaryField4x16b), - parse_quote!(PackedAESBinaryField8x16b), - parse_quote!(PackedAESBinaryField16x16b), - parse_quote!(PackedAESBinaryField32x16b), - parse_quote!(PackedAESBinaryField1x32b), - parse_quote!(PackedAESBinaryField2x32b), - parse_quote!(PackedAESBinaryField4x32b), - parse_quote!(PackedAESBinaryField8x32b), - parse_quote!(PackedAESBinaryField16x32b), - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - parse_quote!(ByteSlicedAES32x128b), - ] - } - "AESTowerField16b" => { - vec![ - parse_quote!(PackedAESBinaryField1x16b), - parse_quote!(PackedAESBinaryField2x16b), - parse_quote!(PackedAESBinaryField4x16b), - parse_quote!(PackedAESBinaryField8x16b), - parse_quote!(PackedAESBinaryField16x16b), - parse_quote!(PackedAESBinaryField32x16b), - parse_quote!(PackedAESBinaryField1x32b), - parse_quote!(PackedAESBinaryField2x32b), - parse_quote!(PackedAESBinaryField4x32b), - parse_quote!(PackedAESBinaryField8x32b), - parse_quote!(PackedAESBinaryField16x32b), - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - ] - } - "AESTowerField32b" => { - vec![ - parse_quote!(PackedAESBinaryField1x32b), - parse_quote!(PackedAESBinaryField2x32b), - parse_quote!(PackedAESBinaryField4x32b), - parse_quote!(PackedAESBinaryField8x32b), - parse_quote!(PackedAESBinaryField16x32b), - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - ] - } - "AESTowerField64b" => { - vec![ - parse_quote!(PackedAESBinaryField1x64b), - parse_quote!(PackedAESBinaryField2x64b), - parse_quote!(PackedAESBinaryField4x64b), - parse_quote!(PackedAESBinaryField8x64b), - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - ] - } - "AESTowerField128b" => { - vec![ - parse_quote!(PackedAESBinaryField1x128b), - parse_quote!(PackedAESBinaryField2x128b), - parse_quote!(PackedAESBinaryField4x128b), - ] - } - - _ => vec![], - } -} diff --git a/crates/macros/src/composition_poly.rs b/crates/macros/src/composition_poly.rs index 36200a762..0472fc328 100644 --- a/crates/macros/src/composition_poly.rs +++ b/crates/macros/src/composition_poly.rs @@ -43,7 +43,10 @@ impl ToTokens for CompositionPolyItem { #[derive(Debug, Clone, Copy)] struct #name; - impl binius_math::CompositionPoly<#scalar_type> for #name { + impl

binius_math::CompositionPoly

for #name + where + P: binius_field::PackedField>, + { fn n_vars(&self) -> usize { #n_vars } @@ -56,18 +59,18 @@ impl ToTokens for CompositionPolyItem { 0 } - fn expression>(&self) -> binius_math::ArithExpr { + fn expression(&self) -> binius_math::ArithExpr { (#expr).convert_field() } - fn evaluate>>(&self, query: &[P]) -> Result { + fn evaluate(&self, query: &[P]) -> Result { if query.len() != #n_vars { return Err(binius_math::Error::IncorrectQuerySize { expected: #n_vars }); } Ok(#eval_single) } - fn batch_evaluate>>( + fn batch_evaluate( &self, batch_query: &[&[P]], evals: &mut [P], @@ -89,36 +92,6 @@ impl ToTokens for CompositionPolyItem { Ok(()) } } - - impl

binius_math::CompositionPolyOS

for #name - where - P: binius_field::PackedField>, - { - fn n_vars(&self) -> usize { - >::n_vars(self) - } - - fn degree(&self) -> usize { - >::degree(self) - } - - fn binary_tower_level(&self) -> usize { - >::binary_tower_level(self) - } - - fn expression(&self) -> binius_math::ArithExpr { - >::expression(self) - } - - fn evaluate(&self, query: &[P]) -> Result { - >::evaluate(self, query) - } - - fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), binius_math::Error> { - >::batch_evaluate(self, batch_query, evals) - } - } - }; if *is_anonymous { diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index f717405e1..08fa219a2 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -14,15 +14,15 @@ use crate::{ composition_poly::CompositionPolyItem, }; -/// Useful for concisely creating structs that implement CompositionPolyOS. +/// Useful for concisely creating structs that implement CompositionPoly. /// This currently only supports creating composition polynomials of tower level 0. /// /// ``` /// use binius_macros::composition_poly; -/// use binius_math::CompositionPolyOS; +/// use binius_math::CompositionPoly; /// use binius_field::{Field, BinaryField1b as F}; /// -/// // Defines named struct without any fields that implements CompositionPolyOS +/// // Defines named struct without any fields that implements CompositionPoly /// composition_poly!(MyComposition[x, y, z] = x + y * z); /// assert_eq!( /// MyComposition.evaluate(&[F::ONE, F::ONE, F::ONE]).unwrap(), diff --git a/crates/macros/tests/arithmetic_circuit.rs b/crates/macros/tests/arithmetic_circuit.rs index 895f6003b..c840e3b96 100644 --- a/crates/macros/tests/arithmetic_circuit.rs +++ b/crates/macros/tests/arithmetic_circuit.rs @@ -2,7 +2,7 @@ use binius_field::*; use binius_macros::arith_circuit_poly; -use binius_math::CompositionPolyOS; +use binius_math::CompositionPoly; use paste::paste; use rand::{rngs::StdRng, SeedableRng}; diff --git a/crates/math/src/composition_poly.rs b/crates/math/src/composition_poly.rs index 100c9d31c..8fc883f06 100644 --- a/crates/math/src/composition_poly.rs +++ b/crates/math/src/composition_poly.rs @@ -3,16 +3,14 @@ use std::fmt::Debug; use auto_impl::auto_impl; -use binius_field::{ExtensionField, Field, PackedField}; +use binius_field::PackedField; use stackalloc::stackalloc_with_default; use crate::{ArithExpr, Error}; /// A multivariate polynomial that is used as a composition of several multilinear polynomials. -/// -/// This is an object-safe version of the [`CompositionPoly`] trait. #[auto_impl(Arc, &)] -pub trait CompositionPolyOS

: Debug + Send + Sync +pub trait CompositionPoly

: Debug + Send + Sync where P: PackedField, { @@ -65,23 +63,3 @@ where }) } } - -/// A generic version of the `CompositionPolyOS` trait that is not object-safe. -#[auto_impl(&)] -pub trait CompositionPoly: Debug + Send + Sync { - fn n_vars(&self) -> usize; - - fn degree(&self) -> usize; - - fn binary_tower_level(&self) -> usize; - - fn expression>(&self) -> ArithExpr; - - fn evaluate>>(&self, query: &[P]) -> Result; - - fn batch_evaluate>>( - &self, - batch_query: &[&[P]], - evals: &mut [P], - ) -> Result<(), Error>; -} From e9991cebb742f6f2774dac1cdf55ea9255361bc5 Mon Sep 17 00:00:00 2001 From: Dmytro Gordon Date: Mon, 24 Feb 2025 17:02:41 +0200 Subject: [PATCH 40/50] ]field] Byte-sliced fields changes (#21) * Refactor a bit TowerLevels to remove packed field parameter from the TowerLevel to the Data associated type. This also makes generic bounds a bit more clean, since TowerLevel itself doesn't depend on a concrete packed field type. * Add support of byte-sliced fields with arbitrary register size, i.e. 128b, 256b, 512b. * Add shifts and unpack low/high within 128-bit lanes to UnderlierWithBitOps. This allows implementing transposition in an efficient way. * Add the transparent implementation of UnderlierWithBitOps for PackedScaledUnderlier as we need it to re-use PackedScaledField. --- .../lasso/big_integer_ops/byte_sliced_add.rs | 8 +- .../byte_sliced_add_carryfree.rs | 8 +- ...yte_sliced_double_conditional_increment.rs | 6 +- .../byte_sliced_modular_mul.rs | 11 +- .../lasso/big_integer_ops/byte_sliced_mul.rs | 11 +- .../big_integer_ops/byte_sliced_test_utils.rs | 31 +-- .../benches/packed_field_element_access.rs | 54 +++- crates/field/benches/packed_field_init.rs | 54 +++- crates/field/benches/packed_field_utils.rs | 12 + crates/field/src/arch/aarch64/m128.rs | 38 ++- .../src/arch/portable/byte_sliced/invert.rs | 45 ++- .../src/arch/portable/byte_sliced/mod.rs | 28 +- .../src/arch/portable/byte_sliced/multiply.rs | 55 ++-- .../byte_sliced/packed_byte_sliced.rs | 105 ++++--- .../src/arch/portable/byte_sliced/square.rs | 35 +-- .../field/src/arch/portable/packed_scaled.rs | 51 ++-- crates/field/src/arch/x86_64/m128.rs | 216 +++++++-------- crates/field/src/arch/x86_64/m256.rs | 106 ++++++- crates/field/src/arch/x86_64/m512.rs | 118 +++++++- crates/field/src/tower_levels.rs | 262 ++++++++---------- crates/field/src/underlier/scaled.rs | 193 ++++++++++++- crates/field/src/underlier/small_uint.rs | 8 + crates/field/src/underlier/underlier_impls.rs | 10 + .../src/underlier/underlier_with_bit_ops.rs | 81 ++++++ 24 files changed, 1072 insertions(+), 474 deletions(-) diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs index 0e274f402..302d2384b 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add.rs @@ -13,15 +13,15 @@ use crate::{ type B1 = BinaryField1b; type B8 = BinaryField8b; -pub fn byte_sliced_add>( +pub fn byte_sliced_add: Sized>>( builder: &mut ConstraintSystemBuilder, name: impl ToString + Clone, - x_in: &Level::Data, - y_in: &Level::Data, + x_in: &Level::Data, + y_in: &Level::Data, carry_in: OracleId, log_size: usize, lookup_batch_add: &mut LookupBatch, -) -> Result<(OracleId, Level::Data), anyhow::Error> { +) -> Result<(OracleId, Level::Data), anyhow::Error> { if Level::WIDTH == 1 { let (carry_out, sum) = u8add(builder, lookup_batch_add, name, x_in[0], y_in[0], carry_in, log_size)?; diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs index 38e2703dd..ab7b864c2 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_add_carryfree.rs @@ -15,16 +15,16 @@ type B1 = BinaryField1b; type B8 = BinaryField8b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_add_carryfree>( +pub fn byte_sliced_add_carryfree: Sized>>( builder: &mut ConstraintSystemBuilder, name: impl ToString, - x_in: &Level::Data, - y_in: &Level::Data, + x_in: &Level::Data, + y_in: &Level::Data, carry_in: OracleId, log_size: usize, lookup_batch_add: &mut LookupBatch, lookup_batch_add_carryfree: &mut LookupBatch, -) -> Result { +) -> Result, anyhow::Error> { if Level::WIDTH == 1 { let sum = u8add_carryfree( builder, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs index ed697b8eb..9000c4572 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_double_conditional_increment.rs @@ -14,16 +14,16 @@ type B1 = BinaryField1b; type B8 = BinaryField8b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_double_conditional_increment>( +pub fn byte_sliced_double_conditional_increment: Sized>>( builder: &mut ConstraintSystemBuilder, name: impl ToString, - x_in: &Level::Data, + x_in: &Level::Data, first_carry_in: OracleId, second_carry_in: OracleId, log_size: usize, zero_oracle_carry: usize, lookup_batch_dci: &mut LookupBatch, -) -> Result<(OracleId, Level::Data), anyhow::Error> { +) -> Result<(OracleId, Level::Data), anyhow::Error> { if Level::WIDTH == 1 { let (carry_out, sum) = u8_double_conditional_increment( builder, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs index 9677a3e4b..d480de5c2 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_modular_mul.rs @@ -20,19 +20,16 @@ use crate::{ type B8 = BinaryField8b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_modular_mul< - LevelIn: TowerLevel, - LevelOut: TowerLevel, ->( +pub fn byte_sliced_modular_mul>( builder: &mut ConstraintSystemBuilder, name: impl ToString, - mult_a: &LevelIn::Data, - mult_b: &LevelIn::Data, + mult_a: &LevelIn::Data, + mult_b: &LevelIn::Data, modulus_input: &[u8], log_size: usize, zero_byte_oracle: OracleId, zero_carry_oracle: OracleId, -) -> Result { +) -> Result, anyhow::Error> { builder.push_namespace(name); let lookup_t_mul = mul_lookup(builder, "mul table")?; diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs index 2236045fb..de386d7ad 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_mul.rs @@ -14,20 +14,17 @@ use crate::{ type B8 = BinaryField8b; #[allow(clippy::too_many_arguments)] -pub fn byte_sliced_mul< - LevelIn: TowerLevel, - LevelOut: TowerLevel, ->( +pub fn byte_sliced_mul>( builder: &mut ConstraintSystemBuilder, name: impl ToString, - mult_a: &LevelIn::Data, - mult_b: &LevelIn::Data, + mult_a: &LevelIn::Data, + mult_b: &LevelIn::Data, log_size: usize, zero_carry_oracle: OracleId, lookup_batch_mul: &mut LookupBatch, lookup_batch_add: &mut LookupBatch, lookup_batch_dci: &mut LookupBatch, -) -> Result { +) -> Result, anyhow::Error> { if LevelIn::WIDTH == 1 { let result_of_u8mul = u8mul_bytesliced( builder, diff --git a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs index 7ce699621..411a5c151 100644 --- a/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs +++ b/crates/circuits/src/lasso/big_integer_ops/byte_sliced_test_utils.rs @@ -33,14 +33,12 @@ pub fn random_u512(rng: &mut impl Rng) -> U512 { pub fn test_bytesliced_add() where - TL: TowerLevel, + TL: TowerLevel, { test_circuit(|builder| { let log_size = 14; - let x_in = - array::from_fn(|_| unconstrained::(builder, "x", log_size).unwrap()); - let y_in = - array::from_fn(|_| unconstrained::(builder, "y", log_size).unwrap()); + let x_in = TL::from_fn(|_| unconstrained::(builder, "x", log_size).unwrap()); + let y_in = TL::from_fn(|_| unconstrained::(builder, "y", log_size).unwrap()); let c_in = unconstrained::(builder, "cin first", log_size)?; let lookup_t_add = add_lookup(builder, "add table")?; let mut lookup_batch_add = LookupBatch::new([lookup_t_add]); @@ -61,14 +59,14 @@ where pub fn test_bytesliced_add_carryfree() where - TL: TowerLevel, + TL: TowerLevel, { test_circuit(|builder| { let log_size = 14; let x_in = - array::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL)); + TL::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL)); let y_in = - array::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL)); + TL::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL)); let c_in = builder.add_committed("c", log_size, BinaryField1b::TOWER_LEVEL); if let Some(witness) = builder.witness() { @@ -136,12 +134,11 @@ where pub fn test_bytesliced_double_conditional_increment() where - TL: TowerLevel, + TL: TowerLevel, { test_circuit(|builder| { let log_size = 14; - let x_in = - array::from_fn(|_| unconstrained::(builder, "x", log_size).unwrap()); + let x_in = TL::from_fn(|_| unconstrained::(builder, "x", log_size).unwrap()); let first_c_in = unconstrained::(builder, "cin first", log_size)?; let second_c_in = unconstrained::(builder, "cin second", log_size)?; let zero_oracle_carry = @@ -166,15 +163,14 @@ where pub fn test_bytesliced_mul() where - TL: TowerLevel, - TL::Base: TowerLevel, + TL: TowerLevel, { test_circuit(|builder| { let log_size = 14; let mult_a = - array::from_fn(|_| unconstrained::(builder, "a", log_size).unwrap()); + TL::Base::from_fn(|_| unconstrained::(builder, "a", log_size).unwrap()); let mult_b = - array::from_fn(|_| unconstrained::(builder, "b", log_size).unwrap()); + TL::Base::from_fn(|_| unconstrained::(builder, "b", log_size).unwrap()); let zero_oracle_carry = transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?; let lookup_t_mul = mul_lookup(builder, "mul lookup")?; @@ -201,9 +197,8 @@ where pub fn test_bytesliced_modular_mul() where - TL: TowerLevel, - TL::Base: TowerLevel, - >::Data: Debug, + TL: TowerLevel: Debug>, + TL::Base: TowerLevel = [OracleId; WIDTH]>, { test_circuit(|builder| { let log_size = 14; diff --git a/crates/field/benches/packed_field_element_access.rs b/crates/field/benches/packed_field_element_access.rs index d8f4d0ee4..834f280ad 100644 --- a/crates/field/benches/packed_field_element_access.rs +++ b/crates/field/benches/packed_field_element_access.rs @@ -3,11 +3,15 @@ use std::array; use binius_field::{ - PackedBinaryField128x1b, PackedBinaryField16x32b, PackedBinaryField16x8b, - PackedBinaryField1x128b, PackedBinaryField256x1b, PackedBinaryField2x128b, - PackedBinaryField2x64b, PackedBinaryField32x8b, PackedBinaryField4x128b, - PackedBinaryField4x32b, PackedBinaryField4x64b, PackedBinaryField512x1b, - PackedBinaryField64x8b, PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField, + ByteSlicedAES16x128b, ByteSlicedAES16x16b, ByteSlicedAES16x32b, ByteSlicedAES16x64b, + ByteSlicedAES16x8b, ByteSlicedAES32x128b, ByteSlicedAES32x16b, ByteSlicedAES32x32b, + ByteSlicedAES32x64b, ByteSlicedAES32x8b, ByteSlicedAES64x128b, ByteSlicedAES64x16b, + ByteSlicedAES64x32b, ByteSlicedAES64x64b, ByteSlicedAES64x8b, PackedBinaryField128x1b, + PackedBinaryField16x32b, PackedBinaryField16x8b, PackedBinaryField1x128b, + PackedBinaryField256x1b, PackedBinaryField2x128b, PackedBinaryField2x64b, + PackedBinaryField32x8b, PackedBinaryField4x128b, PackedBinaryField4x32b, + PackedBinaryField4x64b, PackedBinaryField512x1b, PackedBinaryField64x8b, + PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField, }; use criterion::{ criterion_group, criterion_main, measurement::WallTime, BenchmarkGroup, Criterion, Throughput, @@ -86,5 +90,43 @@ fn packed_512(c: &mut Criterion) { benchmark_get_set!(PackedBinaryField4x128b, group); } -criterion_group!(get_set, packed_128, packed_256, packed_512); +fn byte_sliced_128(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_128"); + + benchmark_get_set!(ByteSlicedAES16x8b, group); + benchmark_get_set!(ByteSlicedAES16x16b, group); + benchmark_get_set!(ByteSlicedAES16x32b, group); + benchmark_get_set!(ByteSlicedAES16x64b, group); + benchmark_get_set!(ByteSlicedAES16x128b, group); +} + +fn byte_sliced_256(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_256"); + + benchmark_get_set!(ByteSlicedAES32x8b, group); + benchmark_get_set!(ByteSlicedAES32x16b, group); + benchmark_get_set!(ByteSlicedAES32x32b, group); + benchmark_get_set!(ByteSlicedAES32x64b, group); + benchmark_get_set!(ByteSlicedAES32x128b, group); +} + +fn byte_sliced_512(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_512"); + + benchmark_get_set!(ByteSlicedAES64x8b, group); + benchmark_get_set!(ByteSlicedAES64x16b, group); + benchmark_get_set!(ByteSlicedAES64x32b, group); + benchmark_get_set!(ByteSlicedAES64x64b, group); + benchmark_get_set!(ByteSlicedAES64x128b, group); +} + +criterion_group!( + get_set, + packed_128, + packed_256, + packed_512, + byte_sliced_128, + byte_sliced_256, + byte_sliced_512 +); criterion_main!(get_set); diff --git a/crates/field/benches/packed_field_init.rs b/crates/field/benches/packed_field_init.rs index b2a3feb54..43c644e93 100644 --- a/crates/field/benches/packed_field_init.rs +++ b/crates/field/benches/packed_field_init.rs @@ -3,11 +3,15 @@ use std::array; use binius_field::{ - PackedBinaryField128x1b, PackedBinaryField16x32b, PackedBinaryField16x8b, - PackedBinaryField1x128b, PackedBinaryField256x1b, PackedBinaryField2x128b, - PackedBinaryField2x64b, PackedBinaryField32x8b, PackedBinaryField4x128b, - PackedBinaryField4x32b, PackedBinaryField4x64b, PackedBinaryField512x1b, - PackedBinaryField64x8b, PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField, + ByteSlicedAES16x128b, ByteSlicedAES16x16b, ByteSlicedAES16x32b, ByteSlicedAES16x64b, + ByteSlicedAES16x8b, ByteSlicedAES32x128b, ByteSlicedAES32x16b, ByteSlicedAES32x32b, + ByteSlicedAES32x64b, ByteSlicedAES32x8b, ByteSlicedAES64x128b, ByteSlicedAES64x16b, + ByteSlicedAES64x32b, ByteSlicedAES64x64b, ByteSlicedAES64x8b, PackedBinaryField128x1b, + PackedBinaryField16x32b, PackedBinaryField16x8b, PackedBinaryField1x128b, + PackedBinaryField256x1b, PackedBinaryField2x128b, PackedBinaryField2x64b, + PackedBinaryField32x8b, PackedBinaryField4x128b, PackedBinaryField4x32b, + PackedBinaryField4x64b, PackedBinaryField512x1b, PackedBinaryField64x8b, + PackedBinaryField8x32b, PackedBinaryField8x64b, PackedField, }; use criterion::{ criterion_group, criterion_main, measurement::WallTime, BenchmarkGroup, Criterion, Throughput, @@ -71,5 +75,43 @@ fn packed_512(c: &mut Criterion) { benchmark_from_fn!(PackedBinaryField4x128b, group); } -criterion_group!(initialization, packed_128, packed_256, packed_512); +fn byte_sliced_128(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_128"); + + benchmark_from_fn!(ByteSlicedAES16x8b, group); + benchmark_from_fn!(ByteSlicedAES16x16b, group); + benchmark_from_fn!(ByteSlicedAES16x32b, group); + benchmark_from_fn!(ByteSlicedAES16x64b, group); + benchmark_from_fn!(ByteSlicedAES16x128b, group); +} + +fn byte_sliced_256(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_256"); + + benchmark_from_fn!(ByteSlicedAES32x8b, group); + benchmark_from_fn!(ByteSlicedAES32x16b, group); + benchmark_from_fn!(ByteSlicedAES32x32b, group); + benchmark_from_fn!(ByteSlicedAES32x64b, group); + benchmark_from_fn!(ByteSlicedAES32x128b, group); +} + +fn byte_sliced_512(c: &mut Criterion) { + let mut group = c.benchmark_group("bytes_sliced_512"); + + benchmark_from_fn!(ByteSlicedAES64x8b, group); + benchmark_from_fn!(ByteSlicedAES64x16b, group); + benchmark_from_fn!(ByteSlicedAES64x32b, group); + benchmark_from_fn!(ByteSlicedAES64x64b, group); + benchmark_from_fn!(ByteSlicedAES64x128b, group); +} + +criterion_group!( + initialization, + packed_128, + packed_256, + packed_512, + byte_sliced_128, + byte_sliced_256, + byte_sliced_512 +); criterion_main!(initialization); diff --git a/crates/field/benches/packed_field_utils.rs b/crates/field/benches/packed_field_utils.rs index 516af5e6e..622a49afa 100644 --- a/crates/field/benches/packed_field_utils.rs +++ b/crates/field/benches/packed_field_utils.rs @@ -274,11 +274,23 @@ macro_rules! benchmark_packed_operation { PackedBinaryPolyval4x128b // Byte sliced AES fields + ByteSlicedAES16x8b + ByteSlicedAES16x16b + ByteSlicedAES16x32b + ByteSlicedAES16x64b + ByteSlicedAES16x128b + ByteSlicedAES32x8b ByteSlicedAES32x16b ByteSlicedAES32x32b ByteSlicedAES32x64b ByteSlicedAES32x128b + + ByteSlicedAES64x8b + ByteSlicedAES64x16b + ByteSlicedAES64x32b + ByteSlicedAES64x64b + ByteSlicedAES64x128b ]); }; } diff --git a/crates/field/src/arch/aarch64/m128.rs b/crates/field/src/arch/aarch64/m128.rs index 8155d6034..70c496606 100644 --- a/crates/field/src/arch/aarch64/m128.rs +++ b/crates/field/src/arch/aarch64/m128.rs @@ -19,8 +19,8 @@ use crate::{ arch::binary_utils::{as_array_mut, as_array_ref}, arithmetic_traits::Broadcast, underlier::{ - impl_divisible, impl_iteration, NumCast, Random, SmallU, UnderlierType, - UnderlierWithBitOps, WithUnderlier, U1, U2, U4, + impl_divisible, impl_iteration, unpack_lo_128b_fallback, NumCast, Random, SmallU, + UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4, }, BinaryField, }; @@ -337,6 +337,40 @@ impl UnderlierWithBitOps for M128 { _ => panic!("unsupported bit count"), } } + + #[inline(always)] + fn shl_128b_lanes(self, rhs: usize) -> Self { + Self(self.0 << rhs) + } + + #[inline(always)] + fn shr_128b_lanes(self, rhs: usize) -> Self { + Self(self.0 >> rhs) + } + + #[inline(always)] + fn unpack_lo_128b_lanes(self, rhs: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, rhs, log_block_len), + 3 => unsafe { vzip1q_u8(self.into(), rhs.into()).into() }, + 4 => unsafe { vzip1q_u16(self.into(), rhs.into()).into() }, + 5 => unsafe { vzip1q_u32(self.into(), rhs.into()).into() }, + 6 => unsafe { vzip1q_u64(self.into(), rhs.into()).into() }, + _ => panic!("Unsupported block length"), + } + } + + #[inline(always)] + fn unpack_hi_128b_lanes(self, rhs: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, rhs, log_block_len), + 3 => unsafe { vzip2q_u8(self.into(), rhs.into()).into() }, + 4 => unsafe { vzip2q_u16(self.into(), rhs.into()).into() }, + 5 => unsafe { vzip2q_u32(self.into(), rhs.into()).into() }, + 6 => unsafe { vzip2q_u64(self.into(), rhs.into()).into() }, + _ => panic!("Unsupported block length"), + } + } } impl UnderlierWithBitConstants for M128 { diff --git a/crates/field/src/arch/portable/byte_sliced/invert.rs b/crates/field/src/arch/portable/byte_sliced/invert.rs index 3544cbbdf..8581e5669 100644 --- a/crates/field/src/arch/portable/byte_sliced/invert.rs +++ b/crates/field/src/arch/portable/byte_sliced/invert.rs @@ -6,25 +6,24 @@ use super::{ use crate::{ tower_levels::{TowerLevel, TowerLevelWithArithOps}, underlier::WithUnderlier, - AESTowerField8b, PackedAESBinaryField32x8b, PackedField, + AESTowerField8b, PackedField, }; #[inline(always)] -pub fn invert_or_zero>( - field_element: &Level::Data, - destination: &mut Level::Data, +pub fn invert_or_zero, Level: TowerLevel>( + field_element: &Level::Data

, + destination: &mut Level::Data

, ) { - let base_alpha = - PackedAESBinaryField32x8b::from_scalars([AESTowerField8b::from_underlier(0xd3); 32]); + let base_alpha = P::broadcast(AESTowerField8b::from_underlier(0xd3)); - inv_main::(field_element, destination, base_alpha); + inv_main::(field_element, destination, base_alpha); } #[inline(always)] -fn inv_main>( - field_element: &Level::Data, - destination: &mut Level::Data, - base_alpha: PackedAESBinaryField32x8b, +fn inv_main, Level: TowerLevel>( + field_element: &Level::Data

, + destination: &mut Level::Data

, + base_alpha: P, ) { if Level::WIDTH == 1 { destination.as_mut()[0] = field_element.as_ref()[0].invert_or_zero(); @@ -35,36 +34,30 @@ fn inv_main>( let (result0, result1) = Level::split_mut(destination); - let mut intermediate = <>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut intermediate = <::Base as TowerLevel>::default(); // intermediate = subfield_alpha*a1 - mul_alpha::(a1, &mut intermediate, base_alpha); + mul_alpha::(a1, &mut intermediate, base_alpha); // intermediate = a0 + subfield_alpha*a1 Level::Base::add_into(a0, &mut intermediate); - let mut delta = <>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut delta = <::Base as TowerLevel>::default(); // delta = intermediate * a0 - mul_main::(&intermediate, a0, &mut delta, base_alpha); + mul_main::(&intermediate, a0, &mut delta, base_alpha); // delta = intermediate * a0 + a1^2 - square_main::(a1, &mut delta, base_alpha); + square_main::(a1, &mut delta, base_alpha); - let mut delta_inv = <>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut delta_inv = <::Base as TowerLevel>::default(); // delta_inv = 1/delta - inv_main::(&delta, &mut delta_inv, base_alpha); + inv_main::(&delta, &mut delta_inv, base_alpha); // result0 = delta_inv*intermediate - mul_main::(&delta_inv, &intermediate, result0, base_alpha); + mul_main::(&delta_inv, &intermediate, result0, base_alpha); // result1 = delta_inv*intermediate - mul_main::(&delta_inv, a1, result1, base_alpha); + mul_main::(&delta_inv, a1, result1, base_alpha); } diff --git a/crates/field/src/arch/portable/byte_sliced/mod.rs b/crates/field/src/arch/portable/byte_sliced/mod.rs index 0c42539b9..0f29290a6 100644 --- a/crates/field/src/arch/portable/byte_sliced/mod.rs +++ b/crates/field/src/arch/portable/byte_sliced/mod.rs @@ -15,8 +15,8 @@ pub mod tests { use proptest::prelude::*; use crate::{$scalar_type, underlier::WithUnderlier, packed::PackedField, arch::byte_sliced::$name}; - fn scalar_array_strategy() -> impl Strategy { - any::<[<$scalar_type as WithUnderlier>::Underlier; 32]>().prop_map(|arr| arr.map(<$scalar_type>::from_underlier)) + fn scalar_array_strategy() -> impl Strategy::WIDTH]> { + any::<[<$scalar_type as WithUnderlier>::Underlier; <$name>::WIDTH]>().prop_map(|arr| arr.map(<$scalar_type>::from_underlier)) } proptest! { @@ -27,7 +27,7 @@ pub mod tests { let bytesliced_result = bytesliced_a + bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_result.get(i)); } } @@ -39,7 +39,7 @@ pub mod tests { bytesliced_a += bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] + scalar_elems_b[i], bytesliced_a.get(i)); } } @@ -51,7 +51,7 @@ pub mod tests { let bytesliced_result = bytesliced_a - bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_result.get(i)); } } @@ -63,7 +63,7 @@ pub mod tests { bytesliced_a -= bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] - scalar_elems_b[i], bytesliced_a.get(i)); } } @@ -75,7 +75,7 @@ pub mod tests { let bytesliced_result = bytesliced_a * bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_result.get(i)); } } @@ -87,7 +87,7 @@ pub mod tests { bytesliced_a *= bytesliced_b; - for i in 0..32 { + for i in 0..<$name>::WIDTH { assert_eq!(scalar_elems_a[i] * scalar_elems_b[i], bytesliced_a.get(i)); } } @@ -118,9 +118,21 @@ pub mod tests { }; } + define_byte_sliced_test!(tests_16x128, ByteSlicedAES16x128b, AESTowerField128b); + define_byte_sliced_test!(tests_16x64, ByteSlicedAES16x64b, AESTowerField64b); + define_byte_sliced_test!(tests_16x32, ByteSlicedAES16x32b, AESTowerField32b); + define_byte_sliced_test!(tests_16x16, ByteSlicedAES16x16b, AESTowerField16b); + define_byte_sliced_test!(tests_16x8, ByteSlicedAES16x8b, AESTowerField8b); + define_byte_sliced_test!(tests_32x128, ByteSlicedAES32x128b, AESTowerField128b); define_byte_sliced_test!(tests_32x64, ByteSlicedAES32x64b, AESTowerField64b); define_byte_sliced_test!(tests_32x32, ByteSlicedAES32x32b, AESTowerField32b); define_byte_sliced_test!(tests_32x16, ByteSlicedAES32x16b, AESTowerField16b); define_byte_sliced_test!(tests_32x8, ByteSlicedAES32x8b, AESTowerField8b); + + define_byte_sliced_test!(tests_64x128, ByteSlicedAES64x128b, AESTowerField128b); + define_byte_sliced_test!(tests_64x64, ByteSlicedAES64x64b, AESTowerField64b); + define_byte_sliced_test!(tests_64x32, ByteSlicedAES64x32b, AESTowerField32b); + define_byte_sliced_test!(tests_64x16, ByteSlicedAES64x16b, AESTowerField16b); + define_byte_sliced_test!(tests_64x8, ByteSlicedAES64x8b, AESTowerField8b); } diff --git a/crates/field/src/arch/portable/byte_sliced/multiply.rs b/crates/field/src/arch/portable/byte_sliced/multiply.rs index cfaf7339b..c2037703a 100644 --- a/crates/field/src/arch/portable/byte_sliced/multiply.rs +++ b/crates/field/src/arch/portable/byte_sliced/multiply.rs @@ -2,25 +2,28 @@ use crate::{ tower_levels::{TowerLevel, TowerLevelWithArithOps}, underlier::WithUnderlier, - AESTowerField8b, PackedAESBinaryField32x8b, PackedField, + AESTowerField8b, PackedField, }; #[inline(always)] -pub fn mul>( - field_element_a: &Level::Data, - field_element_b: &Level::Data, - destination: &mut Level::Data, +pub fn mul, Level: TowerLevel>( + field_element_a: &Level::Data

, + field_element_b: &Level::Data

, + destination: &mut Level::Data

, ) { - let base_alpha = - PackedAESBinaryField32x8b::from_scalars([AESTowerField8b::from_underlier(0xd3); 32]); - mul_main::(field_element_a, field_element_b, destination, base_alpha); + let base_alpha = P::broadcast(AESTowerField8b::from_underlier(0xd3)); + mul_main::(field_element_a, field_element_b, destination, base_alpha); } #[inline(always)] -pub fn mul_alpha>( - field_element: &Level::Data, - destination: &mut Level::Data, - base_alpha: PackedAESBinaryField32x8b, +pub fn mul_alpha< + const WRITING_TO_ZEROS: bool, + P: PackedField, + Level: TowerLevel, +>( + field_element: &Level::Data

, + destination: &mut Level::Data

, + base_alpha: P, ) { if Level::WIDTH == 1 { if WRITING_TO_ZEROS { @@ -49,15 +52,19 @@ pub fn mul_alpha(a1, result1, base_alpha); + mul_alpha::(a1, result1, base_alpha); } #[inline(always)] -pub fn mul_main>( - field_element_a: &Level::Data, - field_element_b: &Level::Data, - destination: &mut Level::Data, - base_alpha: PackedAESBinaryField32x8b, +pub fn mul_main< + const WRITING_TO_ZEROS: bool, + P: PackedField, + Level: TowerLevel, +>( + field_element_a: &Level::Data

, + field_element_b: &Level::Data

, + destination: &mut Level::Data

, + base_alpha: P, ) { if Level::WIDTH == 1 { if WRITING_TO_ZEROS { @@ -78,21 +85,19 @@ pub fn mul_main>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut z2_z0 = <::Base as TowerLevel>::default(); // z2_z0 = z2 - mul_main::(a1, b1, &mut z2_z0, base_alpha); + mul_main::(a1, b1, &mut z2_z0, base_alpha); // result1 = z2 * alpha - mul_alpha::(&z2_z0, result1, base_alpha); + mul_alpha::(&z2_z0, result1, base_alpha); // z2_z0 = z2 + z0 - mul_main::(a0, b0, &mut z2_z0, base_alpha); + mul_main::(a0, b0, &mut z2_z0, base_alpha); // result1 = z1 + z2 * alpha - mul_main::(&xored_halves_a, &xored_halves_b, result1, base_alpha); + mul_main::(&xored_halves_a, &xored_halves_b, result1, base_alpha); // result1 = z2+ z0+ z1 + z2 * alpha Level::Base::add_into(&z2_z0, result1); diff --git a/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs b/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs index fd4c9b84d..40e9648aa 100644 --- a/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs +++ b/crates/field/src/arch/portable/byte_sliced/packed_byte_sliced.rs @@ -7,7 +7,7 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; -use bytemuck::Zeroable; +use bytemuck::{Pod, Zeroable}; use super::{invert::invert_or_zero, multiply::mul, square::square}; use crate::{ @@ -15,7 +15,7 @@ use crate::{ tower_levels::*, underlier::{UnderlierWithBitOps, WithUnderlier}, AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField64b, AESTowerField8b, - PackedField, + PackedAESBinaryField16x8b, PackedAESBinaryField64x8b, PackedField, }; /// Represents 32 AES Tower Field elements in byte-sliced form backed by Packed 32x8b AES fields. @@ -24,16 +24,15 @@ use crate::{ /// multiplication circuit on GFNI machines, since multiplication of two 32x8b field elements is /// handled in one instruction. macro_rules! define_byte_sliced { - ($name:ident, $scalar_type:ty, $tower_level: ty) => { - #[derive(Default, Clone, Debug, Copy, PartialEq, Eq, Zeroable)] + ($name:ident, $scalar_type:ty, $packed_storage:ty, $tower_level: ty) => { + #[derive(Default, Clone, Debug, Copy, PartialEq, Eq, Pod, Zeroable)] + #[repr(transparent)] pub struct $name { - pub(super) data: [PackedAESBinaryField32x8b; - <$tower_level as TowerLevel>::WIDTH], + pub(super) data: [$packed_storage; <$tower_level as TowerLevel>::WIDTH], } impl $name { - pub const BYTES: usize = PackedAESBinaryField32x8b::WIDTH - * <$tower_level as TowerLevel>::WIDTH; + pub const BYTES: usize = <$packed_storage>::WIDTH * <$tower_level as TowerLevel>::WIDTH; /// Get the byte at the given index. /// @@ -41,11 +40,8 @@ macro_rules! define_byte_sliced { /// The caller must ensure that `byte_index` is less than `BYTES`. #[allow(clippy::modulo_one)] pub unsafe fn get_byte_unchecked(&self, byte_index: usize) -> u8 { - self.data - [byte_index % <$tower_level as TowerLevel>::WIDTH] - .get( - byte_index / <$tower_level as TowerLevel>::WIDTH, - ) + self.data[byte_index % <$tower_level as TowerLevel>::WIDTH] + .get(byte_index / <$tower_level as TowerLevel>::WIDTH) .to_underlier() } } @@ -53,20 +49,17 @@ macro_rules! define_byte_sliced { impl PackedField for $name { type Scalar = $scalar_type; - const LOG_WIDTH: usize = 5; + const LOG_WIDTH: usize = <$packed_storage>::LOG_WIDTH; #[inline(always)] unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar { - let mut result_underlier = 0; - for (byte_index, val) in self.data.iter().enumerate() { - // Safety: - // - `byte_index` is less than 16 - // - `i` must be less than 32 due to safety conditions of this method - unsafe { - result_underlier - .set_subvalue(byte_index, val.get_unchecked(i).to_underlier()) - } - } + let result_underlier = + ::Underlier::from_fn(|byte_index| unsafe { + self.data + .get_unchecked(byte_index) + .get_unchecked(i) + .to_underlier() + }); Self::Scalar::from_underlier(result_underlier) } @@ -75,8 +68,7 @@ macro_rules! define_byte_sliced { unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar) { let underlier = scalar.to_underlier(); - for byte_index in 0..<$tower_level as TowerLevel>::WIDTH - { + for byte_index in 0..<$tower_level as TowerLevel>::WIDTH { self.data[byte_index].set_unchecked( i, AESTowerField8b::from_underlier(underlier.get_subvalue(byte_index)), @@ -92,9 +84,9 @@ macro_rules! define_byte_sliced { fn broadcast(scalar: Self::Scalar) -> Self { Self { data: array::from_fn(|byte_index| { - PackedAESBinaryField32x8b::broadcast(AESTowerField8b::from_underlier( - unsafe { scalar.to_underlier().get_subvalue(byte_index) }, - )) + <$packed_storage>::broadcast(AESTowerField8b::from_underlier(unsafe { + scalar.to_underlier().get_subvalue(byte_index) + })) }), } } @@ -115,7 +107,7 @@ macro_rules! define_byte_sliced { fn square(self) -> Self { let mut result = Self::default(); - square::<$tower_level>(&self.data, &mut result.data); + square::<$packed_storage, $tower_level>(&self.data, &mut result.data); result } @@ -123,7 +115,7 @@ macro_rules! define_byte_sliced { #[inline] fn invert_or_zero(self) -> Self { let mut result = Self::default(); - invert_or_zero::<$tower_level>(&self.data, &mut result.data); + invert_or_zero::<$packed_storage, $tower_level>(&self.data, &mut result.data); result } @@ -132,7 +124,7 @@ macro_rules! define_byte_sliced { let mut result1 = Self::default(); let mut result2 = Self::default(); - for byte_num in 0..<$tower_level as TowerLevel>::WIDTH { + for byte_num in 0..<$tower_level as TowerLevel>::WIDTH { (result1.data[byte_num], result2.data[byte_num]) = self.data[byte_num].interleave(other.data[byte_num], log_block_len); } @@ -145,7 +137,7 @@ macro_rules! define_byte_sliced { let mut result1 = Self::default(); let mut result2 = Self::default(); - for byte_num in 0..<$tower_level as TowerLevel>::WIDTH { + for byte_num in 0..<$tower_level as TowerLevel>::WIDTH { (result1.data[byte_num], result2.data[byte_num]) = self.data[byte_num].unzip(other.data[byte_num], log_block_len); } @@ -220,12 +212,9 @@ macro_rules! define_byte_sliced { type Output = Self; fn mul(self, rhs: Self) -> Self { - let mut result = $name { - data: [PackedAESBinaryField32x8b::default(); - <$tower_level as TowerLevel>::WIDTH], - }; + let mut result = Self::default(); - mul::<$tower_level>(&self.data, &rhs.data, &mut result.data); + mul::<$packed_storage, $tower_level>(&self.data, &rhs.data, &mut result.data); result } @@ -284,8 +273,38 @@ macro_rules! define_byte_sliced { }; } -define_byte_sliced!(ByteSlicedAES32x128b, AESTowerField128b, TowerLevel16); -define_byte_sliced!(ByteSlicedAES32x64b, AESTowerField64b, TowerLevel8); -define_byte_sliced!(ByteSlicedAES32x32b, AESTowerField32b, TowerLevel4); -define_byte_sliced!(ByteSlicedAES32x16b, AESTowerField16b, TowerLevel2); -define_byte_sliced!(ByteSlicedAES32x8b, AESTowerField8b, TowerLevel1); +// 128 bit +define_byte_sliced!( + ByteSlicedAES16x128b, + AESTowerField128b, + PackedAESBinaryField16x8b, + TowerLevel16 +); +define_byte_sliced!(ByteSlicedAES16x64b, AESTowerField64b, PackedAESBinaryField16x8b, TowerLevel8); +define_byte_sliced!(ByteSlicedAES16x32b, AESTowerField32b, PackedAESBinaryField16x8b, TowerLevel4); +define_byte_sliced!(ByteSlicedAES16x16b, AESTowerField16b, PackedAESBinaryField16x8b, TowerLevel2); +define_byte_sliced!(ByteSlicedAES16x8b, AESTowerField8b, PackedAESBinaryField16x8b, TowerLevel1); + +// 256 bit +define_byte_sliced!( + ByteSlicedAES32x128b, + AESTowerField128b, + PackedAESBinaryField32x8b, + TowerLevel16 +); +define_byte_sliced!(ByteSlicedAES32x64b, AESTowerField64b, PackedAESBinaryField32x8b, TowerLevel8); +define_byte_sliced!(ByteSlicedAES32x32b, AESTowerField32b, PackedAESBinaryField32x8b, TowerLevel4); +define_byte_sliced!(ByteSlicedAES32x16b, AESTowerField16b, PackedAESBinaryField32x8b, TowerLevel2); +define_byte_sliced!(ByteSlicedAES32x8b, AESTowerField8b, PackedAESBinaryField32x8b, TowerLevel1); + +// 512 bit +define_byte_sliced!( + ByteSlicedAES64x128b, + AESTowerField128b, + PackedAESBinaryField64x8b, + TowerLevel16 +); +define_byte_sliced!(ByteSlicedAES64x64b, AESTowerField64b, PackedAESBinaryField64x8b, TowerLevel8); +define_byte_sliced!(ByteSlicedAES64x32b, AESTowerField32b, PackedAESBinaryField64x8b, TowerLevel4); +define_byte_sliced!(ByteSlicedAES64x16b, AESTowerField16b, PackedAESBinaryField64x8b, TowerLevel2); +define_byte_sliced!(ByteSlicedAES64x8b, AESTowerField8b, PackedAESBinaryField64x8b, TowerLevel1); diff --git a/crates/field/src/arch/portable/byte_sliced/square.rs b/crates/field/src/arch/portable/byte_sliced/square.rs index bcd9514aa..0c0b6ab70 100644 --- a/crates/field/src/arch/portable/byte_sliced/square.rs +++ b/crates/field/src/arch/portable/byte_sliced/square.rs @@ -3,24 +3,27 @@ use super::multiply::mul_alpha; use crate::{ tower_levels::{TowerLevel, TowerLevelWithArithOps}, underlier::WithUnderlier, - AESTowerField8b, PackedAESBinaryField32x8b, PackedField, + AESTowerField8b, PackedField, }; #[inline(always)] -pub fn square>( - field_element: &Level::Data, - destination: &mut Level::Data, +pub fn square, Level: TowerLevel>( + field_element: &Level::Data

, + destination: &mut Level::Data

, ) { - let base_alpha = - PackedAESBinaryField32x8b::from_scalars([AESTowerField8b::from_underlier(0xd3); 32]); - square_main::(field_element, destination, base_alpha); + let base_alpha = P::broadcast(AESTowerField8b::from_underlier(0xd3)); + square_main::(field_element, destination, base_alpha); } #[inline(always)] -pub fn square_main>( - field_element: &Level::Data, - destination: &mut Level::Data, - base_alpha: PackedAESBinaryField32x8b, +pub fn square_main< + const WRITING_TO_ZEROS: bool, + P: PackedField, + Level: TowerLevel, +>( + field_element: &Level::Data

, + destination: &mut Level::Data

, + base_alpha: P, ) { if Level::WIDTH == 1 { if WRITING_TO_ZEROS { @@ -34,15 +37,13 @@ pub fn square_main>::Base as TowerLevel< - PackedAESBinaryField32x8b, - >>::default(); + let mut a1_squared = <::Base as TowerLevel>::default(); - square_main::(a1, &mut a1_squared, base_alpha); + square_main::(a1, &mut a1_squared, base_alpha); - mul_alpha::(&a1_squared, result1, base_alpha); + mul_alpha::(&a1_squared, result1, base_alpha); - square_main::(a0, result0, base_alpha); + square_main::(a0, result0, base_alpha); Level::Base::add_into(&a1_squared, result0); } diff --git a/crates/field/src/arch/portable/packed_scaled.rs b/crates/field/src/arch/portable/packed_scaled.rs index 93f693bcd..8417939e5 100644 --- a/crates/field/src/arch/portable/packed_scaled.rs +++ b/crates/field/src/arch/portable/packed_scaled.rs @@ -397,20 +397,23 @@ macro_rules! packed_scaled_field { impl std::ops::Add<<$inner as $crate::packed::PackedField>::Scalar> for $name { type Output = Self; - fn add(self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { - let mut result = Self::default(); - for i in 0..Self::WIDTH_IN_PT { - result.0[i] = self.0[i] + rhs; + #[inline] + fn add(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v += broadcast; } - result + self } } impl std::ops::AddAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name { + #[inline] fn add_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) { - for i in 0..Self::WIDTH_IN_PT { - self.0[i] += rhs; + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v += broadcast; } } } @@ -418,20 +421,23 @@ macro_rules! packed_scaled_field { impl std::ops::Sub<<$inner as $crate::packed::PackedField>::Scalar> for $name { type Output = Self; - fn sub(self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { - let mut result = Self::default(); - for i in 0..Self::WIDTH_IN_PT { - result.0[i] = self.0[i] - rhs; + #[inline] + fn sub(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v -= broadcast; } - result + self } } impl std::ops::SubAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name { + #[inline] fn sub_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) { - for i in 0..Self::WIDTH_IN_PT { - self.0[i] -= rhs; + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v -= broadcast; } } } @@ -439,20 +445,23 @@ macro_rules! packed_scaled_field { impl std::ops::Mul<<$inner as $crate::packed::PackedField>::Scalar> for $name { type Output = Self; - fn mul(self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { - let mut result = Self::default(); - for i in 0..Self::WIDTH_IN_PT { - result.0[i] = self.0[i] * rhs; + #[inline] + fn mul(mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) -> Self { + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v *= broadcast; } - result + self } } impl std::ops::MulAssign<<$inner as $crate::packed::PackedField>::Scalar> for $name { + #[inline] fn mul_assign(&mut self, rhs: <$inner as $crate::packed::PackedField>::Scalar) { - for i in 0..Self::WIDTH_IN_PT { - self.0[i] *= rhs; + let broadcast = <$inner as $crate::packed::PackedField>::broadcast(rhs); + for v in self.0.iter_mut() { + *v *= broadcast; } } } diff --git a/crates/field/src/arch/x86_64/m128.rs b/crates/field/src/arch/x86_64/m128.rs index c2827ae76..d1d715f6d 100644 --- a/crates/field/src/arch/x86_64/m128.rs +++ b/crates/field/src/arch/x86_64/m128.rs @@ -23,8 +23,9 @@ use crate::{ }, arithmetic_traits::Broadcast, underlier::{ - impl_divisible, impl_iteration, spread_fallback, NumCast, Random, SmallU, SpreadToByte, - UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4, + impl_divisible, impl_iteration, spread_fallback, unpack_hi_128b_fallback, + unpack_lo_128b_fallback, NumCast, Random, SmallU, SpreadToByte, UnderlierType, + UnderlierWithBitOps, WithUnderlier, U1, U2, U4, }, BinaryField, }; @@ -181,7 +182,7 @@ impl Not for M128 { } /// `std::cmp::max` isn't const, so we need our own implementation -const fn max_i32(left: i32, right: i32) -> i32 { +pub(crate) const fn max_i32(left: i32, right: i32) -> i32 { if left > right { left } else { @@ -193,22 +194,37 @@ const fn max_i32(left: i32, right: i32) -> i32 { /// We have to use macro because parameter `count` in _mm_slli_epi64/_mm_srli_epi64 should be passed as constant /// and Rust currently doesn't allow passing expressions (`count - 64`) where variable is a generic constant parameter. /// Source: https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688 -macro_rules! bitshift_right { - ($val:expr, $count:literal) => { +macro_rules! bitshift_128b { + ($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => { unsafe { - let carry = _mm_bsrli_si128($val, 8); - if $count >= 64 { - _mm_srli_epi64(carry, max_i32($count - 64, 0)) - } else { - let carry = _mm_slli_epi64(carry, max_i32(64 - $count, 0)); - - let val = _mm_srli_epi64($val, $count); - _mm_or_si128(val, carry) - } + let carry = $byte_shift($val, 8); + seq!(N in 64..128 { + if $shift == N { + return $bit_shift_64( + carry, + crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _, + ).into(); + } + }); + seq!(N in 0..64 { + if $shift == N { + let carry = $bit_shift_64_opposite( + carry, + crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _, + ); + + let val = $bit_shift_64($val, N); + return $or(val, carry).into(); + } + }); + + return Default::default() } }; } +pub(crate) use bitshift_128b; + impl Shr for M128 { type Output = Self; @@ -216,32 +232,10 @@ impl Shr for M128 { fn shr(self, rhs: usize) -> Self::Output { // This implementation is effective when `rhs` is known at compile-time. // In our code this is always the case. - seq!(N in 0..128 { - if rhs == N { - return Self(bitshift_right!(self.0, N)); - } - }); - - Self::default() + bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128) } } -macro_rules! bitshift_left { - ($val:expr, $count:literal) => { - unsafe { - let carry = _mm_bslli_si128($val, 8); - if $count >= 64 { - _mm_slli_epi64(carry, max_i32($count - 64, 0)) - } else { - let carry = _mm_srli_epi64(carry, max_i32(64 - $count, 0)); - - let val = _mm_slli_epi64($val, $count); - _mm_or_si128(val, carry) - } - } - }; -} - impl Shl for M128 { type Output = Self; @@ -249,13 +243,7 @@ impl Shl for M128 { fn shl(self, rhs: usize) -> Self::Output { // This implementation is effective when `rhs` is known at compile-time. // In our code this is always the case. - seq!(N in 0..128 { - if rhs == N { - return Self(bitshift_left!(self.0, N)); - } - }); - - Self::default() + bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128); } } @@ -707,6 +695,40 @@ impl UnderlierWithBitOps for M128 { _ => panic!("unsupported bit length"), } } + + #[inline] + fn shl_128b_lanes(self, shift: usize) -> Self { + self << shift + } + + #[inline] + fn shr_128b_lanes(self, shift: usize) -> Self { + self >> shift + } + + #[inline] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } + + #[inline] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_hi_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } } unsafe impl Zeroable for M128 {} @@ -781,27 +803,14 @@ impl UnderlierWithBitConstants for M128 { #[inline(always)] fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) { - let (c, d) = unsafe { - interleave_bits( - Into::::into(self).into(), - Into::::into(other).into(), - log_block_len, - ) - }; - (Self::from(c), Self::from(d)) - } - - #[inline(always)] - fn transpose(self, other: Self, log_block_len: usize) -> (Self, Self) { - let (c, d) = unsafe { - transpose_bits( + unsafe { + let (c, d) = interleave_bits( Into::::into(self).into(), Into::::into(other).into(), log_block_len, - ) - }; - - (Self::from(c), Self::from(d)) + ); + (Self::from(c), Self::from(d)) + } } } @@ -897,46 +906,6 @@ unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m1 } } -#[inline] -unsafe fn transpose_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) { - match log_block_len { - 0..=3 => { - let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); - let (mut a, mut b) = transpose_with_shuffle(a, b, shuffle); - for log_block_len in (log_block_len..3).rev() { - (a, b) = interleave_bits(a, b, log_block_len); - } - - (a, b) - } - 4 => { - let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0); - transpose_with_shuffle(a, b, shuffle) - } - 5 => { - let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0); - transpose_with_shuffle(a, b, shuffle) - } - 6 => { - let c = _mm_unpacklo_epi64(a, b); - let d = _mm_unpackhi_epi64(a, b); - (c, d) - } - _ => panic!("unsupported block length"), - } -} - -#[inline(always)] -unsafe fn transpose_with_shuffle( - a: __m128i, - b: __m128i, - shuffle_mask: __m128i, -) -> (__m128i, __m128i) { - let a = _mm_shuffle_epi8(a, shuffle_mask); - let b = _mm_shuffle_epi8(b, shuffle_mask); - (_mm_unpacklo_epi64(a, b), _mm_unpackhi_epi64(a, b)) -} - #[inline] unsafe fn interleave_bits_imm( a: __m128i, @@ -977,6 +946,10 @@ mod tests { assert_eq!(M128::from(1u128), M128::ONE); } + fn get(value: M128, log_block_len: usize, index: usize) -> M128 { + (value >> (index << log_block_len)) & single_element_mask_bits::(1 << log_block_len) + } + proptest! { #[test] fn test_conversion(a in any::()) { @@ -1010,15 +983,36 @@ mod tests { let (c, d) = unsafe {interleave_bits(a.0, b.0, height)}; let (c, d) = (M128::from(c), M128::from(d)); - let block_len = 1usize << height; - let get = |v, i| { - u128::num_cast_from((v >> (i * block_len)) & single_element_mask_bits::(1 << height)) - }; - for i in (0..128/block_len).step_by(2) { - assert_eq!(get(c, i), get(a, i)); - assert_eq!(get(c, i+1), get(b, i)); - assert_eq!(get(d, i), get(a, i+1)); - assert_eq!(get(d, i+1), get(b, i+1)); + for i in (0..128>>height).step_by(2) { + assert_eq!(get(c, height, i), get(a, height, i)); + assert_eq!(get(c, height, i+1), get(b, height, i)); + assert_eq!(get(d, height, i), get(a, height, i+1)); + assert_eq!(get(d, height, i+1), get(b, height, i+1)); + } + } + + #[test] + fn test_unpack_lo(a in any::(), b in any::(), height in 1usize..7) { + let a = M128::from(a); + let b = M128::from(b); + + let result = a.unpack_lo_128b_lanes(b, height); + for i in 0..128>>(height + 1) { + assert_eq!(get(result, height, 2*i), get(a, height, i)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i)); + } + } + + #[test] + fn test_unpack_hi(a in any::(), b in any::(), height in 1usize..7) { + let a = M128::from(a); + let b = M128::from(b); + + let result = a.unpack_hi_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count)); } } } diff --git a/crates/field/src/arch/x86_64/m256.rs b/crates/field/src/arch/x86_64/m256.rs index c548cc2f3..a2fc71ab8 100644 --- a/crates/field/src/arch/x86_64/m256.rs +++ b/crates/field/src/arch/x86_64/m256.rs @@ -9,6 +9,7 @@ use std::{ use bytemuck::{must_cast, Pod, Zeroable}; use cfg_if::cfg_if; use rand::{Rng, RngCore}; +use seq_macro::seq; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use crate::{ @@ -20,11 +21,13 @@ use crate::{ interleave_mask_even, interleave_mask_odd, UnderlierWithBitConstants, }, }, + x86_64::m128::bitshift_128b, }, arithmetic_traits::Broadcast, underlier::{ get_block_values, get_spread_bytes, impl_divisible, impl_iteration, spread_fallback, - NumCast, Random, SmallU, UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4, + unpack_hi_128b_fallback, unpack_lo_128b_fallback, NumCast, Random, SmallU, UnderlierType, + UnderlierWithBitOps, WithUnderlier, U1, U2, U4, }, BinaryField, }; @@ -323,7 +326,7 @@ impl UnderlierType for M256 { impl UnderlierWithBitOps for M256 { const ZERO: Self = { Self(m256_from_u128s!(0, 0,)) }; - const ONE: Self = { Self(m256_from_u128s!(0, 1,)) }; + const ONE: Self = { Self(m256_from_u128s!(1, 0,)) }; const ONES: Self = { Self(m256_from_u128s!(u128::MAX, u128::MAX,)) }; #[inline] @@ -834,6 +837,58 @@ impl UnderlierWithBitOps for M256 { _ => spread_fallback(self, log_block_len, block_idx), } } + + #[inline] + fn shr_128b_lanes(self, rhs: usize) -> Self { + // This implementation is effective when `rhs` is known at compile-time. + // In our code this is always the case. + bitshift_128b!( + self.0, + rhs, + _mm256_bsrli_epi128, + _mm256_srli_epi64, + _mm256_slli_epi64, + _mm256_or_si256 + ) + } + + #[inline] + fn shl_128b_lanes(self, rhs: usize) -> Self { + // This implementation is effective when `rhs` is known at compile-time. + // In our code this is always the case. + bitshift_128b!( + self.0, + rhs, + _mm256_bslli_epi128, + _mm256_slli_epi64, + _mm256_srli_epi64, + _mm256_or_si256 + ); + } + + #[inline] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm256_unpacklo_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm256_unpacklo_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm256_unpacklo_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm256_unpacklo_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } + + #[inline] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_hi_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm256_unpackhi_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm256_unpackhi_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm256_unpackhi_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm256_unpackhi_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } } unsafe impl Zeroable for M256 {} @@ -1132,7 +1187,7 @@ mod tests { fn test_constants() { assert_eq!(M256::default(), M256::ZERO); assert_eq!(M256::from(0u128), M256::ZERO); - assert_eq!(M256::from([0u128, 1u128]), M256::ONE); + assert_eq!(M256::from([1u128, 0u128]), M256::ONE); } #[derive(Default)] @@ -1196,6 +1251,10 @@ mod tests { } } + fn get(value: M256, log_block_len: usize, index: usize) -> M256 { + (value >> (index << log_block_len)) & single_element_mask_bits::(1 << log_block_len) + } + proptest! { #[allow(clippy::tuple_array_conversions)] // false positive #[test] @@ -1232,14 +1291,41 @@ mod tests { let (c, d) = (M256::from(c), M256::from(d)); let block_len = 1usize << height; - let get = |v, i| { - u128::num_cast_from((v >> (i * block_len)) & single_element_mask_bits::(1 << height)) - }; for i in (0..256/block_len).step_by(2) { - assert_eq!(get(c, i), get(a, i)); - assert_eq!(get(c, i+1), get(b, i)); - assert_eq!(get(d, i), get(a, i+1)); - assert_eq!(get(d, i+1), get(b, i+1)); + assert_eq!(get(c, height, i), get(a, height, i)); + assert_eq!(get(c, height, i+1), get(b, height, i)); + assert_eq!(get(d, height, i), get(a, height, i+1)); + assert_eq!(get(d, height, i+1), get(b, height, i+1)); + } + } + + #[test] + fn test_unpack_lo(a in any::<[u128; 2]>(), b in any::<[u128; 2]>(), height in 0usize..7) { + let a = M256::from(a); + let b = M256::from(b); + + let result = a.unpack_lo_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i)); + assert_eq!(get(result, height, 2*(i + half_block_count)), get(a, height, 2 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + half_block_count)+1), get(b, height, 2 * half_block_count + i)); + } + } + + #[test] + fn test_unpack_hi(a in any::<[u128; 2]>(), b in any::<[u128; 2]>(), height in 0usize..7) { + let a = M256::from(a); + let b = M256::from(b); + + let result = a.unpack_hi_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count)); + assert_eq!(get(result, height, 2*(half_block_count + i)), get(a, height, 3*half_block_count + i)); + assert_eq!(get(result, height, 2*(half_block_count + i) +1), get(b, height, 3*half_block_count + i)); } } } diff --git a/crates/field/src/arch/x86_64/m512.rs b/crates/field/src/arch/x86_64/m512.rs index a7de540e5..8afc28251 100644 --- a/crates/field/src/arch/x86_64/m512.rs +++ b/crates/field/src/arch/x86_64/m512.rs @@ -8,6 +8,7 @@ use std::{ use bytemuck::{must_cast, Pod, Zeroable}; use rand::{Rng, RngCore}; +use seq_macro::seq; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; use crate::{ @@ -19,12 +20,16 @@ use crate::{ interleave_mask_even, interleave_mask_odd, UnderlierWithBitConstants, }, }, - x86_64::{m128::M128, m256::M256}, + x86_64::{ + m128::{bitshift_128b, M128}, + m256::M256, + }, }, arithmetic_traits::Broadcast, underlier::{ get_block_values, get_spread_bytes, impl_divisible, impl_iteration, spread_fallback, - NumCast, Random, SmallU, UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4, + unpack_hi_128b_fallback, unpack_lo_128b_fallback, NumCast, Random, SmallU, UnderlierType, + UnderlierWithBitOps, WithUnderlier, U1, U2, U4, }, BinaryField, }; @@ -370,7 +375,7 @@ impl UnderlierType for M512 { impl UnderlierWithBitOps for M512 { const ZERO: Self = { Self(m512_from_u128s!(0, 0, 0, 0,)) }; - const ONE: Self = { Self(m512_from_u128s!(0, 0, 0, 1,)) }; + const ONE: Self = { Self(m512_from_u128s!(1, 0, 0, 0,)) }; const ONES: Self = { Self(m512_from_u128s!(u128::MAX, u128::MAX, u128::MAX, u128::MAX,)) }; #[inline(always)] @@ -865,6 +870,58 @@ impl UnderlierWithBitOps for M512 { _ => spread_fallback(self, log_block_len, block_idx), } } + + #[inline] + fn shr_128b_lanes(self, rhs: usize) -> Self { + // This implementation is effective when `rhs` is known at compile-time. + // In our code this is always the case. + bitshift_128b!( + self.0, + rhs, + _mm512_bsrli_epi128, + _mm512_srli_epi64, + _mm512_slli_epi64, + _mm512_or_si512 + ); + } + + #[inline] + fn shl_128b_lanes(self, rhs: usize) -> Self { + // This implementation is effective when `rhs` is known at compile-time. + // In our code this is always the case. + bitshift_128b!( + self.0, + rhs, + _mm512_bslli_epi128, + _mm512_slli_epi64, + _mm512_srli_epi64, + _mm512_or_si512 + ); + } + + #[inline] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_lo_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm512_unpacklo_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm512_unpacklo_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm512_unpacklo_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm512_unpacklo_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } + + #[inline] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + match log_block_len { + 0..3 => unpack_hi_128b_fallback(self, other, log_block_len), + 3 => unsafe { _mm512_unpackhi_epi8(self.0, other.0).into() }, + 4 => unsafe { _mm512_unpackhi_epi16(self.0, other.0).into() }, + 5 => unsafe { _mm512_unpackhi_epi32(self.0, other.0).into() }, + 6 => unsafe { _mm512_unpackhi_epi64(self.0, other.0).into() }, + _ => panic!("unsupported block length"), + } + } } unsafe impl Zeroable for M512 {} @@ -1233,7 +1290,7 @@ mod tests { fn test_constants() { assert_eq!(M512::default(), M512::ZERO); assert_eq!(M512::from(0u128), M512::ZERO); - assert_eq!(M512::from([0u128, 0u128, 0u128, 1u128]), M512::ONE); + assert_eq!(M512::from([1u128, 0u128, 0u128, 0u128]), M512::ONE); } #[derive(Default)] @@ -1297,6 +1354,10 @@ mod tests { } } + fn get(value: M512, log_block_len: usize, index: usize) -> M512 { + (value >> (index << log_block_len)) & single_element_mask_bits::(1 << log_block_len) + } + proptest! { #[test] fn test_conversion(a in any::<[u128; 4]>()) { @@ -1330,14 +1391,49 @@ mod tests { let (c, d) = (M512::from(c), M512::from(d)); let block_len = 1usize << height; - let get = |v, i| { - u128::num_cast_from((v >> (i * block_len)) & single_element_mask_bits::(1 << height)) - }; for i in (0..512/block_len).step_by(2) { - assert_eq!(get(c, i), get(a, i)); - assert_eq!(get(c, i+1), get(b, i)); - assert_eq!(get(d, i), get(a, i+1)); - assert_eq!(get(d, i+1), get(b, i+1)); + assert_eq!(get(c, height, i), get(a, height, i)); + assert_eq!(get(c, height, i+1), get(b, height, i)); + assert_eq!(get(d, height, i), get(a, height, i+1)); + assert_eq!(get(d, height, i+1), get(b, height, i+1)); + } + } + + #[test] + fn test_unpack_lo(a in any::<[u128; 4]>(), b in any::<[u128; 4]>(), height in 0usize..7) { + let a = M512::from(a); + let b = M512::from(b); + + let result = a.unpack_lo_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i)); + assert_eq!(get(result, height, 2*(i + half_block_count)), get(a, height, 2 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + half_block_count)+1), get(b, height, 2 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + 2*half_block_count)), get(a, height, 4 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + 2*half_block_count)+1), get(b, height, 4 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + 3*half_block_count)), get(a, height, 6 * half_block_count + i)); + assert_eq!(get(result, height, 2*(i + 3*half_block_count)+1), get(b, height, 6 * half_block_count + i)); + } + } + + #[test] + fn test_unpack_hi(a in any::<[u128; 4]>(), b in any::<[u128; 4]>(), height in 0usize..7) { + let a = M512::from(a); + let b = M512::from(b); + + let result = a.unpack_hi_128b_lanes(b, height); + let half_block_count = 128>>(height + 1); + for i in 0..half_block_count { + assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count)); + assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count)); + assert_eq!(get(result, height, 2*(half_block_count + i)), get(a, height, 3*half_block_count + i)); + assert_eq!(get(result, height, 2*(half_block_count + i) +1), get(b, height, 3*half_block_count + i)); + assert_eq!(get(result, height, 2*(2*half_block_count + i)), get(a, height, 5*half_block_count + i)); + assert_eq!(get(result, height, 2*(2*half_block_count + i) +1), get(b, height, 5*half_block_count + i)); + assert_eq!(get(result, height, 2*(3*half_block_count + i)), get(a, height, 7*half_block_count + i)); + assert_eq!(get(result, height, 2*(3*half_block_count + i) +1), get(b, height, 7*half_block_count + i)); } } } diff --git a/crates/field/src/tower_levels.rs b/crates/field/src/tower_levels.rs index bdac60d06..bfddced06 100644 --- a/crates/field/src/tower_levels.rs +++ b/crates/field/src/tower_levels.rs @@ -16,110 +16,104 @@ use std::{ /// These separate implementations are necessary to overcome the limitations of const generics in Rust. /// These implementations eliminate costly bounds checking that would otherwise be imposed by the compiler /// and allow easy inlining of recursive functions. -pub trait TowerLevel -where - T: Default + Copy, -{ +pub trait TowerLevel { // WIDTH is ALWAYS a power of 2 const WIDTH: usize; // The underlying Data should ALWAYS be a fixed-width array of T's - type Data: AsMut<[T]> + type Data: AsMut<[T]> + AsRef<[T]> + Sized + Index + IndexMut; - type Base: TowerLevel; + type Base: TowerLevel; - // Split something of type Self::Data into two equal halves + // Split something of type Self::Datainto two equal halves #[allow(clippy::type_complexity)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data); + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data); - // Split something of type Self::Data into two equal mutable halves + // Split something of type Self::Datainto two equal mutable halves #[allow(clippy::type_complexity)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data); + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data); // Join two equal-length arrays (the reverse of split) #[allow(clippy::type_complexity)] - fn join( - first: &>::Data, - second: &>::Data, - ) -> Self::Data; + fn join( + first: &::Data, + second: &::Data, + ) -> Self::Data; // Fills an array of T's containing WIDTH elements - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data; + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data; // Fills an array of T's containing WIDTH elements with T::default() - fn default() -> Self::Data { + fn default() -> Self::Data { Self::from_fn(|_| T::default()) } } -pub trait TowerLevelWithArithOps: TowerLevel -where - T: Default + Add + AddAssign + Copy, -{ +pub trait TowerLevelWithArithOps: TowerLevel { #[inline(always)] - fn add_into(field_element: &Self::Data, destination: &mut Self::Data) { + fn add_into( + field_element: &Self::Data, + destination: &mut Self::Data, + ) { for i in 0..Self::WIDTH { destination[i] += field_element[i]; } } #[inline(always)] - fn copy_into(field_element: &Self::Data, destination: &mut Self::Data) { + fn copy_into(field_element: &Self::Data, destination: &mut Self::Data) { for i in 0..Self::WIDTH { destination[i] = field_element[i]; } } #[inline(always)] - fn sum(field_element_a: &Self::Data, field_element_b: &Self::Data) -> Self::Data { + fn sum>( + field_element_a: &Self::Data, + field_element_b: &Self::Data, + ) -> Self::Data { Self::from_fn(|i| field_element_a[i] + field_element_b[i]) } } -impl> TowerLevelWithArithOps for U where - T: Default + Add + AddAssign + Copy -{ -} +impl TowerLevelWithArithOps for T {} pub struct TowerLevel64; -impl TowerLevel for TowerLevel64 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel64 { const WIDTH: usize = 64; - type Data = [T; 64]; + type Data = [T; 64]; type Base = TowerLevel32; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..32].try_into().unwrap()), (data[32..64].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(32); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 64]; result[..32].copy_from_slice(left); result[32..].copy_from_slice(right); @@ -127,43 +121,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel32; -impl TowerLevel for TowerLevel32 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel32 { const WIDTH: usize = 32; - type Data = [T; 32]; + type Data = [T; 32]; type Base = TowerLevel16; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..16].try_into().unwrap()), (data[16..32].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(16); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 32]; result[..16].copy_from_slice(left); result[16..].copy_from_slice(right); @@ -171,43 +162,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel16; -impl TowerLevel for TowerLevel16 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel16 { const WIDTH: usize = 16; - type Data = [T; 16]; + type Data = [T; 16]; type Base = TowerLevel8; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..8].try_into().unwrap()), (data[8..16].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(8); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 16]; result[..8].copy_from_slice(left); result[8..].copy_from_slice(right); @@ -215,43 +203,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel8; -impl TowerLevel for TowerLevel8 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel8 { const WIDTH: usize = 8; - type Data = [T; 8]; + type Data = [T; 8]; type Base = TowerLevel4; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..4].try_into().unwrap()), (data[4..8].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(4); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 8]; result[..4].copy_from_slice(left); result[4..].copy_from_slice(right); @@ -259,43 +244,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel4; -impl TowerLevel for TowerLevel4 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel4 { const WIDTH: usize = 4; - type Data = [T; 4]; + type Data = [T; 4]; type Base = TowerLevel2; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..2].try_into().unwrap()), (data[2..4].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(2); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 4]; result[..2].copy_from_slice(left); result[2..].copy_from_slice(right); @@ -303,43 +285,40 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel2; -impl TowerLevel for TowerLevel2 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel2 { const WIDTH: usize = 2; - type Data = [T; 2]; + type Data = [T; 2]; type Base = TowerLevel1; #[inline(always)] - fn split( - data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + data: &Self::Data, + ) -> (&::Data, &::Data) { ((data[0..1].try_into().unwrap()), (data[1..2].try_into().unwrap())) } #[inline(always)] - fn split_mut( - data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { let (chunk_1, chunk_2) = data.split_at_mut(1); ((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap())) } #[inline(always)] - fn join<'a>( - left: &>::Data, - right: &>::Data, - ) -> Self::Data { + fn join( + left: &::Data, + right: &::Data, + ) -> Self::Data { let mut result = [T::default(); 2]; result[..1].copy_from_slice(left); result[1..].copy_from_slice(right); @@ -347,48 +326,45 @@ where } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } pub struct TowerLevel1; -impl TowerLevel for TowerLevel1 -where - T: Default + Copy, -{ +impl TowerLevel for TowerLevel1 { const WIDTH: usize = 1; - type Data = [T; 1]; + type Data = [T; 1]; type Base = Self; // Level 1 is the atomic unit of backing data and must not be split. #[inline(always)] - fn split( - _data: &Self::Data, - ) -> (&>::Data, &>::Data) { + fn split( + _data: &Self::Data, + ) -> (&::Data, &::Data) { unreachable!() } #[inline(always)] - fn split_mut( - _data: &mut Self::Data, - ) -> (&mut >::Data, &mut >::Data) { + fn split_mut( + _data: &mut Self::Data, + ) -> (&mut ::Data, &mut ::Data) { unreachable!() } #[inline(always)] - fn join<'a>( - _left: &>::Data, - _right: &>::Data, - ) -> Self::Data { + fn join( + _left: &::Data, + _right: &::Data, + ) -> Self::Data { unreachable!() } #[inline(always)] - fn from_fn(f: impl Fn(usize) -> T) -> Self::Data { + fn from_fn(f: impl FnMut(usize) -> T) -> Self::Data { array::from_fn(f) } } diff --git a/crates/field/src/underlier/scaled.rs b/crates/field/src/underlier/scaled.rs index 4210bc653..40cabc5c2 100644 --- a/crates/field/src/underlier/scaled.rs +++ b/crates/field/src/underlier/scaled.rs @@ -1,13 +1,16 @@ // Copyright 2024-2025 Irreducible Inc. -use std::array; +use std::{ + array, + ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr}, +}; use binius_utils::checked_arithmetics::checked_log_2; use bytemuck::{must_cast_mut, must_cast_ref, NoUninit, Pod, Zeroable}; use rand::RngCore; use subtle::{Choice, ConstantTimeEq}; -use super::{Divisible, Random, UnderlierType}; +use super::{Divisible, Random, UnderlierType, UnderlierWithBitOps}; /// A type that represents a pair of elements of the same underlier type. /// We use it as an underlier for the `ScaledPAckedField` type. @@ -104,3 +107,189 @@ where must_cast_mut::(self) } } + +impl + Copy, const N: usize> BitAnd for ScaledUnderlier { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + Self(array::from_fn(|i| self.0[i] & rhs.0[i])) + } +} + +impl BitAndAssign for ScaledUnderlier { + fn bitand_assign(&mut self, rhs: Self) { + for i in 0..N { + self.0[i] &= rhs.0[i]; + } + } +} + +impl + Copy, const N: usize> BitOr for ScaledUnderlier { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + Self(array::from_fn(|i| self.0[i] | rhs.0[i])) + } +} + +impl BitOrAssign for ScaledUnderlier { + fn bitor_assign(&mut self, rhs: Self) { + for i in 0..N { + self.0[i] |= rhs.0[i]; + } + } +} + +impl + Copy, const N: usize> BitXor for ScaledUnderlier { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self::Output { + Self(array::from_fn(|i| self.0[i] ^ rhs.0[i])) + } +} + +impl BitXorAssign for ScaledUnderlier { + fn bitxor_assign(&mut self, rhs: Self) { + for i in 0..N { + self.0[i] ^= rhs.0[i]; + } + } +} + +impl Shr for ScaledUnderlier { + type Output = Self; + + fn shr(self, rhs: usize) -> Self::Output { + let mut result = Self::default(); + + let shift_in_items = rhs / U::BITS; + for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) { + if i + shift_in_items < N { + result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS); + } + if i + shift_in_items + 1 < N && rhs % U::BITS != 0 { + result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS)); + } + } + + result + } +} + +impl Shl for ScaledUnderlier { + type Output = Self; + + fn shl(self, rhs: usize) -> Self::Output { + let mut result = Self::default(); + + let shift_in_items = rhs / U::BITS; + for i in shift_in_items.saturating_sub(1)..N { + if i >= shift_in_items { + result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS); + } + if i > shift_in_items && rhs % U::BITS != 0 { + result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS)); + } + } + + result + } +} + +impl, const N: usize> Not for ScaledUnderlier { + type Output = Self; + + fn not(self) -> Self::Output { + Self(self.0.map(U::not)) + } +} + +impl UnderlierWithBitOps for ScaledUnderlier { + const ZERO: Self = Self([U::ZERO; N]); + const ONE: Self = { + let mut arr = [U::ZERO; N]; + arr[0] = U::ONE; + Self(arr) + }; + const ONES: Self = Self([U::ONES; N]); + + #[inline] + fn fill_with_bit(val: u8) -> Self { + Self(array::from_fn(|_| U::fill_with_bit(val))) + } + + #[inline] + fn shl_128b_lanes(self, rhs: usize) -> Self { + // We assume that the underlier type has at least 128 bits as the current implementation + // is valid for this case only. + // On practice, we don't use scaled underliers with underlier types that have less than 128 bits. + assert!(U::BITS >= 128); + + Self(self.0.map(|x| x.shl_128b_lanes(rhs))) + } + + #[inline] + fn shr_128b_lanes(self, rhs: usize) -> Self { + // We assume that the underlier type has at least 128 bits as the current implementation + // is valid for this case only. + // On practice, we don't use scaled underliers with underlier types that have less than 128 bits. + assert!(U::BITS >= 128); + + Self(self.0.map(|x| x.shr_128b_lanes(rhs))) + } + + #[inline] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + // We assume that the underlier type has at least 128 bits as the current implementation + // is valid for this case only. + // On practice, we don't use scaled underliers with underlier types that have less than 128 bits. + assert!(U::BITS >= 128); + + Self(array::from_fn(|i| self.0[i].unpack_lo_128b_lanes(other.0[i], log_block_len))) + } + + #[inline] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + // We assume that the underlier type has at least 128 bits as the current implementation + // is valid for this case only. + // On practice, we don't use scaled underliers with underlier types that have less than 128 bits. + assert!(U::BITS >= 128); + + Self(array::from_fn(|i| self.0[i].unpack_hi_128b_lanes(other.0[i], log_block_len))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_shr() { + let val = ScaledUnderlier::([0, 1, 2, 3]); + assert_eq!( + val >> 1, + ScaledUnderlier::([0b10000000, 0b00000000, 0b10000001, 0b00000001]) + ); + assert_eq!( + val >> 2, + ScaledUnderlier::([0b01000000, 0b10000000, 0b11000000, 0b00000000]) + ); + assert_eq!( + val >> 8, + ScaledUnderlier::([0b00000001, 0b00000010, 0b00000011, 0b00000000]) + ); + assert_eq!( + val >> 9, + ScaledUnderlier::([0b00000000, 0b10000001, 0b00000001, 0b00000000]) + ); + } + + #[test] + fn test_shl() { + let val = ScaledUnderlier::([0, 1, 2, 3]); + assert_eq!(val << 1, ScaledUnderlier::([0, 2, 4, 6])); + assert_eq!(val << 2, ScaledUnderlier::([0, 4, 8, 12])); + assert_eq!(val << 8, ScaledUnderlier::([0, 0, 1, 2])); + assert_eq!(val << 9, ScaledUnderlier::([0, 0, 2, 4])); + } +} diff --git a/crates/field/src/underlier/small_uint.rs b/crates/field/src/underlier/small_uint.rs index 33a1e9b1c..3413f5cab 100644 --- a/crates/field/src/underlier/small_uint.rs +++ b/crates/field/src/underlier/small_uint.rs @@ -159,6 +159,14 @@ impl UnderlierWithBitOps for SmallU { fn fill_with_bit(val: u8) -> Self { Self(u8::fill_with_bit(val)) & Self::ONES } + + fn shl_128b_lanes(self, rhs: usize) -> Self { + self << rhs + } + + fn shr_128b_lanes(self, rhs: usize) -> Self { + self >> rhs + } } impl From> for u8 { diff --git a/crates/field/src/underlier/underlier_impls.rs b/crates/field/src/underlier/underlier_impls.rs index bb52b689a..e28e27dca 100644 --- a/crates/field/src/underlier/underlier_impls.rs +++ b/crates/field/src/underlier/underlier_impls.rs @@ -23,6 +23,16 @@ macro_rules! impl_underlier_type { debug_assert!(val == 0 || val == 1); (val as Self).wrapping_neg() } + + #[inline(always)] + fn shl_128b_lanes(self, rhs: usize) -> Self { + self << rhs + } + + #[inline(always)] + fn shr_128b_lanes(self, rhs: usize) -> Self { + self >> rhs + } } }; () => {}; diff --git a/crates/field/src/underlier/underlier_with_bit_ops.rs b/crates/field/src/underlier/underlier_with_bit_ops.rs index e151f5010..8cff6d187 100644 --- a/crates/field/src/underlier/underlier_with_bit_ops.rs +++ b/crates/field/src/underlier/underlier_with_bit_ops.rs @@ -118,6 +118,38 @@ pub trait UnderlierWithBitOps: { spread_fallback(self, log_block_len, block_idx) } + + /// Left shift within 128-bit lanes. + /// This can be more efficient than the full `Shl` implementation. + fn shl_128b_lanes(self, shift: usize) -> Self; + + /// Right shift within 128-bit lanes. + /// This can be more efficient than the full `Shr` implementation. + fn shr_128b_lanes(self, shift: usize) -> Self; + + /// Unpacks `1 << log_block_len`-bit values from low parts of `self` and `other` within 128-bit lanes. + /// + /// Example: + /// self: [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7] + /// other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7] + /// log_block_len: 1 + /// + /// result: [a_0, a_0, b_0, b_1, a_2, a_3, b_2, b_3] + fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + unpack_lo_128b_fallback(self, other, log_block_len) + } + + /// Unpacks `1 << log_block_len`-bit values from high parts of `self` and `other` within 128-bit lanes. + /// + /// Example: + /// self: [a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7] + /// other: [b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7] + /// log_block_len: 1 + /// + /// result: [a_4, a_5, b_4, b_5, a_6, a_7, b_6, b_7] + fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self { + unpack_hi_128b_fallback(self, other, log_block_len) + } } /// Returns a bit mask for a single `T` element inside underlier type. @@ -171,6 +203,55 @@ where result } +#[inline(always)] +fn single_element_mask_bits_128b_lanes(log_block_len: usize) -> T { + let mut mask = single_element_mask_bits(1 << log_block_len); + for i in 1..T::BITS / 128 { + mask |= mask << (i * 128); + } + + mask +} + +pub(crate) fn unpack_lo_128b_fallback( + lhs: T, + rhs: T, + log_block_len: usize, +) -> T { + assert!(log_block_len <= 6); + + let mask = single_element_mask_bits_128b_lanes::(log_block_len); + + let mut result = T::ZERO; + for i in 0..1 << (6 - log_block_len) { + result |= ((lhs.shr_128b_lanes(i << log_block_len)) & mask) + .shl_128b_lanes(i << (log_block_len + 1)); + result |= ((rhs.shr_128b_lanes(i << log_block_len)) & mask) + .shl_128b_lanes((2 * i + 1) << log_block_len); + } + + result +} + +pub(crate) fn unpack_hi_128b_fallback( + lhs: T, + rhs: T, + log_block_len: usize, +) -> T { + assert!(log_block_len <= 6); + + let mask = single_element_mask_bits_128b_lanes::(log_block_len); + let mut result = T::ZERO; + for i in 0..1 << (6 - log_block_len) { + result |= ((lhs.shr_128b_lanes(64 + (i << log_block_len))) & mask) + .shl_128b_lanes(i << (log_block_len + 1)); + result |= ((rhs.shr_128b_lanes(64 + (i << log_block_len))) & mask) + .shl_128b_lanes((2 * i + 1) << log_block_len); + } + + result +} + pub(crate) fn single_element_mask_bits(bits_count: usize) -> T { if bits_count == T::BITS { !T::ZERO From 04a29d0ee833f4980a0b30b284fbc153cd925147 Mon Sep 17 00:00:00 2001 From: Artem Storozhuk Date: Thu, 23 Jan 2025 18:43:48 +0200 Subject: [PATCH 41/50] feat: Add example of LinearCombination column usage --- examples/Cargo.toml | 4 ++ examples/acc-linear-combination.rs | 106 +++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 examples/acc-linear-combination.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 07edfc0a3..a4196a240 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -72,6 +72,10 @@ path = "bitwise_ops.rs" name = "b32_mul" path = "b32_mul.rs" +[[example]] +name = "acc-linear-combination" +path = "acc-linear-combination.rs" + [lints.clippy] needless_range_loop = "allow" diff --git a/examples/acc-linear-combination.rs b/examples/acc-linear-combination.rs new file mode 100644 index 000000000..a1ea45d94 --- /dev/null +++ b/examples/acc-linear-combination.rs @@ -0,0 +1,106 @@ +use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; +use binius_field::{ + arch::OptimalUnderlier, packed::set_packed_slice, BinaryField128b, BinaryField1b, + BinaryField8b, ExtensionField, TowerField, +}; +use binius_macros::arith_expr; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F8 = BinaryField8b; +type F1 = BinaryField1b; + +fn bytes_decomposition_gadget( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + log_size: usize, + input: OracleId, +) -> Result { + builder.push_namespace(name); + + // Define 8 separate variables that represent bits (F1) of the particular byte behind `input` variable + let output_bits: [OracleId; 8] = + builder.add_committed_multiple("output_bits", log_size, F1::TOWER_LEVEL); + + // Define `output` variable that will store `input` bytes (we will compare this in our constraint below). + // Since we want to enforce decomposition, we use `LinearCombination` column which naturally fits for this purpose. + // We need to specify our coefficients now and later take care of defining bit columns and setting bit values appropriately + let output = builder.add_linear_combination( + "output", + log_size, + (0..8).map(|b| { + // Our coefficients are: + // + // 00000001 + // 00000010 + // 00000100 + // 00001000 + // 00010000 + // 00100000 + // 01000000 + // 10000000 + // + let basis = + >::basis(b).expect("index is less than extension degree"); + (output_bits[b], basis.into()) + }), + )?; + + if let Some(witness) = builder.witness() { + // Let's get actual value of bytes from memory of `input` variable + let input = witness.get::(input)?.as_slice::(); + + // Create exactly 8 columns in the witness each representing 1 bit from decomposition + let mut output_bits_witness: [_; 8] = output_bits.map(|id| witness.new_column::(id)); + + // Here we use packed type. Since constraint system is instantiated with F128, the packed type for our bits would be Packed128x1 + let output_bits = output_bits_witness.each_mut().map(|bit| bit.packed()); + + // Create 1 column where we will write bytes from input to compare in the constraint later + let mut output = witness.new_column::(output); + + // Get its memory + let output = output.as_mut_slice::(); + + // For each byte from the `input` we need to just copy it to the `output` and also + // we need to perform actual decomposition and write it in a form of packed bits to the `output_bits` + for z in 0..input.len() { + output[z] = input[z]; + + // Decompose particular byte value from the 'input' + let input_bits_bases = ExtensionField::::iter_bases(&input[z]); + + // Write decomposed bits to the memory of `output_bits` which expects them in the form of Packed128x1. + // It is important step, since our `output` variable is actually composed from `output_bits` + // and it expects to contain exactly the result of executing linear combination over bits stored + // behind `output_bits` + for (b, bit) in input_bits_bases.enumerate() { + set_packed_slice(output_bits[b], z, bit); + } + } + } + + // We just assert that every byte from `input` equals to correspondent byte from `output` + builder.assert_zero("s_box", [input, output], arith_expr!([i, o] = i - o).convert_field()); + builder.pop_namespace(); + Ok(output) +} + +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let log_size = 1usize; + + // Define set of bytes that we want to decompose + let p_in = unconstrained::(&mut builder, format!("p_in"), log_size).unwrap(); + + let _ = + bytes_decomposition_gadget(&mut builder, "bytes decomposition", log_size, p_in).unwrap(); + + let witness = builder.take_witness().unwrap(); + let cs = builder.build().unwrap(); + + validate_witness(&cs, &[], &witness).unwrap(); +} From af3b373e2a1fe1c45ab9550b93ea2cc68046fc5b Mon Sep 17 00:00:00 2001 From: Samuel Burnham <45365069+samuelburnham@users.noreply.github.com> Date: Fri, 24 Jan 2025 09:40:33 -0500 Subject: [PATCH 42/50] ci: Add basic Rust CI (#2) * ci: Add basic Rust CI * Fix test flags --- .github/workflows/benchmark.yml | 78 --------------- .github/workflows/ci.yml | 156 ++++++----------------------- .github/workflows/mirror.yml | 26 ----- examples/acc-linear-combination.rs | 2 +- 4 files changed, 33 insertions(+), 229 deletions(-) delete mode 100644 .github/workflows/benchmark.yml delete mode 100644 .github/workflows/mirror.yml diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml deleted file mode 100644 index 0c5fef7ea..000000000 --- a/.github/workflows/benchmark.yml +++ /dev/null @@ -1,78 +0,0 @@ -name: Nightly Benchmark - -on: - push: - branches: [ main ] - workflow_dispatch: - inputs: - ec2_instance_type: - description: 'Select EC2 instance type' - required: true - default: 'c7a-4xlarge' - type: choice - options: - - c7a-2xlarge - - c7a-4xlarge - - c8g-2xlarge - -permissions: - contents: write - checks: write - pull-requests: write - -jobs: - benchmark: - name: Continuous Benchmarking with Bencher - container: rustlang/rust:nightly - permissions: - checks: write - actions: write - runs-on: ${{ github.event_name == 'push' && github.ref_name == 'main' && 'c7a-4xlarge' || github.event.inputs.ec2_instance_type }} - steps: - - name: Checkout Repository - uses: actions/checkout@v4 - - name: Setup Bencher - uses: bencherdev/bencher@main - - name: Create Output Directory - run: mkdir output - - name: Execute Benchmark Tests - run: ./scripts/nightly_benchmarks.py --export-file output/result.json - - name: Track base branch benchmarks with Bencher - run: | - bencher run \ - --project ben \ - --token '${{ secrets.BENCHER_API_TOKEN }}' \ - --branch main \ - --testbed c7a-4xlarge \ - --threshold-measure latency \ - --threshold-test t_test \ - --threshold-max-sample-size 64 \ - --threshold-upper-boundary 0.99 \ - --thresholds-reset \ - --err \ - --adapter json \ - --github-actions '${{ secrets.GITHUB_TOKEN }}' \ - --file output/result.json - - name: Upload artifact - uses: actions/upload-artifact@v4 - with: - name: gh-pages - path: output/ - publish_results: - name: Publish Results to Github Page - needs: [benchmark] - runs-on: ubuntu-latest - steps: - - name: Download artifact - uses: actions/download-artifact@v4 - with: - name: gh-pages - - name: Deploy to GitHub Pages - uses: crazy-max/ghaction-github-pages@v4 - with: - repo: irreducibleoss/binius-benchmark - fqdn: benchmark.binius.xyz - target_branch: main - build_dir: ./ - env: - GITHUB_TOKEN: ${{ secrets.GH_TOKEN }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3341a164f..6962bb43d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,135 +1,43 @@ -name: Rust CI +name: Tests on: - push: - branches: [ main ] pull_request: - branches: [ main ] + push: + branches: main concurrency: - group: ${{ github.event_name }}-${{ github.ref }} - cancel-in-progress: ${{ github.event_name == 'pull_request' }} + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true jobs: - lint: - name: ${{ matrix.expand.name }} - runs-on: ${{ matrix.expand.runner }} - container: rustlang/rust:nightly - strategy: - matrix: - expand: - - runner: "ubuntu-latest" - name: "copyright-check" - cmd: "./scripts/check_copyright_notice.sh" - - runner: "ubuntu-latest" - name: "cargofmt" - cmd: "cargo fmt --check" - - runner: "ubuntu-latest" - name: "clippy" - cmd: "cargo clippy --all --all-features --tests --benches --examples -- -D warnings" - steps: - - name: Checkout Repository - uses: actions/checkout@v4 - - - name: Run Command - run: ${{ matrix.expand.cmd }} - build: - name: build-${{ matrix.expand.name }} - needs: [lint] - runs-on: ${{ matrix.expand.runner }} - env: - RUST_VERSION: 1.83.0 - container: rustlang/rust:nightly - strategy: - matrix: - expand: - - runner: "c7a-2xlarge" - name: "debug-wasm" - cmd: "rustup target add wasm32-unknown-unknown && cargo build --package binius_field --target wasm32-unknown-unknown" - - runner: "c7a-2xlarge" - name: "debug-amd" - cmd: "cargo build --tests --benches --examples" - - runner: "c7a-2xlarge" - name: "debug-amd-no-default-features" - cmd: "cargo build --tests --benches --examples --no-default-features" - - runner: "c7a-2xlarge" - name: "debug-amd-stable" - cmd: "cargo +$RUST_VERSION build --tests --benches --examples -p binius_core --features stable_only" - - runner: "c8g-2xlarge" - name: "debug-arm" - cmd: "cargo build --tests --benches --examples" - - runner: "c7a-2xlarge" - name: "docs" - cmd: 'cargo doc --no-deps; echo "" > target/doc/index.html' - steps: - - name: Checkout Repository - uses: actions/checkout@v4 - - name: AMD job configuration template with stable Rust - if: ${{ matrix.expand.name == 'debug-amd-stable' }} - run: | - rustup set auto-self-update disable - rustup toolchain install $RUST_VERSION - - name: Run Command - run: ${{ matrix.expand.cmd }} - - name: Upload static files as artifact - if: ${{ matrix.expand.name == 'docs' }} - id: deployment - uses: actions/upload-pages-artifact@v3 - with: - path: "target/doc" test: - name: unit-test-${{ matrix.expand.name }} - needs: [build] - runs-on: ${{ matrix.expand.runner }} - env: - RUST_VERSION: 1.83.0 - container: rustlang/rust:nightly - strategy: - matrix: - expand: - - runner: "c7a-2xlarge" - name: "amd" - cmd: 'RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh' - - runner: "c7a-2xlarge" - name: "amd-portable" - cmd: 'RUSTFLAGS="-C target-cpu=generic" ./scripts/run_tests_and_examples.sh' - - runner: "c7a-2xlarge" - name: "amd-stable" - cmd: 'RUSTFLAGS="-C target-cpu=native" CARGO_STABLE=true ./scripts/run_tests_and_examples.sh' - - runner: "c7a-2xlarge" - name: "single-threaded" - cmd: 'RAYON_NUM_THREADS=1 RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh' - - runner: "c7a-2xlarge" - name: "no-default-features" - cmd: 'CARGO_EXTRA_FLAGS="--no-default-features" RUSTFLAGS="-C target-cpu=native" ./scripts/run_tests_and_examples.sh' - - runner: "c8g-2xlarge" - name: "arm" - cmd: 'RUSTFLAGS="-C target-cpu=native -C target-feature=+aes" ./scripts/run_tests_and_examples.sh' - - runner: "c8g-2xlarge" - name: "arm-portable" - cmd: 'RUSTFLAGS="-C target-cpu=generic" ./scripts/run_tests_and_examples.sh' + runs-on: ubuntu-latest steps: - - name: Checkout Repository - uses: actions/checkout@v4 - - name: AMD job configuration template with stable Rust - if: ${{ matrix.expand.name == 'amd-stable' }} - run: | - rustup set auto-self-update disable - rustup toolchain install $RUST_VERSION - - name: Run Command - run: ${{ matrix.expand.cmd }} - deploy: - name: deploy-pages - needs: [build] + - uses: actions/checkout@v4 + with: + repository: argumentcomputer/ci-workflows + - uses: ./.github/actions/ci-env + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: taiki-e/install-action@nextest + - uses: Swatinem/rust-cache@v2 + - name: Tests + run: cargo nextest run --cargo-profile test --workspace --run-ignored all + + lint: runs-on: ubuntu-latest - if: github.ref_name == 'main' - permissions: - pages: write - id-token: write - environment: - name: github-pages - url: ${{ steps.deployment.outputs.page_url }} steps: - - name: Deploy to GitHub Pages - id: deployment - uses: actions/deploy-pages@v4 + - uses: actions/checkout@v4 + with: + repository: argumentcomputer/ci-workflows + - uses: ./.github/actions/ci-env + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - name: Check Rustfmt Code Style + run: cargo fmt --all --check + - name: check *everything* compiles + run: cargo check --workspace --all-targets --all-features + - name: Check clippy warnings + run: cargo clippy --workspace --all-targets --all-features -- -D warnings + - name: Doctests + run: cargo test --doc --workspace diff --git a/.github/workflows/mirror.yml b/.github/workflows/mirror.yml deleted file mode 100644 index 7d10296ee..000000000 --- a/.github/workflows/mirror.yml +++ /dev/null @@ -1,26 +0,0 @@ -name: Mirror Repository - -on: - push: - branches: [ main ] - -permissions: - contents: read - -jobs: - mirror: - name: Mirror Repository to GitLab - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - uses: shimataro/ssh-key-action@v2 - with: - key: ${{ secrets.GIT_SSH_PRIVATE_KEY }} - name: id_rsa - known_hosts: ${{ secrets.GIT_SSH_KNOWN_HOSTS }} - - name: Mirror current ref to GitLab - run: | - git remote add gitlab ssh://git@gitlab.com/IrreducibleOSS/binius.git - git push gitlab ${{ github.ref }} diff --git a/examples/acc-linear-combination.rs b/examples/acc-linear-combination.rs index a1ea45d94..d73efbddf 100644 --- a/examples/acc-linear-combination.rs +++ b/examples/acc-linear-combination.rs @@ -94,7 +94,7 @@ fn main() { let log_size = 1usize; // Define set of bytes that we want to decompose - let p_in = unconstrained::(&mut builder, format!("p_in"), log_size).unwrap(); + let p_in = unconstrained::(&mut builder, "p_in".to_string(), log_size).unwrap(); let _ = bytes_decomposition_gadget(&mut builder, "bytes decomposition", log_size, p_in).unwrap(); From 2af6726a1f501da94063f979213bd1e5b05d6223 Mon Sep 17 00:00:00 2001 From: Artem Storozhuk Date: Fri, 31 Jan 2025 15:23:26 +0200 Subject: [PATCH 43/50] example: Linear combination with offset (#4) * example: Add linear-combination-with-offset usage example * chore: Add example for bit masking using LinearCombination * chore: Add byte decomposition constraint --- examples/Cargo.toml | 5 + .../acc-linear-combination-with-offset.rs | 127 +++++++++++++++ examples/acc-linear-combination.rs | 152 ++++++++++++++++++ 3 files changed, 284 insertions(+) create mode 100644 examples/acc-linear-combination-with-offset.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index a4196a240..8d7baccec 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -76,6 +76,11 @@ path = "b32_mul.rs" name = "acc-linear-combination" path = "acc-linear-combination.rs" + +[[example]] +name = "acc-linear-combination-with-offset" +path = "acc-linear-combination-with-offset.rs" + [lints.clippy] needless_range_loop = "allow" diff --git a/examples/acc-linear-combination-with-offset.rs b/examples/acc-linear-combination-with-offset.rs new file mode 100644 index 000000000..2bb595cf0 --- /dev/null +++ b/examples/acc-linear-combination-with-offset.rs @@ -0,0 +1,127 @@ +use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; +use binius_field::{ + arch::OptimalUnderlier, packed::set_packed_slice, AESTowerField128b, AESTowerField8b, + BinaryField1b, ExtensionField, PackedField, TowerField, +}; + +type U = OptimalUnderlier; +type F128 = AESTowerField128b; +type F8 = AESTowerField8b; +type F1 = BinaryField1b; + +fn aes_s_box(x: F8) -> F8 { + #[rustfmt::skip] + const S_BOX: [u8; 256] = [ + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, + 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, + 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, + 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, + 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, + 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, + 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, + 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, + 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, + 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, + 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, + 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, + 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, + 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, + 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, + 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, + 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, + ]; + let idx = u8::from(x) as usize; + F8::from(S_BOX[idx]) +} + +// AES s-box is equivalent to the affine transformation defined as follows: +// +// s[i] = b[i] + +// b[(i+4) mod 8] + +// b[(i+5) mod 8] + +// b[(i+6) mod 8] + +// b[(i+7) mod 8] + +// c[i] +// +// where 'b' is input byte, 's' is output byte, 'c' is constant which is equal to 0x63 (0b01100011) and 'i' is a bit position. +// The '+' operation is defined over Rijndael finite field : GF(2^8) = GF(2) [x] / (x^8 + x^4 + x^3 + x + 1). +// +const C: F8 = F8::new(0x63); +const AES_AFFINE_TRANSFORMATION: [F8; 8] = [ + F8::new(0b00011111), + F8::new(0b00111110), + F8::new(0b01111100), + F8::new(0b11111000), + F8::new(0b11110001), + F8::new(0b11100011), + F8::new(0b11000111), + F8::new(0b10001111), +]; + +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let log_size = 1usize; + let byte_in = unconstrained::(&mut builder, "input_byte", log_size).unwrap(); + + let bits: [OracleId; 8] = + builder.add_committed_multiple("decomposition", log_size, F1::TOWER_LEVEL); + + let byte_out = builder + .add_linear_combination_with_offset( + "lc", + log_size, + C.into(), + (0..8).map(|i| (bits[i], AES_AFFINE_TRANSFORMATION[i].into())), + ) + .unwrap(); + + if let Some(witness) = builder.witness() { + // get initial values of input bytes + let byte_in_values = witness.get::(byte_in).unwrap().as_slice::(); + + // create column for expected values of the output bytes + let mut byte_out_witness = witness.new_column::(byte_out); + let byte_out_values = byte_out_witness.as_mut_slice::(); + + // For each (inverted!) input byte, write correspondent bits to the decomposition + let mut bits_witness = bits.map(|bit| witness.new_column::(bit)); + let packed_bits = bits_witness.each_mut().map(|bit| bit.packed()); + + for byte_position in 0..byte_in_values.len() { + // write expected byte value to the output after applying s_box + byte_out_values[byte_position] = aes_s_box(byte_in_values[byte_position]); + + // invert input byte and write it to a decomposition bits + let input_inverted = byte_in_values[byte_position].invert_or_zero(); + + let bases = ExtensionField::::iter_bases(&input_inverted); + + for (bit_position, bit) in bases.clone().enumerate() { + set_packed_slice(packed_bits[bit_position], byte_position, bit); + } + } + } + + let witness = builder.take_witness().unwrap(); + let cs = builder.build().unwrap(); + + validate_witness(&cs, &[], &witness).unwrap(); +} diff --git a/examples/acc-linear-combination.rs b/examples/acc-linear-combination.rs index d73efbddf..e6a58a855 100644 --- a/examples/acc-linear-combination.rs +++ b/examples/acc-linear-combination.rs @@ -23,6 +23,20 @@ fn bytes_decomposition_gadget( let output_bits: [OracleId; 8] = builder.add_committed_multiple("output_bits", log_size, F1::TOWER_LEVEL); + let coefficients: [OracleId; 8] = + builder.add_committed_multiple("coeffs", log_size, F8::TOWER_LEVEL); + + let coeff_vals = [ + F8::new(0b00000001), + F8::new(0b00000010), + F8::new(0b00000100), + F8::new(0b00001000), + F8::new(0b00010000), + F8::new(0b00100000), + F8::new(0b01000000), + F8::new(0b10000000), + ]; + // Define `output` variable that will store `input` bytes (we will compare this in our constraint below). // Since we want to enforce decomposition, we use `LinearCombination` column which naturally fits for this purpose. // We need to specify our coefficients now and later take care of defining bit columns and setting bit values appropriately @@ -63,6 +77,18 @@ fn bytes_decomposition_gadget( // Get its memory let output = output.as_mut_slice::(); + // Write coefficients into the witness + let mut coeff_witness = + coefficients.map(|coefficient| witness.new_column::(coefficient)); + let coeff_witness = coeff_witness + .each_mut() + .map(|coeff| coeff.as_mut_slice::()); + for (idx, v) in coeff_witness.into_iter().enumerate() { + for vv in v { + *vv = coeff_vals[idx]; + } + } + // For each byte from the `input` we need to just copy it to the `output` and also // we need to perform actual decomposition and write it in a form of packed bits to the `output_bits` for z in 0..input.len() { @@ -83,6 +109,130 @@ fn bytes_decomposition_gadget( // We just assert that every byte from `input` equals to correspondent byte from `output` builder.assert_zero("s_box", [input, output], arith_expr!([i, o] = i - o).convert_field()); + + // Assert decomposition + builder.assert_zero( + "decomposition", + [ + input, + output_bits[0], + output_bits[1], + output_bits[2], + output_bits[3], + output_bits[4], + output_bits[5], + output_bits[6], + output_bits[7], + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + coefficients[4], + coefficients[5], + coefficients[6], + coefficients[7], + ], + arith_expr!( + [i, b0, b1, b2, b3, b4, b5, b6, b7, c0, c1, c2, c3, c4, c5, c6, c7] = + b0 * c0 + b1 * c1 + b2 * c2 + b3 * c3 + b4 * c4 + b5 * c5 + b6 * c6 + b7 * c7 - i + ) + .convert_field(), + ); + + builder.pop_namespace(); + Ok(output) +} + +fn elder_4bits_masking_gadget( + builder: &mut ConstraintSystemBuilder, + name: impl ToString, + log_size: usize, + input: OracleId, +) -> Result { + builder.push_namespace(name); + let output_bits: [OracleId; 8] = + builder.add_committed_multiple("output_bits", log_size, F1::TOWER_LEVEL); + + // we want to mask 4 elder bits in input byte + let lc_coefficients = [ + F8::new(0b00000001), + F8::new(0b00000010), + F8::new(0b00000100), + F8::new(0b00001000), + F8::new(0b00000000), + F8::new(0b00000000), + F8::new(0b00000000), + F8::new(0b00000000), + ]; + + let coefficients: [OracleId; 8] = + builder.add_committed_multiple("coeffs", log_size, F8::TOWER_LEVEL); + + let output = builder.add_linear_combination( + "output", + log_size, + (0..8).map(|b| (output_bits[b], lc_coefficients[b].into())), + )?; + + if let Some(witness) = builder.witness() { + // Write coefficients into the witness + let mut coeff_witness = + coefficients.map(|coefficient| witness.new_column::(coefficient)); + let coeff_witness = coeff_witness + .each_mut() + .map(|coeff| coeff.as_mut_slice::()); + for (idx, v) in coeff_witness.into_iter().enumerate() { + for vv in v { + *vv = lc_coefficients[idx]; + } + } + + let input = witness.get::(input)?.as_slice::(); + let mut output_bits_witness: [_; 8] = output_bits.map(|id| witness.new_column::(id)); + let output_bits = output_bits_witness.each_mut().map(|bit| bit.packed()); + let mut output = witness.new_column::(output); + let output = output.as_mut_slice::(); + for z in 0..input.len() { + // apply mask to the input byte + let byte_out_val = u8::from(input[z]) & 0x0F; + output[z] = F8::from(byte_out_val); + + let input_bits_bases = ExtensionField::::iter_bases(&input[z]); + for (b, bit) in input_bits_bases.enumerate() { + set_packed_slice(output_bits[b], z, bit); + } + } + } + + // Assert decomposition + builder.assert_zero( + "decomposition", + [ + output, + output_bits[0], + output_bits[1], + output_bits[2], + output_bits[3], + output_bits[4], + output_bits[5], + output_bits[6], + output_bits[7], + coefficients[0], + coefficients[1], + coefficients[2], + coefficients[3], + coefficients[4], + coefficients[5], + coefficients[6], + coefficients[7], + ], + arith_expr!( + [o, b0, b1, b2, b3, b4, b5, b6, b7, c0, c1, c2, c3, c4, c5, c6, c7] = + b0 * c0 + b1 * c1 + b2 * c2 + b3 * c3 + b4 * c4 + b5 * c5 + b6 * c6 + b7 * c7 - o + ) + .convert_field(), + ); + builder.pop_namespace(); Ok(output) } @@ -99,6 +249,8 @@ fn main() { let _ = bytes_decomposition_gadget(&mut builder, "bytes decomposition", log_size, p_in).unwrap(); + let _ = elder_4bits_masking_gadget(&mut builder, "masking", log_size, p_in).unwrap(); + let witness = builder.take_witness().unwrap(); let cs = builder.build().unwrap(); From fc05666f0efe06522bba8bc56b885249ddd2333f Mon Sep 17 00:00:00 2001 From: Artem Storozhuk Date: Mon, 3 Feb 2025 16:44:47 +0200 Subject: [PATCH 44/50] example: Implement bit-shifting/rotating and packing (#5) * example: Add example of Shifted column usage * example: Add example of Packed column usage * chore: Add 'unconstrained gadgets' warning --- examples/Cargo.toml | 9 +- .../acc-linear-combination-with-offset.rs | 2 + examples/acc-linear-combination.rs | 2 + examples/acc-packed.rs | 107 ++++++++++++++ examples/acc-shifted.rs | 138 ++++++++++++++++++ 5 files changed, 257 insertions(+), 1 deletion(-) create mode 100644 examples/acc-packed.rs create mode 100644 examples/acc-shifted.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 8d7baccec..fbf0c564c 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -76,11 +76,18 @@ path = "b32_mul.rs" name = "acc-linear-combination" path = "acc-linear-combination.rs" - [[example]] name = "acc-linear-combination-with-offset" path = "acc-linear-combination-with-offset.rs" +[[example]] +name = "acc-shifted" +path = "acc-shifted.rs" + +[[example]] +name = "acc-packed" +path = "acc-packed.rs" + [lints.clippy] needless_range_loop = "allow" diff --git a/examples/acc-linear-combination-with-offset.rs b/examples/acc-linear-combination-with-offset.rs index 2bb595cf0..619eeaec4 100644 --- a/examples/acc-linear-combination-with-offset.rs +++ b/examples/acc-linear-combination-with-offset.rs @@ -74,6 +74,8 @@ const AES_AFFINE_TRANSFORMATION: [F8; 8] = [ F8::new(0b10001111), ]; +// FIXME: Following gadget is unconstrained. Only for demonstrative purpose, don't use in production + fn main() { let allocator = bumpalo::Bump::new(); let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); diff --git a/examples/acc-linear-combination.rs b/examples/acc-linear-combination.rs index e6a58a855..5cb1eab54 100644 --- a/examples/acc-linear-combination.rs +++ b/examples/acc-linear-combination.rs @@ -11,6 +11,8 @@ type F128 = BinaryField128b; type F8 = BinaryField8b; type F1 = BinaryField1b; +// FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production + fn bytes_decomposition_gadget( builder: &mut ConstraintSystemBuilder, name: impl ToString, diff --git a/examples/acc-packed.rs b/examples/acc-packed.rs new file mode 100644 index 000000000..6963d8780 --- /dev/null +++ b/examples/acc-packed.rs @@ -0,0 +1,107 @@ +use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_core::constraint_system::validate::validate_witness; +use binius_field::{ + arch::OptimalUnderlier, BinaryField128b, BinaryField16b, BinaryField1b, BinaryField32b, + BinaryField8b, TowerField, +}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F32 = BinaryField32b; +type F16 = BinaryField16b; +type F8 = BinaryField8b; +type F1 = BinaryField1b; + +// FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production + +fn packing_32_bits_to_u32(builder: &mut ConstraintSystemBuilder) { + builder.push_namespace("packing_32_bits_to_u32"); + + let bits = unconstrained::(builder, "bits", F32::TOWER_LEVEL).unwrap(); + let packed = builder + .add_packed("packed", bits, F32::TOWER_LEVEL) + .unwrap(); + + if let Some(witness) = builder.witness() { + let bits = witness.get::(bits).unwrap(); + assert_eq!(bits.as_slice::().len(), 16); // 16x u8 + + let composition = bits.repacked::(); + assert_eq!(composition.as_slice::().len(), 4); // 4x u32 + + witness.set(packed, composition).unwrap(); + } + + // setting witness above is logically identical to the following "manual" data writing (using Little-Endian format): + /* + if let Some(witness) = builder.witness() { + let bytes_values = witness.get::(bits).unwrap().as_slice::(); + + let mut packed_witness = witness.new_column::(packed); + let slice = packed_witness.as_mut_slice::(); + + bytes_values.chunks(4).zip(slice.into_iter()).for_each(|(chunk, val)| { + *val = u32::from_le_bytes(chunk.try_into().unwrap()); + }); + } + */ + + builder.pop_namespace(); +} + +fn packing_4_bytes_to_u32(builder: &mut ConstraintSystemBuilder) { + builder.push_namespace("packing_4_bytes_to_u32"); + + let bytes = unconstrained::(builder, "bytes", F16::TOWER_LEVEL).unwrap(); + let packed = builder + .add_packed("packed", bytes, F16::TOWER_LEVEL) + .unwrap(); + + // 'repacked' approach doesn't work for this case, so let's write data to the witness "manually" + + if let Some(witness) = builder.witness() { + let bytes_val = witness.get::(bytes).unwrap().as_slice::(); + + let mut packed_witness = witness.new_column::(packed); + let slice = packed_witness.as_mut_slice::(); + + bytes_val + .chunks(4) + .zip(slice.iter_mut()) + .for_each(|(chunk, val)| { + *val = u32::from_le_bytes(chunk.try_into().unwrap()); + }); + } + + builder.pop_namespace(); +} + +fn packing_8_bits_to_u8(builder: &mut ConstraintSystemBuilder) { + builder.push_namespace("packing_8_bits_to_u8"); + + let bits = unconstrained::(builder, "bits", F8::TOWER_LEVEL).unwrap(); + let packed = builder.add_packed("packed", bits, F8::TOWER_LEVEL).unwrap(); + + if let Some(witness) = builder.witness() { + let bits_values = witness.get::(bits).unwrap(); + + witness.set::(packed, bits_values.repacked()).unwrap(); + } + + builder.pop_namespace(); +} + +fn main() { + let allocator = bumpalo::Bump::new(); + + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + packing_32_bits_to_u32(&mut builder); + packing_4_bytes_to_u32(&mut builder); + packing_8_bits_to_u8(&mut builder); + + let witness = builder.take_witness().unwrap(); + let cs = builder.build().unwrap(); + + validate_witness(&cs, &[], &witness).unwrap(); +} diff --git a/examples/acc-shifted.rs b/examples/acc-shifted.rs new file mode 100644 index 000000000..924fb8234 --- /dev/null +++ b/examples/acc-shifted.rs @@ -0,0 +1,138 @@ +use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_core::{constraint_system::validate::validate_witness, oracle::ShiftVariant}; +use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F1 = BinaryField1b; + +// FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production + +fn shift_right_gadget_u32(builder: &mut ConstraintSystemBuilder) { + builder.push_namespace("u32_right_shift"); + + // defined empirically and it is the same as 'block_bits' defined below + let log_size = 5usize; + + // create column and write arbitrary bytes to it + let input = unconstrained::(builder, "input", log_size).unwrap(); + + // we want to shift our u32 variable on 1 bit + let shift_offset = 1; + let shift_type = ShiftVariant::LogicalRight; + + // 'block_bits' defines type of integer to shift. Binius must understand how to treat actual data in memory behind the variable + // So for u32 we have 32 bits of data, 32 = 2 ^ 5. + let block_bits = 5; + let shifted = builder + .add_shifted("shifted", input, shift_offset, block_bits, shift_type) + .unwrap(); + + if let Some(witness) = builder.witness() { + // get input values from the witness + let input_values = witness.get::(input).unwrap().as_slice::(); // u32 + + // write shifted input to the output + let mut output_values = witness.new_column::(shifted); + let output_values = output_values.as_mut_slice::(); // u32 + for i in 0..input_values.len() { + output_values[i] = input_values[i] >> shift_offset; // shift right + } + } + + builder.pop_namespace(); +} + +fn shift_left_gadget_u8(builder: &mut ConstraintSystemBuilder) { + builder.push_namespace("u8_left_shift"); + let log_size = 3usize; + + let input = unconstrained::(builder, "input", log_size).unwrap(); + let shift_offset = 4; + let shift_type = ShiftVariant::LogicalLeft; + let block_bits = 3; + let shifted = builder + .add_shifted("shifted", input, shift_offset, block_bits, shift_type) + .unwrap(); + + if let Some(witness) = builder.witness() { + let input_values = witness.get::(input).unwrap().as_slice::(); // u8 + let mut output_values = witness.new_column::(shifted); + let output_values = output_values.as_mut_slice::(); // u8 + for i in 0..input_values.len() { + output_values[i] = input_values[i] << shift_offset; // shift left + } + } + + builder.pop_namespace(); +} + +fn rotate_left_gadget_u16(builder: &mut ConstraintSystemBuilder) { + builder.push_namespace("u16_rotate_right"); + let log_size = 4usize; + + let input = unconstrained::(builder, "input", log_size).unwrap(); + let rotation_offset = 5; + let rotation_type = ShiftVariant::CircularLeft; + let block_bits = 4usize; + let shifted = builder + .add_shifted("shifted", input, rotation_offset, block_bits, rotation_type) + .unwrap(); + + if let Some(witness) = builder.witness() { + // write rotated input to the output + let input_values = witness.get::(input).unwrap().as_slice::(); // u16 + let mut output_values = witness.new_column::(shifted); + let output_values = output_values.as_mut_slice::(); // u16 + for i in 0..input_values.len() { + output_values[i] = input_values[i].rotate_left(rotation_offset as u32) // rotate left + } + } + + builder.pop_namespace(); +} + +fn rotate_right_gadget_u64(builder: &mut ConstraintSystemBuilder) { + builder.push_namespace("u64_rotate_right"); + let log_size = 6usize; + + let input = unconstrained::(builder, "input", log_size).unwrap(); + + // Right rotation to X bits is achieved using 'ShiftVariant::CircularLeft' with the offset, + // computed as size in bits of the variable type - X (e.g. if we want to right-rotate u64 to 8 bits, + // we have to use CircularLeft with the offset = 64 - 8). + let rotation_offset = 8; + let rotation_type = ShiftVariant::CircularLeft; + let block_bits = 6usize; + let shifted = builder + .add_shifted("shifted", input, 64 - rotation_offset, block_bits, rotation_type) + .unwrap(); + + if let Some(witness) = builder.witness() { + // write rotated input to the output + let input_values = witness.get::(input).unwrap().as_slice::(); // u64 + let mut output_values = witness.new_column::(shifted); + let output_values = output_values.as_mut_slice::(); // u64 + for i in 0..input_values.len() { + output_values[i] = input_values[i].rotate_right(rotation_offset as u32) // rotate right + } + } + + builder.pop_namespace(); +} + +fn main() { + let allocator = bumpalo::Bump::new(); + + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + shift_right_gadget_u32(&mut builder); + shift_left_gadget_u8(&mut builder); + rotate_left_gadget_u16(&mut builder); + rotate_right_gadget_u64(&mut builder); + + let witness = builder.take_witness().unwrap(); + let cs = builder.build().unwrap(); + + validate_witness(&cs, &[], &witness).unwrap(); +} From 6928690250b2ac4cd77ccbfbe95bbafe7174e7c9 Mon Sep 17 00:00:00 2001 From: Artem Storozhuk Date: Thu, 6 Feb 2025 20:01:32 +0200 Subject: [PATCH 45/50] example: Projected / Repeated columns usage (#6) * example: Add example of Projected column usage * example: Add example of Repeated column usage * example: Add example of ZeroPadded column usage --- examples/Cargo.toml | 12 ++ examples/acc-projected.rs | 248 +++++++++++++++++++++++++++++++++++++ examples/acc-repeated.rs | 124 +++++++++++++++++++ examples/acc-zeropadded.rs | 45 +++++++ 4 files changed, 429 insertions(+) create mode 100644 examples/acc-projected.rs create mode 100644 examples/acc-repeated.rs create mode 100644 examples/acc-zeropadded.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index fbf0c564c..f719c24d5 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -88,6 +88,18 @@ path = "acc-shifted.rs" name = "acc-packed" path = "acc-packed.rs" +[[example]] +name = "acc-projected" +path = "acc-projected.rs" + +[[example]] +name = "acc-repeated" +path = "acc-repeated.rs" + +[[example]] +name = "acc-zeropadded" +path = "acc-zeropadded.rs" + [lints.clippy] needless_range_loop = "allow" diff --git a/examples/acc-projected.rs b/examples/acc-projected.rs new file mode 100644 index 000000000..8d4de121d --- /dev/null +++ b/examples/acc-projected.rs @@ -0,0 +1,248 @@ +use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_core::{constraint_system::validate::validate_witness, oracle::ProjectionVariant}; +use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F8 = BinaryField8b; + +#[derive(Clone)] +struct U8U128ProjectionInfo { + log_size: usize, + decimal: usize, + binary: Vec, + variant: ProjectionVariant, +} + +// The idea behind projection is that data from a column of some given field (F8) +// can be interpreted as a data of some or greater field (F128) and written to another column with equal or smaller length, +// which depends on LOG_SIZE and values of projection. Also two possible variants of projections are available, which +// has significant impact on input data processing. +// In the following example we have input column with bytes (u8) projected to the output column with u128 values. +fn projection( + builder: &mut ConstraintSystemBuilder, + projection_info: U8U128ProjectionInfo, + namespace: &str, +) { + builder.push_namespace(format!("projection {}", namespace)); + + let input = + unconstrained::(builder, "in", projection_info.clone().log_size).unwrap(); + + let projected = builder + .add_projected( + "projected", + input, + projection_info.clone().binary, + projection_info.clone().variant, + ) + .unwrap(); + + if let Some(witness) = builder.witness() { + let input_values = witness.get::(input).unwrap().as_slice::(); + let mut projected_witness = witness.new_column::(projected); + let projected_values = projected_witness.as_mut_slice::(); + + assert_eq!(projected_values.len(), projection_info.expected_projection_len()); + + match projection_info.variant { + ProjectionVariant::FirstVars => { + // Quite elaborated regularity, on my opinion + for idx in 0..projected_values.len() { + projected_values[idx] = F128::new( + input_values[(idx + * 2usize.pow(projection_info.clone().binary.len() as u32)) + + projection_info.clone().decimal] as u128, + ); + } + } + ProjectionVariant::LastVars => { + // decimal representation of the binary values is used as a simple offset + for idx in 0..projected_values.len() { + projected_values[idx] = + F128::new(input_values[projection_info.clone().decimal + idx] as u128); + } + } + }; + } + builder.pop_namespace(); +} + +impl U8U128ProjectionInfo { + fn new( + log_size: usize, + decimal: usize, + binary: Vec, + variant: ProjectionVariant, + ) -> U8U128ProjectionInfo { + assert!(log_size >= binary.len()); + + if variant == ProjectionVariant::LastVars { + // Pad with zeroes to LOG_SIZE len iterator. + // In this case we interpret binary values in a reverse order, meaning that the very first + // element is elder byte, so zeroes must be explicitly appended + let mut binary_clone = binary.clone(); + let mut zeroes = vec![F128::new(0u128); log_size - binary.len()]; + binary_clone.append(&mut zeroes); + + let coefficients = (0..binary_clone.len()) + .map(|degree| F128::new(2usize.pow(degree as u32) as u128)) + .collect::>(); + + let value = binary_clone + .iter() + .zip(coefficients.iter().rev()) + .fold(F128::new(0u128), |acc, (byte, coefficient)| acc + (*byte) * (*coefficient)); + + assert_eq!(decimal as u128, value.val()); + } + + U8U128ProjectionInfo { + log_size, + decimal, + binary, + variant, + } + } + + fn expected_projection_len(&self) -> usize { + 2usize.pow((self.log_size - self.binary.len()) as u32) + } +} + +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let projection_data = U8U128ProjectionInfo::new( + 4usize, + 9usize, + vec![ + F128::from(1u128), + F128::from(0u128), + F128::from(0u128), + F128::from(1u128), + ], + ProjectionVariant::FirstVars, + ); + projection(&mut builder, projection_data, "test_1"); + + let projection_data = U8U128ProjectionInfo::new( + 16usize, + 34816usize, + vec![ + F128::from(1u128), + F128::from(0u128), + F128::from(0u128), + F128::from(0u128), + F128::from(1u128), + ], + ProjectionVariant::LastVars, + ); + projection(&mut builder, projection_data, "test_2"); + + let projection_data = U8U128ProjectionInfo::new( + 4usize, + 15usize, + vec![ + F128::from(1u128), + F128::from(1u128), + F128::from(1u128), + F128::from(1u128), + ], + ProjectionVariant::LastVars, + ); + projection(&mut builder, projection_data, "test_3"); + + let projection_data = U8U128ProjectionInfo::new( + 6usize, + 60usize, + vec![ + F128::from(1u128), + F128::from(1u128), + F128::from(1u128), + F128::from(1u128), + ], + ProjectionVariant::LastVars, + ); + /* + With projection_data defined above we have 2^LOG_SIZE = 2^6 bytes in the input, + the size of projection is computed as follows: 2.pow(LOG_SIZE - binary.len()) = 2.pow(6 - 4) = 4. + the index of the input byte to use as projection is computed as follows (according to + a LastVars projection variant regularity): + + idx + decimal, e.g.: + + 0 + 60 + 1 + 60 + 2 + 60 + 3 + 60 + + where idx is [0..4]. + + Memory layout: + + input: [a5, a2, b1, 60, 91, ed, 5e, fb, ae, 1c, b2, 14, 92, 73, 92, c8, 56, 6d, fa, de, a8, 46, 77, 48, e1, cc, 90, 75, 78, d5, 19, be, 0c, 86, 39, 28, 0c, cc, e9, 4e, 46, d9, 84, 65, 4a, a2, b4, 64, eb, 59, 7b, fd, 3f, 0e, 2d, ea, 06, 42, a9, ea, (19), (8f), (19), (52)], len: 64 + output: [ + BinaryField128b(0x00000000000000000000000000000019), + BinaryField128b(0x0000000000000000000000000000008f), + BinaryField128b(0x00000000000000000000000000000019), + BinaryField128b(0x00000000000000000000000000000052) + ] + */ + projection(&mut builder, projection_data, "test_4"); + + let projection_data = U8U128ProjectionInfo::new( + 4usize, + 15usize, + vec![ + F128::from(1u128), + F128::from(1u128), + F128::from(1u128), + F128::from(1u128), + ], + ProjectionVariant::FirstVars, + ); + projection(&mut builder, projection_data, "test_5"); + + let projection_data = U8U128ProjectionInfo::new( + 6usize, + 13usize, + vec![ + F128::from(1u128), + F128::from(0u128), + F128::from(1u128), + F128::from(1u128), + ], + ProjectionVariant::FirstVars, + ); + /* + With projection_data defined above we have 2^LOG_SIZE = 2^6 bytes in the input, + the size of projection is computed as follows: 2.pow(LOG_SIZE - binary.len()) = 2.pow(6 - 4) = 4. + the index of the input byte to use as projection is computed as follows: + + idx * 2usize.pow(binary.len()) + decimal, e.g.: + + 0 * 2.pow(4) + 13 = 13, so input[13] + 1 * 2.pow(4) + 13 = 29, so input[29] + 2 * 2.pow(4) + 13 = 45, so input[45] + 3 * 2.pow(4) + 13 = 61, so input[61] + + where idx is [0..4] according to a FirstVars projection variant regularity. + + Memory layout: + + input: [18, d8, 58, d3, 24, f1, 8b, ec, 74, 1c, ab, 78, 13, (3e), 57, d7, 36, 15, 54, 50, 9a, cb, 98, 90, 58, cb, 79, 05, 83, (72), ea, 4d, f6, 3d, f3, 2f, af, e3, 32, 11, c9, 97, fb, ba, 24, (36), e9, 38, 7e, c7, a9, 68, bf, 31, 51, cf, 7b, 12, 20, 53, d8, (df), d7, cc], len: 64 + + output: BinaryField128b(0x0000000000000000000000000000003e) + output: BinaryField128b(0x00000000000000000000000000000072) + output: BinaryField128b(0x00000000000000000000000000000036) + output: BinaryField128b(0x000000000000000000000000000000df) + */ + projection(&mut builder, projection_data, "test_6"); + + let witness = builder.take_witness().unwrap(); + let cs = builder.build().unwrap(); + + validate_witness(&cs, &[], &witness).unwrap(); +} diff --git a/examples/acc-repeated.rs b/examples/acc-repeated.rs new file mode 100644 index 000000000..749bb6fee --- /dev/null +++ b/examples/acc-repeated.rs @@ -0,0 +1,124 @@ +use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_core::constraint_system::validate::validate_witness; +use binius_field::{ + arch::OptimalUnderlier, packed::set_packed_slice, BinaryField128b, BinaryField1b, + BinaryField8b, PackedBinaryField128x1b, +}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F8 = BinaryField8b; +type F1 = BinaryField1b; + +// FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production + +const LOG_SIZE: usize = 8; + +// The idea of 'Repeated' column is that one can just copy data from initial column multiple times, +// so new column is X times bigger than original one. The following gadget operates over bytes, e.g. +// it creates column with some input bytes written and then creates one more 'Repeated' column +// where the same bytes are copied multiple times. +fn bytes_repeat_gadget(builder: &mut ConstraintSystemBuilder) { + builder.push_namespace("bytes_repeat_gadget"); + + let bytes = unconstrained::(builder, "input", LOG_SIZE).unwrap(); + + let repeat_times_log = 4usize; + let repeating = builder + .add_repeating("repeating", bytes, repeat_times_log) + .unwrap(); + + if let Some(witness) = builder.witness() { + let input_values = witness.get::(bytes).unwrap().as_slice::(); + + let mut repeating_witness = witness.new_column::(repeating); + let repeating_values = repeating_witness.as_mut_slice::(); + + let repeat_times = 2usize.pow(repeat_times_log as u32); + assert_eq!(2usize.pow(LOG_SIZE as u32), input_values.len()); + assert_eq!(input_values.len() * repeat_times, repeating_values.len()); + + for idx in 0..repeat_times { + let start = idx * input_values.len(); + let end = start + input_values.len(); + repeating_values[start..end].copy_from_slice(input_values); + } + } + + builder.pop_namespace(); +} + +// Bit-oriented repeating is more elaborated due to a specifics of memory layout in Binius. +// In the following example, we use LOG_SIZE=8, which gives 2.pow(8) = 32 bytes written in the memory +// layout. This gives 32 * 8 = 256 bits of input information. Having that Repeated' column +// is instantiated with 'repeat_times_log = 2', this means that we have to repeat our bytes +// 2.pow(repeat_times_log) = 4 times ultimately. For setting bit values we use PackedBinaryField128x1b, +// so for 32 bytes (256 bits) of input data we use 2 PackedBinaryField128x1b elements. Considering 4 +// repetitions Binius creates column with 8 PackedBinaryField128x1b elements totally. +// Proper writing bits requires separate iterating over PackedBinaryField128x1b elements and input bytes +// with extracting particular bit values from the input and setting appropriate bit in a given PackedBinaryField128x1b. +fn bits_repeat_gadget(builder: &mut ConstraintSystemBuilder) { + builder.push_namespace("bits_repeat_gadget"); + + let bits = unconstrained::(builder, "input", LOG_SIZE).unwrap(); + let repeat_times_log = 2usize; + + // Binius will create column with appropriate height for us + let repeating = builder + .add_repeating("repeating", bits, repeat_times_log) + .unwrap(); + + if let Some(witness) = builder.witness() { + let input_values = witness.get::(bits).unwrap().as_slice::(); + let mut repeating_witness = witness.new_column::(repeating); + let output_values = repeating_witness.packed(); + + // this performs writing input bits exactly 1 time. Depending on number of repetitions we + // need to call this multiple times, providing offset for output values (PackedBinaryField128x1b elements) + fn write_input( + input_values: &[u8], + output_values: &mut [PackedBinaryField128x1b], + output_packed_offset: usize, + ) { + let mut output_index = output_packed_offset; + for (input_index, _) in (0..input_values.len()).enumerate() { + let byte = input_values[input_index]; + + set_packed_slice(output_values, output_index, F1::from(byte)); + set_packed_slice(output_values, output_index + 1, F1::from((byte >> 1) & 0x01)); + set_packed_slice(output_values, output_index + 2, F1::from((byte >> 2) & 0x01)); + set_packed_slice(output_values, output_index + 3, F1::from((byte >> 3) & 0x01)); + set_packed_slice(output_values, output_index + 4, F1::from((byte >> 4) & 0x01)); + set_packed_slice(output_values, output_index + 5, F1::from((byte >> 5) & 0x01)); + set_packed_slice(output_values, output_index + 6, F1::from((byte >> 6) & 0x01)); + set_packed_slice(output_values, output_index + 7, F1::from((byte >> 7) & 0x01)); + + output_index += 8; + } + } + + let repeat_times = 2u32.pow(repeat_times_log as u32); + + let mut offset = 0; + for _ in 0..repeat_times { + write_input(input_values, output_values, offset); + offset += input_values.len() * 8; + } + } + + builder.pop_namespace(); +} + +fn main() { + let allocator = bumpalo::Bump::new(); + + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + bytes_repeat_gadget(&mut builder); + bits_repeat_gadget(&mut builder); + + let witness = builder.take_witness().unwrap(); + let cs = builder.build().unwrap(); + + validate_witness(&cs, &[], &witness).unwrap(); +} diff --git a/examples/acc-zeropadded.rs b/examples/acc-zeropadded.rs new file mode 100644 index 000000000..2cbb954ca --- /dev/null +++ b/examples/acc-zeropadded.rs @@ -0,0 +1,45 @@ +use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; +use binius_core::constraint_system::validate::validate_witness; +use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F8 = BinaryField8b; + +const LOG_SIZE: usize = 4; + +fn main() { + let allocator = bumpalo::Bump::new(); + + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let bytes = unconstrained::(&mut builder, "bytes", LOG_SIZE).unwrap(); + + // Height of ZeroPadded column can't be smaller than input one. + // If n_vars equals to LOG_SIZE, then no padding is required, + // the ZeroPadded column will have same length as input one, so we use bigger number + let n_vars = 5usize; + let zeropadded = builder + .add_zero_padded("zeropadded", bytes, n_vars) + .unwrap(); + + if let Some(witness) = builder.witness() { + let input_values = witness.get::(bytes).unwrap().as_slice::(); + + let mut zeropadded_witness = witness.new_column::(zeropadded); + let zeropadded_values = zeropadded_witness.as_mut_slice::(); + + // padding naturally happens in the end, so we just copy input data to ZeroPadded column + zeropadded_values[..input_values.len()].copy_from_slice(input_values); + + assert_eq!(zeropadded_values.len(), 2usize.pow(n_vars as u32)); + assert!(n_vars >= LOG_SIZE); + let zeroes_to_pad = 2usize.pow(n_vars as u32) - 2usize.pow(LOG_SIZE as u32); + assert_eq!(zeroes_to_pad, zeropadded_values.len() - input_values.len()); + } + + let witness = builder.take_witness().unwrap(); + let cs = builder.build().unwrap(); + + validate_witness(&cs, &[], &witness).unwrap(); +} From 9ce3e3e8915d2206ce5e436d3c66dd150ae526e2 Mon Sep 17 00:00:00 2001 From: Artem Storozhuk Date: Mon, 17 Feb 2025 15:36:25 +0200 Subject: [PATCH 46/50] examples: Transparent columns usage (part 1) (#8) * feat: Add example of Transparent (Constant) column usage * example: Add example of Transparent (Powers) column usage * example: Add example of Transparent (DisjointProduct) column usage * example: Add example of Transparent (EqIndPartialEval) column usage --- examples/Cargo.toml | 16 ++++++ examples/acc-constants.rs | 88 +++++++++++++++++++++++++++++ examples/acc-disjoint-product.rs | 64 +++++++++++++++++++++ examples/acc-eq-ind-partial-eval.rs | 81 ++++++++++++++++++++++++++ examples/acc-powers.rs | 77 +++++++++++++++++++++++++ 5 files changed, 326 insertions(+) create mode 100644 examples/acc-constants.rs create mode 100644 examples/acc-disjoint-product.rs create mode 100644 examples/acc-eq-ind-partial-eval.rs create mode 100644 examples/acc-powers.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index f719c24d5..eb270f246 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -100,6 +100,22 @@ path = "acc-repeated.rs" name = "acc-zeropadded" path = "acc-zeropadded.rs" +[[example]] +name = "acc-powers" +path = "acc-powers.rs" + +[[example]] +name = "acc-constants" +path = "acc-constants.rs" + +[[example]] +name = "acc-disjoint-product" +path = "acc-disjoint-product.rs" + +[[example]] +name = "acc-eq-ind-partial-eval" +path = "acc-eq-ind-partial-eval.rs" + [lints.clippy] needless_range_loop = "allow" diff --git a/examples/acc-constants.rs b/examples/acc-constants.rs new file mode 100644 index 000000000..0947ed107 --- /dev/null +++ b/examples/acc-constants.rs @@ -0,0 +1,88 @@ +use binius_circuits::{builder::ConstraintSystemBuilder, sha256::u32const_repeating}; +use binius_core::{ + constraint_system::validate::validate_witness, oracle::OracleId, + transparent::constant::Constant, +}; +use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b, BinaryField32b}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F32 = BinaryField32b; +type F1 = BinaryField1b; + +const LOG_SIZE: usize = 4; + +// FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production + +fn constants_gadget( + name: impl ToString, + log_size: usize, + builder: &mut ConstraintSystemBuilder, + constant_value: u32, +) -> OracleId { + builder.push_namespace(name); + + let c = Constant::new(log_size, F32::new(constant_value)); + + let oracle = builder.add_transparent("constant", c).unwrap(); + + if let Some(witness) = builder.witness() { + let mut oracle_witness = witness.new_column::(oracle); + let values = oracle_witness.as_mut_slice::(); + for v in values { + *v = constant_value; + } + } + + builder.pop_namespace(); + + oracle +} + +// Transparent column can also naturally be used for storing some constants (also available for verifier). +// For example there is a 'u32const_repeating' function (in sha256 gadget) that does exactly this +// using Transparent + Repeated columns. Alternatively one can use Constant abstraction to create equivalent +// Transparent column. +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + pub const SHA256_INIT: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, + 0x5be0cd19, + ]; + + let oracles: [OracleId; 8] = + SHA256_INIT.map(|c| u32const_repeating(LOG_SIZE, &mut builder, c, "INIT").unwrap()); + if let Some(witness) = builder.witness() { + for (index, oracle) in oracles.into_iter().enumerate() { + let values = witness.get::(oracle).unwrap().as_slice::(); + + // every value in the column should match the expected one + for value in values { + assert_eq!(*value, SHA256_INIT[index]); + } + } + } + + let oracles: [OracleId; 8] = + SHA256_INIT.map(|c| constants_gadget("constants_gadget", LOG_SIZE, &mut builder, c)); + if let Some(witness) = builder.witness() { + for (index, oracle) in oracles.into_iter().enumerate() { + // The difference is here. With Constant we have to operate over F32, while + // with Transparent + Repeated approach as in 'u32const_repeating' we operate over F1, + // which can be more convenient in the bit-oriented computations + let values = witness.get::(oracle).unwrap().as_slice::(); + + // every value in the column should match the expected one + for value in values { + assert_eq!(*value, SHA256_INIT[index]); + } + } + } + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); +} diff --git a/examples/acc-disjoint-product.rs b/examples/acc-disjoint-product.rs new file mode 100644 index 000000000..05f4b59d9 --- /dev/null +++ b/examples/acc-disjoint-product.rs @@ -0,0 +1,64 @@ +use binius_circuits::builder::ConstraintSystemBuilder; +use binius_core::{ + constraint_system::validate::validate_witness, + transparent::{constant::Constant, disjoint_product::DisjointProduct, powers::Powers}, +}; +use binius_field::{ + arch::OptimalUnderlier, BinaryField, BinaryField128b, BinaryField8b, PackedField, +}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F8 = BinaryField8b; + +const LOG_SIZE: usize = 4; + +// FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production + +// DisjointProduct can be used for creating some more elaborated regularities over public data. +// In the following example we have a Transparent column with DisjointProduct instantiated over Powers +// and Constant. In this regularity, the DisjointProduct would be represented as a following expression: +// +// [ c * F8(x)^0, c * F8(x)^1, c * F8(x)^2, ... c * F8(x)^(2^LOG_SIZE) ], +// +// where +// 'x' is a multiplicative generator - a public value that exists for every BinaryField, +// 'c' is some (F8) constant. +// +// Also note, that DisjointProduct makes eventual Transparent column to have height (n_vars) which is sum +// of heights (n_vars) of Powers and Constant, so actual data could be repeated multiple times +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let generator = F8::MULTIPLICATIVE_GENERATOR; + let powers = Powers::new(LOG_SIZE, generator.into()); + + let constant_value = F8::new(0xf0); + let constant = Constant::new(LOG_SIZE, constant_value); + + let disjoint_product = DisjointProduct(powers, constant); + let disjoint_product_id = builder + .add_transparent("disjoint_product", disjoint_product) + .unwrap(); + + if let Some(witness) = builder.witness() { + let mut disjoint_product_witness = witness.new_column::(disjoint_product_id); + + let values = disjoint_product_witness.as_mut_slice::(); + + let mut exponent = 0u64; + for val in values.iter_mut() { + if exponent == 2u64.pow(LOG_SIZE as u32) { + exponent = 0; + } + *val = generator.pow(exponent) * constant_value; + exponent += 1; + } + } + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); +} diff --git a/examples/acc-eq-ind-partial-eval.rs b/examples/acc-eq-ind-partial-eval.rs new file mode 100644 index 000000000..7fcdc4b54 --- /dev/null +++ b/examples/acc-eq-ind-partial-eval.rs @@ -0,0 +1,81 @@ +use binius_circuits::builder::ConstraintSystemBuilder; +use binius_core::{ + constraint_system::validate::validate_witness, transparent::eq_ind::EqIndPartialEval, +}; +use binius_field::{arch::OptimalUnderlier, BinaryField128b, PackedField}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; + +const LOG_SIZE: usize = 3; + +// FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production + +// Currently, it is hard for me to imagine some real world use-cases where Transparent column specified by +// EqIndPartialEval could be useful. The program can use some of its data as challenges and the Transparent +// column with EqIndPartialEval will expect witness values defined as following: +// +// x_i * y_i + (1 - x_i) * (1 - y_i) +// +// where 'x_i' is an element from a particular row of basis matrix, and y_i is a given challenge. +// +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + // A truth table [000, 001, 010, 011 ... 111] where each row is in reversed order + let rev_basis = [ + vec![0, 0, 0], + vec![1, 0, 0], + vec![0, 1, 0], + vec![1, 1, 0], + vec![0, 0, 1], + vec![1, 0, 1], + vec![0, 1, 1], + vec![1, 1, 1], + ]; + + // rev_basis size correlates with LOG_SIZE + assert_eq!(1 << LOG_SIZE, rev_basis.len()); + + // let's choose some random challenges (each not greater than 1 << LOG_SIZE bits for this example) + let challenges = vec![F128::from(110), F128::from(190), F128::from(200)]; + + // challenges size correlates with LOG_SIZE + assert_eq!(challenges.len(), LOG_SIZE); + + let eq_ind_partial_eval = EqIndPartialEval::new(LOG_SIZE, challenges.clone()).unwrap(); + + let id = builder + .add_transparent("eq_ind_partial_eval", eq_ind_partial_eval) + .unwrap(); + + if let Some(witness) = builder.witness() { + let mut eq_witness = witness.new_column::(id); + + let column_values = eq_witness.as_mut_slice::(); + assert_eq!(column_values.len(), 1 << LOG_SIZE); + + let one = F128::one(); + + for (inv_basis_item, val) in rev_basis.iter().zip(column_values.iter_mut()) { + let mut value = F128::one(); + inv_basis_item + .iter() + .zip(challenges.iter()) + .for_each(|(x, y)| { + let x = F128::new(*x); + let y = *y; + + // following expression is defined in the EqIndPartialEval implementation + value *= x * y + (one - x) * (one - y); + }); + *val = value; + } + } + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); +} diff --git a/examples/acc-powers.rs b/examples/acc-powers.rs new file mode 100644 index 000000000..159aa1136 --- /dev/null +++ b/examples/acc-powers.rs @@ -0,0 +1,77 @@ +use binius_circuits::builder::ConstraintSystemBuilder; +use binius_core::constraint_system::validate::validate_witness; +use binius_field::{ + arch::OptimalUnderlier, BinaryField, BinaryField128b, BinaryField16b, BinaryField32b, + PackedField, +}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F32 = BinaryField32b; +type F16 = BinaryField16b; + +const LOG_SIZE: usize = 3; + +// FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production + +// Values for the Transparent columns are known to verifier, so they can be used for storing non-private data +// (like constants for example). The following gadget demonstrates how to use Powers abstraction to build a +// Transparent column that keeps following values (we write them during witness population): +// +// [ F32(x)^0, F32(x)^1 , F32(x)^2, ... F32(x)^(2^LOG_SIZE) ], + +// where 'x' is a multiplicative generator - a public value that exists for every BinaryField +// +fn powers_gadget_f32(builder: &mut ConstraintSystemBuilder, name: impl ToString) { + builder.push_namespace(name); + + let generator = F32::MULTIPLICATIVE_GENERATOR; + let powers = binius_core::transparent::powers::Powers::new(LOG_SIZE, generator.into()); + let transparent = builder + .add_transparent("Powers of F32 gen", powers) + .unwrap(); + + if let Some(witness) = builder.witness() { + let mut transparent_witness = witness.new_column::(transparent); + let transparent_values = transparent_witness.as_mut_slice::(); + for (exp, val) in transparent_values.iter_mut().enumerate() { + *val = generator.pow(exp as u64); + } + } + + builder.pop_namespace(); +} + +// Only Field is being changed +fn powers_gadget_f16(builder: &mut ConstraintSystemBuilder, name: impl ToString) { + builder.push_namespace(name); + + let generator = F16::MULTIPLICATIVE_GENERATOR; + let powers = binius_core::transparent::powers::Powers::new(LOG_SIZE, generator.into()); + let transparent = builder + .add_transparent("Powers of F16 gen", powers) + .unwrap(); + + if let Some(witness) = builder.witness() { + let mut transparent_witness = witness.new_column::(transparent); + let transparent_values = transparent_witness.as_mut_slice::(); + for (exp, val) in transparent_values.iter_mut().enumerate() { + *val = generator.pow(exp as u64); + } + } + + builder.pop_namespace(); +} + +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + powers_gadget_f16(&mut builder, "f16"); + powers_gadget_f32(&mut builder, "f32"); + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); +} From 5bb44de2c83323184125dc22030f562205da73f0 Mon Sep 17 00:00:00 2001 From: Artem Storozhuk Date: Mon, 17 Feb 2025 20:31:20 +0200 Subject: [PATCH 47/50] examples: Transparent columns usage (part 2) (#9) * example: Add example of Transparent (MultilinearExtensionTransparent) column usage * example: Add example of Transparent (SelectRow) column usage * example: Add example of Transparent (ShiftIndPartialEval) column usage * example: Add example of Transparent (StepDown) column usage * example: Add example of Transparent (StepUp) column usage * example: Add example of Transparent (TowerBasis) column usage --- examples/Cargo.toml | 25 +++++ .../acc-multilinear-extension-transparent.rs | 95 +++++++++++++++++++ examples/acc-select-row.rs | 35 +++++++ examples/acc-shift-ind-partial-eq.rs | 93 ++++++++++++++++++ examples/acc-step-down.rs | 34 +++++++ examples/acc-step-up.rs | 32 +++++++ examples/acc-tower-basis.rs | 60 ++++++++++++ 7 files changed, 374 insertions(+) create mode 100644 examples/acc-multilinear-extension-transparent.rs create mode 100644 examples/acc-select-row.rs create mode 100644 examples/acc-shift-ind-partial-eq.rs create mode 100644 examples/acc-step-down.rs create mode 100644 examples/acc-step-up.rs create mode 100644 examples/acc-tower-basis.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index eb270f246..86c0af570 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -15,6 +15,7 @@ binius_hash = { path = "../crates/hash" } binius_macros = { path = "../crates/macros" } binius_math = { path = "../crates/math" } binius_utils = { path = "../crates/utils", default-features = false } +bytemuck.workspace = true bumpalo.workspace = true bytesize.workspace = true clap = { version = "4.5.20", features = ["derive"] } @@ -116,6 +117,30 @@ path = "acc-disjoint-product.rs" name = "acc-eq-ind-partial-eval" path = "acc-eq-ind-partial-eval.rs" +[[example]] +name = "acc-multilinear-extension-transparent" +path = "acc-multilinear-extension-transparent.rs" + +[[example]] +name = "acc-select-row" +path = "acc-select-row.rs" + +[[example]] +name = "acc-shift-ind-partial-eq" +path = "acc-shift-ind-partial-eq.rs" + +[[example]] +name = "acc-step-down" +path = "acc-step-down.rs" + +[[example]] +name = "acc-step-up" +path = "acc-step-up.rs" + +[[example]] +name = "acc-tower-basis" +path = "acc-tower-basis.rs" + [lints.clippy] needless_range_loop = "allow" diff --git a/examples/acc-multilinear-extension-transparent.rs b/examples/acc-multilinear-extension-transparent.rs new file mode 100644 index 000000000..f9dd45706 --- /dev/null +++ b/examples/acc-multilinear-extension-transparent.rs @@ -0,0 +1,95 @@ +use binius_circuits::builder::ConstraintSystemBuilder; +use binius_core::{ + constraint_system::validate::validate_witness, transparent::MultilinearExtensionTransparent, +}; +use binius_field::{ + arch::OptimalUnderlier, as_packed_field::PackedType, underlier::WithUnderlier, BinaryField128b, + BinaryField1b, PackedField, +}; +use binius_utils::checked_arithmetics::log2_ceil_usize; +use bytemuck::{pod_collect_to_vec, Pod}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F1 = BinaryField1b; + +// From a perspective of circuits creation, MultilinearExtensionTransparent can be used naturally for decomposing integers to bits +fn decompose_transparent_u64(builder: &mut ConstraintSystemBuilder, x: u64) { + builder.push_namespace("decompose_transparent_u64"); + + let log_bits = log2_ceil_usize(64); + + let broadcasted = vec![x; 1 << (PackedType::::LOG_WIDTH.saturating_sub(log_bits))]; + + let broadcasted_decomposed = into_packed_vec::>(&broadcasted); + + let transparent_id = builder + .add_transparent( + "transparent", + MultilinearExtensionTransparent::<_, PackedType, _>::from_values( + broadcasted_decomposed, + ) + .unwrap(), + ) + .unwrap(); + + if let Some(witness) = builder.witness() { + let mut transparent_witness = witness.new_column::(transparent_id); + let values = transparent_witness.as_mut_slice::(); + values.fill(x); + } + + builder.pop_namespace(); +} + +fn decompose_transparent_u32(builder: &mut ConstraintSystemBuilder, x: u32) { + builder.push_namespace("decompose_transparent_u32"); + + let log_bits = log2_ceil_usize(32); + + let broadcasted = vec![x; 1 << (PackedType::::LOG_WIDTH.saturating_sub(log_bits))]; + + let broadcasted_decomposed = into_packed_vec::>(&broadcasted); + + let transparent_id = builder + .add_transparent( + "transparent", + MultilinearExtensionTransparent::<_, PackedType, _>::from_values( + broadcasted_decomposed, + ) + .unwrap(), + ) + .unwrap(); + + if let Some(witness) = builder.witness() { + let mut transparent_witness = witness.new_column::(transparent_id); + let values = transparent_witness.as_mut_slice::(); + values.fill(x); + } + + builder.pop_namespace(); +} + +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + decompose_transparent_u64(&mut builder, 0xff00ff00ff00ff00); + decompose_transparent_u32(&mut builder, 0x00ff00ff); + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); +} + +fn into_packed_vec

(src: &[impl Pod]) -> Vec

+where + P: PackedField + WithUnderlier, + P::Underlier: Pod, +{ + pod_collect_to_vec::<_, P::Underlier>(src) + .into_iter() + .map(P::from_underlier) + .collect() +} diff --git a/examples/acc-select-row.rs b/examples/acc-select-row.rs new file mode 100644 index 000000000..4337540bd --- /dev/null +++ b/examples/acc-select-row.rs @@ -0,0 +1,35 @@ +use binius_circuits::builder::ConstraintSystemBuilder; +use binius_core::{ + constraint_system::validate::validate_witness, transparent::select_row::SelectRow, +}; +use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F8 = BinaryField8b; + +const LOG_SIZE: usize = 8; + +// SelectRow expects exactly one witness value at particular index to be set. +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let index = 58; + assert!(index < 1 << LOG_SIZE); + + let select_row = SelectRow::new(LOG_SIZE, index).unwrap(); + let transparent = builder.add_transparent("select_row", select_row).unwrap(); + + if let Some(witness) = builder.witness() { + let mut transparent_witness = witness.new_column::(transparent); + let values = transparent_witness.as_mut_slice::(); + + values[index] = 0x01; + } + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); +} diff --git a/examples/acc-shift-ind-partial-eq.rs b/examples/acc-shift-ind-partial-eq.rs new file mode 100644 index 000000000..74fa2bae7 --- /dev/null +++ b/examples/acc-shift-ind-partial-eq.rs @@ -0,0 +1,93 @@ +use binius_circuits::builder::ConstraintSystemBuilder; +use binius_core::{ + constraint_system::validate::validate_witness, oracle::ShiftVariant, + transparent::shift_ind::ShiftIndPartialEval, +}; +use binius_field::{arch::OptimalUnderlier, util::eq, BinaryField128b, Field}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; + +// ShiftIndPartialEval is a more elaborated version of EqIndPartialEval. Same idea with challenges, but a bit more +// elaborated evaluation algorithm is used +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let block_size = 3; + let shift_offset = 4; + // Challenges have to be F128, but actual values in the witness could be of smaller field + let challenges = vec![ + F128::new(0xff00ff00ff00ff00ff00ff00ff00ff00), + F128::new(0x1f1f1f1f1f1f1f1f1f1f1f1f1f1f1f1f), + F128::new(0x2f2f2f2f2f2f2f2f2f2f2f2f2f2f2f2f), + ]; + let shift_variant = ShiftVariant::LogicalLeft; + + assert_eq!(block_size, challenges.len()); + + let shift_ind = + ShiftIndPartialEval::new(block_size, shift_offset, shift_variant, challenges.clone()) + .unwrap(); + + let transparent = builder.add_transparent("shift_ind", shift_ind).unwrap(); + + if let Some(witness) = builder.witness() { + let mut transparent_witness = witness.new_column::(transparent); + let values = transparent_witness.as_mut_slice::(); + + let lexicographical_order_x = [ + vec![F128::new(0), F128::new(0), F128::new(0)], + vec![F128::new(1), F128::new(0), F128::new(0)], + vec![F128::new(0), F128::new(1), F128::new(0)], + vec![F128::new(1), F128::new(1), F128::new(0)], + vec![F128::new(0), F128::new(0), F128::new(1)], + vec![F128::new(1), F128::new(0), F128::new(1)], + vec![F128::new(0), F128::new(1), F128::new(1)], + vec![F128::new(1), F128::new(1), F128::new(1)], + ]; + + assert_eq!(lexicographical_order_x.len(), 1 << block_size); + + for (val, x) in values.iter_mut().zip(lexicographical_order_x.into_iter()) { + *val = compute(block_size, shift_offset, shift_variant, x, challenges.clone()).val(); + } + } + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); +} + +// Evaluation logic taken from ShiftIndPartialEval implementation +fn compute( + block_size: usize, + shift_offset: usize, + shift_variant: ShiftVariant, + x: Vec, + y: Vec, +) -> F128 { + let (mut s_ind_p, mut s_ind_pp) = (F128::ONE, F128::ZERO); + let (mut temp_p, mut temp_pp) = (F128::default(), F128::default()); + (0..block_size).for_each(|k| { + let o_k = shift_offset >> k; + let product = x[k] * y[k]; + if o_k % 2 == 1 { + temp_p = (y[k] - product) * s_ind_p; + temp_pp = (x[k] - product) * s_ind_p + eq(x[k], y[k]) * s_ind_pp; + } else { + temp_p = eq(x[k], y[k]) * s_ind_p + (y[k] - product) * s_ind_pp; + temp_pp = (x[k] - product) * s_ind_pp; + } + // roll over results + s_ind_p = temp_p; + s_ind_pp = temp_pp; + }); + + match shift_variant { + ShiftVariant::CircularLeft => s_ind_p + s_ind_pp, + ShiftVariant::LogicalLeft => s_ind_p, + ShiftVariant::LogicalRight => s_ind_pp, + } +} diff --git a/examples/acc-step-down.rs b/examples/acc-step-down.rs new file mode 100644 index 000000000..cbe3cd6e6 --- /dev/null +++ b/examples/acc-step-down.rs @@ -0,0 +1,34 @@ +use binius_circuits::builder::ConstraintSystemBuilder; +use binius_core::{ + constraint_system::validate::validate_witness, transparent::step_down::StepDown, +}; +use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; + +const LOG_SIZE: usize = 8; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F8 = BinaryField8b; + +// StepDown expects all bytes to be set before particular index specified as input +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let index = 10; + + let step_down = StepDown::new(LOG_SIZE, index).unwrap(); + let transparent = builder.add_transparent("step_down", step_down).unwrap(); + + if let Some(witness) = builder.witness() { + let mut transparent_witness = witness.new_column::(transparent); + let values = transparent_witness.as_mut_slice::(); + + values[0..index].fill(0x01); + } + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); +} diff --git a/examples/acc-step-up.rs b/examples/acc-step-up.rs new file mode 100644 index 000000000..e659352e8 --- /dev/null +++ b/examples/acc-step-up.rs @@ -0,0 +1,32 @@ +use binius_circuits::builder::ConstraintSystemBuilder; +use binius_core::{constraint_system::validate::validate_witness, transparent::step_up::StepUp}; +use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; +type F8 = BinaryField8b; + +const LOG_SIZE: usize = 8; + +// StepUp expects all bytes to be unset before particular index specified as input (opposite to StepDown) +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let index = 10; + + let step_up = StepUp::new(LOG_SIZE, index).unwrap(); + let transparent = builder.add_transparent("step_up", step_up).unwrap(); + + if let Some(witness) = builder.witness() { + let mut transparent_witness = witness.new_column::(transparent); + let values = transparent_witness.as_mut_slice::(); + + values[index..].fill(0x01); + } + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); +} diff --git a/examples/acc-tower-basis.rs b/examples/acc-tower-basis.rs new file mode 100644 index 000000000..e99e6790b --- /dev/null +++ b/examples/acc-tower-basis.rs @@ -0,0 +1,60 @@ +use binius_circuits::builder::ConstraintSystemBuilder; +use binius_core::{ + constraint_system::validate::validate_witness, transparent::tower_basis::TowerBasis, +}; +use binius_field::{arch::OptimalUnderlier, BinaryField128b, Field, TowerField}; + +type U = OptimalUnderlier; +type F128 = BinaryField128b; + +// TowerBasis expects actually basis vectors written to the witness. +// The form of basis could vary depending on 'iota' and 'k' parameters +fn main() { + let allocator = bumpalo::Bump::new(); + let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + + let k = 3usize; + let iota = 4usize; + + assert!(k + iota < 8); + + let tower_basis = TowerBasis::new(k, iota).unwrap(); + let transparent = builder.add_transparent("tower_basis", tower_basis).unwrap(); + + if let Some(witness) = builder.witness() { + let mut transparent_witness = witness.new_column::(transparent); + let values = transparent_witness.as_mut_slice::(); + + let lexicographic_query = [ + vec![F128::new(0), F128::new(0), F128::new(0)], + vec![F128::new(1), F128::new(0), F128::new(0)], + vec![F128::new(0), F128::new(1), F128::new(0)], + vec![F128::new(1), F128::new(1), F128::new(0)], + vec![F128::new(0), F128::new(0), F128::new(1)], + vec![F128::new(1), F128::new(0), F128::new(1)], + vec![F128::new(0), F128::new(1), F128::new(1)], + vec![F128::new(1), F128::new(1), F128::new(1)], + ]; + + assert_eq!(lexicographic_query.len(), 1 << k); + + for (val, query) in values.iter_mut().zip(lexicographic_query.into_iter()) { + *val = compute(iota, query).val(); + } + } + + let witness = builder.take_witness().unwrap(); + let constraints_system = builder.build().unwrap(); + + validate_witness(&constraints_system, &[], &witness).unwrap(); +} + +fn compute(iota: usize, query: Vec) -> F128 { + let mut result = F128::ONE; + for (i, query_i) in query.iter().enumerate() { + let r_comp = F128::ONE - query_i; + let basis_elt = ::basis(iota + i, 1).unwrap(); + result *= r_comp + *query_i * basis_elt; + } + result +} From dcf4cac5767245b36e5d144c9a1147672de4cd2e Mon Sep 17 00:00:00 2001 From: Artem Storozhuk Date: Mon, 24 Feb 2025 16:49:28 +0200 Subject: [PATCH 48/50] chore: Forward port --- examples/Cargo.toml | 6 ++--- examples/acc-constants.rs | 8 +++---- examples/acc-disjoint-product.rs | 8 ++----- examples/acc-eq-ind-partial-eval.rs | 5 ++--- ...inear-combination-with-offset.rs.disabled} | 0 examples/acc-linear-combination.rs | 13 +++++------ .../acc-multilinear-extension-transparent.rs | 6 ++--- examples/acc-packed.rs | 21 +++++++----------- examples/acc-powers.rs | 13 ++++------- examples/acc-projected.rs | 10 ++++----- examples/acc-repeated.rs | 15 +++++-------- examples/acc-select-row.rs | 6 ++--- examples/acc-shift-ind-partial-eq.rs | 5 ++--- examples/acc-shifted.rs | 22 +++++++++---------- examples/acc-step-down.rs | 6 ++--- examples/acc-step-up.rs | 6 ++--- examples/acc-tower-basis.rs | 5 ++--- examples/acc-zeropadded.rs | 8 +++---- 18 files changed, 63 insertions(+), 100 deletions(-) rename examples/{acc-linear-combination-with-offset.rs => acc-linear-combination-with-offset.rs.disabled} (100%) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 86c0af570..c43d9b7b6 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -77,9 +77,9 @@ path = "b32_mul.rs" name = "acc-linear-combination" path = "acc-linear-combination.rs" -[[example]] -name = "acc-linear-combination-with-offset" -path = "acc-linear-combination-with-offset.rs" +#[[example]] +#name = "acc-linear-combination-with-offset" +#path = "acc-linear-combination-with-offset.rs" [[example]] name = "acc-shifted" diff --git a/examples/acc-constants.rs b/examples/acc-constants.rs index 0947ed107..5c9075b69 100644 --- a/examples/acc-constants.rs +++ b/examples/acc-constants.rs @@ -3,10 +3,8 @@ use binius_core::{ constraint_system::validate::validate_witness, oracle::OracleId, transparent::constant::Constant, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b, BinaryField32b}; +use binius_field::{BinaryField1b, BinaryField32b}; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F32 = BinaryField32b; type F1 = BinaryField1b; @@ -17,7 +15,7 @@ const LOG_SIZE: usize = 4; fn constants_gadget( name: impl ToString, log_size: usize, - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, constant_value: u32, ) -> OracleId { builder.push_namespace(name); @@ -45,7 +43,7 @@ fn constants_gadget( // Transparent column. fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); pub const SHA256_INIT: [u32; 8] = [ 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, diff --git a/examples/acc-disjoint-product.rs b/examples/acc-disjoint-product.rs index 05f4b59d9..2f6ed5812 100644 --- a/examples/acc-disjoint-product.rs +++ b/examples/acc-disjoint-product.rs @@ -3,12 +3,8 @@ use binius_core::{ constraint_system::validate::validate_witness, transparent::{constant::Constant, disjoint_product::DisjointProduct, powers::Powers}, }; -use binius_field::{ - arch::OptimalUnderlier, BinaryField, BinaryField128b, BinaryField8b, PackedField, -}; +use binius_field::{BinaryField, BinaryField8b, PackedField}; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; const LOG_SIZE: usize = 4; @@ -29,7 +25,7 @@ const LOG_SIZE: usize = 4; // of heights (n_vars) of Powers and Constant, so actual data could be repeated multiple times fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let generator = F8::MULTIPLICATIVE_GENERATOR; let powers = Powers::new(LOG_SIZE, generator.into()); diff --git a/examples/acc-eq-ind-partial-eval.rs b/examples/acc-eq-ind-partial-eval.rs index 7fcdc4b54..00f7777a3 100644 --- a/examples/acc-eq-ind-partial-eval.rs +++ b/examples/acc-eq-ind-partial-eval.rs @@ -2,9 +2,8 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{ constraint_system::validate::validate_witness, transparent::eq_ind::EqIndPartialEval, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, PackedField}; +use binius_field::{BinaryField128b, PackedField}; -type U = OptimalUnderlier; type F128 = BinaryField128b; const LOG_SIZE: usize = 3; @@ -21,7 +20,7 @@ const LOG_SIZE: usize = 3; // fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); // A truth table [000, 001, 010, 011 ... 111] where each row is in reversed order let rev_basis = [ diff --git a/examples/acc-linear-combination-with-offset.rs b/examples/acc-linear-combination-with-offset.rs.disabled similarity index 100% rename from examples/acc-linear-combination-with-offset.rs rename to examples/acc-linear-combination-with-offset.rs.disabled diff --git a/examples/acc-linear-combination.rs b/examples/acc-linear-combination.rs index 5cb1eab54..b069b5196 100644 --- a/examples/acc-linear-combination.rs +++ b/examples/acc-linear-combination.rs @@ -1,20 +1,17 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; use binius_field::{ - arch::OptimalUnderlier, packed::set_packed_slice, BinaryField128b, BinaryField1b, - BinaryField8b, ExtensionField, TowerField, + packed::set_packed_slice, BinaryField1b, BinaryField8b, ExtensionField, TowerField, }; use binius_macros::arith_expr; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; type F1 = BinaryField1b; // FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production fn bytes_decomposition_gadget( - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, input: OracleId, @@ -146,7 +143,7 @@ fn bytes_decomposition_gadget( } fn elder_4bits_masking_gadget( - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, name: impl ToString, log_size: usize, input: OracleId, @@ -241,12 +238,12 @@ fn elder_4bits_masking_gadget( fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let log_size = 1usize; // Define set of bytes that we want to decompose - let p_in = unconstrained::(&mut builder, "p_in".to_string(), log_size).unwrap(); + let p_in = unconstrained::(&mut builder, "p_in".to_string(), log_size).unwrap(); let _ = bytes_decomposition_gadget(&mut builder, "bytes decomposition", log_size, p_in).unwrap(); diff --git a/examples/acc-multilinear-extension-transparent.rs b/examples/acc-multilinear-extension-transparent.rs index f9dd45706..033184edd 100644 --- a/examples/acc-multilinear-extension-transparent.rs +++ b/examples/acc-multilinear-extension-transparent.rs @@ -14,7 +14,7 @@ type F128 = BinaryField128b; type F1 = BinaryField1b; // From a perspective of circuits creation, MultilinearExtensionTransparent can be used naturally for decomposing integers to bits -fn decompose_transparent_u64(builder: &mut ConstraintSystemBuilder, x: u64) { +fn decompose_transparent_u64(builder: &mut ConstraintSystemBuilder, x: u64) { builder.push_namespace("decompose_transparent_u64"); let log_bits = log2_ceil_usize(64); @@ -42,7 +42,7 @@ fn decompose_transparent_u64(builder: &mut ConstraintSystemBuilder, x: builder.pop_namespace(); } -fn decompose_transparent_u32(builder: &mut ConstraintSystemBuilder, x: u32) { +fn decompose_transparent_u32(builder: &mut ConstraintSystemBuilder, x: u32) { builder.push_namespace("decompose_transparent_u32"); let log_bits = log2_ceil_usize(32); @@ -72,7 +72,7 @@ fn decompose_transparent_u32(builder: &mut ConstraintSystemBuilder, x: fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); decompose_transparent_u64(&mut builder, 0xff00ff00ff00ff00); decompose_transparent_u32(&mut builder, 0x00ff00ff); diff --git a/examples/acc-packed.rs b/examples/acc-packed.rs index 6963d8780..434a9d11c 100644 --- a/examples/acc-packed.rs +++ b/examples/acc-packed.rs @@ -1,12 +1,7 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::constraint_system::validate::validate_witness; -use binius_field::{ - arch::OptimalUnderlier, BinaryField128b, BinaryField16b, BinaryField1b, BinaryField32b, - BinaryField8b, TowerField, -}; +use binius_field::{BinaryField16b, BinaryField1b, BinaryField32b, BinaryField8b, TowerField}; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F32 = BinaryField32b; type F16 = BinaryField16b; type F8 = BinaryField8b; @@ -14,10 +9,10 @@ type F1 = BinaryField1b; // FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production -fn packing_32_bits_to_u32(builder: &mut ConstraintSystemBuilder) { +fn packing_32_bits_to_u32(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("packing_32_bits_to_u32"); - let bits = unconstrained::(builder, "bits", F32::TOWER_LEVEL).unwrap(); + let bits = unconstrained::(builder, "bits", F32::TOWER_LEVEL).unwrap(); let packed = builder .add_packed("packed", bits, F32::TOWER_LEVEL) .unwrap(); @@ -49,10 +44,10 @@ fn packing_32_bits_to_u32(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn packing_4_bytes_to_u32(builder: &mut ConstraintSystemBuilder) { +fn packing_4_bytes_to_u32(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("packing_4_bytes_to_u32"); - let bytes = unconstrained::(builder, "bytes", F16::TOWER_LEVEL).unwrap(); + let bytes = unconstrained::(builder, "bytes", F16::TOWER_LEVEL).unwrap(); let packed = builder .add_packed("packed", bytes, F16::TOWER_LEVEL) .unwrap(); @@ -76,10 +71,10 @@ fn packing_4_bytes_to_u32(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn packing_8_bits_to_u8(builder: &mut ConstraintSystemBuilder) { +fn packing_8_bits_to_u8(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("packing_8_bits_to_u8"); - let bits = unconstrained::(builder, "bits", F8::TOWER_LEVEL).unwrap(); + let bits = unconstrained::(builder, "bits", F8::TOWER_LEVEL).unwrap(); let packed = builder.add_packed("packed", bits, F8::TOWER_LEVEL).unwrap(); if let Some(witness) = builder.witness() { @@ -94,7 +89,7 @@ fn packing_8_bits_to_u8(builder: &mut ConstraintSystemBuilder) { fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); packing_32_bits_to_u32(&mut builder); packing_4_bytes_to_u32(&mut builder); diff --git a/examples/acc-powers.rs b/examples/acc-powers.rs index 159aa1136..c266191e8 100644 --- a/examples/acc-powers.rs +++ b/examples/acc-powers.rs @@ -1,12 +1,7 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::constraint_system::validate::validate_witness; -use binius_field::{ - arch::OptimalUnderlier, BinaryField, BinaryField128b, BinaryField16b, BinaryField32b, - PackedField, -}; +use binius_field::{BinaryField, BinaryField16b, BinaryField32b, PackedField}; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F32 = BinaryField32b; type F16 = BinaryField16b; @@ -22,7 +17,7 @@ const LOG_SIZE: usize = 3; // where 'x' is a multiplicative generator - a public value that exists for every BinaryField // -fn powers_gadget_f32(builder: &mut ConstraintSystemBuilder, name: impl ToString) { +fn powers_gadget_f32(builder: &mut ConstraintSystemBuilder, name: impl ToString) { builder.push_namespace(name); let generator = F32::MULTIPLICATIVE_GENERATOR; @@ -43,7 +38,7 @@ fn powers_gadget_f32(builder: &mut ConstraintSystemBuilder, name: impl } // Only Field is being changed -fn powers_gadget_f16(builder: &mut ConstraintSystemBuilder, name: impl ToString) { +fn powers_gadget_f16(builder: &mut ConstraintSystemBuilder, name: impl ToString) { builder.push_namespace(name); let generator = F16::MULTIPLICATIVE_GENERATOR; @@ -65,7 +60,7 @@ fn powers_gadget_f16(builder: &mut ConstraintSystemBuilder, name: impl fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); powers_gadget_f16(&mut builder, "f16"); powers_gadget_f32(&mut builder, "f32"); diff --git a/examples/acc-projected.rs b/examples/acc-projected.rs index 8d4de121d..c7ac8150d 100644 --- a/examples/acc-projected.rs +++ b/examples/acc-projected.rs @@ -1,8 +1,7 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::{constraint_system::validate::validate_witness, oracle::ProjectionVariant}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::{BinaryField128b, BinaryField8b}; -type U = OptimalUnderlier; type F128 = BinaryField128b; type F8 = BinaryField8b; @@ -20,14 +19,13 @@ struct U8U128ProjectionInfo { // has significant impact on input data processing. // In the following example we have input column with bytes (u8) projected to the output column with u128 values. fn projection( - builder: &mut ConstraintSystemBuilder, + builder: &mut ConstraintSystemBuilder, projection_info: U8U128ProjectionInfo, namespace: &str, ) { builder.push_namespace(format!("projection {}", namespace)); - let input = - unconstrained::(builder, "in", projection_info.clone().log_size).unwrap(); + let input = unconstrained::(builder, "in", projection_info.clone().log_size).unwrap(); let projected = builder .add_projected( @@ -112,7 +110,7 @@ impl U8U128ProjectionInfo { fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let projection_data = U8U128ProjectionInfo::new( 4usize, diff --git a/examples/acc-repeated.rs b/examples/acc-repeated.rs index 749bb6fee..ef4bf6431 100644 --- a/examples/acc-repeated.rs +++ b/examples/acc-repeated.rs @@ -1,12 +1,9 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::constraint_system::validate::validate_witness; use binius_field::{ - arch::OptimalUnderlier, packed::set_packed_slice, BinaryField128b, BinaryField1b, - BinaryField8b, PackedBinaryField128x1b, + packed::set_packed_slice, BinaryField1b, BinaryField8b, PackedBinaryField128x1b, }; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; type F1 = BinaryField1b; @@ -18,10 +15,10 @@ const LOG_SIZE: usize = 8; // so new column is X times bigger than original one. The following gadget operates over bytes, e.g. // it creates column with some input bytes written and then creates one more 'Repeated' column // where the same bytes are copied multiple times. -fn bytes_repeat_gadget(builder: &mut ConstraintSystemBuilder) { +fn bytes_repeat_gadget(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("bytes_repeat_gadget"); - let bytes = unconstrained::(builder, "input", LOG_SIZE).unwrap(); + let bytes = unconstrained::(builder, "input", LOG_SIZE).unwrap(); let repeat_times_log = 4usize; let repeating = builder @@ -57,10 +54,10 @@ fn bytes_repeat_gadget(builder: &mut ConstraintSystemBuilder) { // repetitions Binius creates column with 8 PackedBinaryField128x1b elements totally. // Proper writing bits requires separate iterating over PackedBinaryField128x1b elements and input bytes // with extracting particular bit values from the input and setting appropriate bit in a given PackedBinaryField128x1b. -fn bits_repeat_gadget(builder: &mut ConstraintSystemBuilder) { +fn bits_repeat_gadget(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("bits_repeat_gadget"); - let bits = unconstrained::(builder, "input", LOG_SIZE).unwrap(); + let bits = unconstrained::(builder, "input", LOG_SIZE).unwrap(); let repeat_times_log = 2usize; // Binius will create column with appropriate height for us @@ -112,7 +109,7 @@ fn bits_repeat_gadget(builder: &mut ConstraintSystemBuilder) { fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); bytes_repeat_gadget(&mut builder); bits_repeat_gadget(&mut builder); diff --git a/examples/acc-select-row.rs b/examples/acc-select-row.rs index 4337540bd..7f45bee02 100644 --- a/examples/acc-select-row.rs +++ b/examples/acc-select-row.rs @@ -2,10 +2,8 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{ constraint_system::validate::validate_witness, transparent::select_row::SelectRow, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::BinaryField8b; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; const LOG_SIZE: usize = 8; @@ -13,7 +11,7 @@ const LOG_SIZE: usize = 8; // SelectRow expects exactly one witness value at particular index to be set. fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let index = 58; assert!(index < 1 << LOG_SIZE); diff --git a/examples/acc-shift-ind-partial-eq.rs b/examples/acc-shift-ind-partial-eq.rs index 74fa2bae7..385a07b85 100644 --- a/examples/acc-shift-ind-partial-eq.rs +++ b/examples/acc-shift-ind-partial-eq.rs @@ -3,16 +3,15 @@ use binius_core::{ constraint_system::validate::validate_witness, oracle::ShiftVariant, transparent::shift_ind::ShiftIndPartialEval, }; -use binius_field::{arch::OptimalUnderlier, util::eq, BinaryField128b, Field}; +use binius_field::{util::eq, BinaryField128b, Field}; -type U = OptimalUnderlier; type F128 = BinaryField128b; // ShiftIndPartialEval is a more elaborated version of EqIndPartialEval. Same idea with challenges, but a bit more // elaborated evaluation algorithm is used fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let block_size = 3; let shift_offset = 4; diff --git a/examples/acc-shifted.rs b/examples/acc-shifted.rs index 924fb8234..a81a6610d 100644 --- a/examples/acc-shifted.rs +++ b/examples/acc-shifted.rs @@ -1,21 +1,19 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::{constraint_system::validate::validate_witness, oracle::ShiftVariant}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField1b}; +use binius_field::BinaryField1b; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F1 = BinaryField1b; // FIXME: Following gadgets are unconstrained. Only for demonstrative purpose, don't use in production -fn shift_right_gadget_u32(builder: &mut ConstraintSystemBuilder) { +fn shift_right_gadget_u32(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("u32_right_shift"); // defined empirically and it is the same as 'block_bits' defined below let log_size = 5usize; // create column and write arbitrary bytes to it - let input = unconstrained::(builder, "input", log_size).unwrap(); + let input = unconstrained::(builder, "input", log_size).unwrap(); // we want to shift our u32 variable on 1 bit let shift_offset = 1; @@ -43,11 +41,11 @@ fn shift_right_gadget_u32(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn shift_left_gadget_u8(builder: &mut ConstraintSystemBuilder) { +fn shift_left_gadget_u8(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("u8_left_shift"); let log_size = 3usize; - let input = unconstrained::(builder, "input", log_size).unwrap(); + let input = unconstrained::(builder, "input", log_size).unwrap(); let shift_offset = 4; let shift_type = ShiftVariant::LogicalLeft; let block_bits = 3; @@ -67,11 +65,11 @@ fn shift_left_gadget_u8(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn rotate_left_gadget_u16(builder: &mut ConstraintSystemBuilder) { +fn rotate_left_gadget_u16(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("u16_rotate_right"); let log_size = 4usize; - let input = unconstrained::(builder, "input", log_size).unwrap(); + let input = unconstrained::(builder, "input", log_size).unwrap(); let rotation_offset = 5; let rotation_type = ShiftVariant::CircularLeft; let block_bits = 4usize; @@ -92,11 +90,11 @@ fn rotate_left_gadget_u16(builder: &mut ConstraintSystemBuilder) { builder.pop_namespace(); } -fn rotate_right_gadget_u64(builder: &mut ConstraintSystemBuilder) { +fn rotate_right_gadget_u64(builder: &mut ConstraintSystemBuilder) { builder.push_namespace("u64_rotate_right"); let log_size = 6usize; - let input = unconstrained::(builder, "input", log_size).unwrap(); + let input = unconstrained::(builder, "input", log_size).unwrap(); // Right rotation to X bits is achieved using 'ShiftVariant::CircularLeft' with the offset, // computed as size in bits of the variable type - X (e.g. if we want to right-rotate u64 to 8 bits, @@ -124,7 +122,7 @@ fn rotate_right_gadget_u64(builder: &mut ConstraintSystemBuilder) { fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); shift_right_gadget_u32(&mut builder); shift_left_gadget_u8(&mut builder); diff --git a/examples/acc-step-down.rs b/examples/acc-step-down.rs index cbe3cd6e6..a2ecf3ba2 100644 --- a/examples/acc-step-down.rs +++ b/examples/acc-step-down.rs @@ -2,18 +2,16 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{ constraint_system::validate::validate_witness, transparent::step_down::StepDown, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::BinaryField8b; const LOG_SIZE: usize = 8; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; // StepDown expects all bytes to be set before particular index specified as input fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let index = 10; diff --git a/examples/acc-step-up.rs b/examples/acc-step-up.rs index e659352e8..0d65820e5 100644 --- a/examples/acc-step-up.rs +++ b/examples/acc-step-up.rs @@ -1,9 +1,7 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{constraint_system::validate::validate_witness, transparent::step_up::StepUp}; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::BinaryField8b; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; const LOG_SIZE: usize = 8; @@ -11,7 +9,7 @@ const LOG_SIZE: usize = 8; // StepUp expects all bytes to be unset before particular index specified as input (opposite to StepDown) fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let index = 10; diff --git a/examples/acc-tower-basis.rs b/examples/acc-tower-basis.rs index e99e6790b..fcc9a3837 100644 --- a/examples/acc-tower-basis.rs +++ b/examples/acc-tower-basis.rs @@ -2,16 +2,15 @@ use binius_circuits::builder::ConstraintSystemBuilder; use binius_core::{ constraint_system::validate::validate_witness, transparent::tower_basis::TowerBasis, }; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, Field, TowerField}; +use binius_field::{BinaryField128b, Field, TowerField}; -type U = OptimalUnderlier; type F128 = BinaryField128b; // TowerBasis expects actually basis vectors written to the witness. // The form of basis could vary depending on 'iota' and 'k' parameters fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); let k = 3usize; let iota = 4usize; diff --git a/examples/acc-zeropadded.rs b/examples/acc-zeropadded.rs index 2cbb954ca..b306bd63a 100644 --- a/examples/acc-zeropadded.rs +++ b/examples/acc-zeropadded.rs @@ -1,9 +1,7 @@ use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; use binius_core::constraint_system::validate::validate_witness; -use binius_field::{arch::OptimalUnderlier, BinaryField128b, BinaryField8b}; +use binius_field::BinaryField8b; -type U = OptimalUnderlier; -type F128 = BinaryField128b; type F8 = BinaryField8b; const LOG_SIZE: usize = 4; @@ -11,9 +9,9 @@ const LOG_SIZE: usize = 4; fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - let bytes = unconstrained::(&mut builder, "bytes", LOG_SIZE).unwrap(); + let bytes = unconstrained::(&mut builder, "bytes", LOG_SIZE).unwrap(); // Height of ZeroPadded column can't be smaller than input one. // If n_vars equals to LOG_SIZE, then no padding is required, From 8b342f7c42204b02a76840122589e07eaf6493fd Mon Sep 17 00:00:00 2001 From: Artem Storozhuk Date: Mon, 24 Feb 2025 16:59:27 +0200 Subject: [PATCH 49/50] feat: Blake3 permutation using channels API --- examples/Cargo.toml | 4 ++ examples/acc-permutation-channels.rs | 97 ++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 examples/acc-permutation-channels.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index c43d9b7b6..0de51d976 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -141,6 +141,10 @@ path = "acc-step-up.rs" name = "acc-tower-basis" path = "acc-tower-basis.rs" +[[example]] +name = "acc-permutation-channels" +path = "acc-permutation-channels.rs" + [lints.clippy] needless_range_loop = "allow" diff --git a/examples/acc-permutation-channels.rs b/examples/acc-permutation-channels.rs new file mode 100644 index 000000000..11fe04074 --- /dev/null +++ b/examples/acc-permutation-channels.rs @@ -0,0 +1,97 @@ +use bumpalo::Bump; +use binius_circuits::builder::ConstraintSystemBuilder; +use binius_circuits::unconstrained::fixed_u32; +use binius_core::constraint_system::channel::{Boundary, FlushDirection}; +use binius_core::constraint_system::validate::validate_witness; +use binius_field::{BinaryField128b, BinaryField32b}; + +type F128 = BinaryField128b; +type F32 = BinaryField32b; + +const MSG_PERMUTATION: [usize; 16] = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8]; + + +// Permutation is a classic construction in a traditional cryptography. It has well-defined security properties +// and high performance due to implementation via lookups. One can possible to implement gadget for permutations using +// channels API from Binius. The following examples shows how to enforce Blake3 permutation - verifier pulls pairs of +// input/output of the permutation (encoded as a BinaryField128b elements, to reduce number of flushes), +// while prover is expected to push similar IO to make channel balanced. +fn permute(m: &mut [u32; 16]) { + let mut permuted = [0; 16]; + for i in 0..16 { + permuted[i] = m[MSG_PERMUTATION[i]]; + } + *m = permuted; +} + +fn main() { + let log_size = 4usize; + + let allocator = Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + let m = [0xfffffff0, 0xfffffff1, 0xfffffff2, 0xfffffff3, 0xfffffff4, 0xfffffff5, 0xfffffff6, 0xfffffff7, 0xfffffff8, 0xfffffff9, 0xfffffffa, 0xfffffffb, 0xfffffffc, 0xfffffffd, 0xfffffffe, 0xffffffff]; + + let mut m_clone = m.clone(); + permute(&mut m_clone); + + let expected = [0xfffffff2, 0xfffffff6, 0xfffffff3, 0xfffffffa, 0xfffffff7, 0xfffffff0, 0xfffffff4, 0xfffffffd, 0xfffffff1, 0xfffffffb, 0xfffffffc, 0xfffffff5, 0xfffffff9, 0xfffffffe, 0xffffffff, 0xfffffff8]; + assert_eq!(m_clone, expected); + + + let u32_in = fixed_u32::(&mut builder, "in", log_size, m.to_vec()).unwrap(); + let u32_out = fixed_u32::(&mut builder, "out", log_size, expected.to_vec()).unwrap(); + + // we pack 4-u32 (F32) tuples of permutation IO into F128 columns and use them for flushing + let u128_in = builder.add_packed("in_packed", u32_in, 2).unwrap(); + let u128_out = builder.add_packed("out_packed", u32_out, 2).unwrap(); + + // populate memory layout (witness) + if let Some(witness) = builder.witness() { + let in_f32 = witness.get::(u32_in).unwrap(); + let out_f32 = witness.get::(u32_out).unwrap(); + witness.new_column::(u128_in); + witness.new_column::(u128_out); + + witness.set(u128_in, in_f32.repacked::()).unwrap(); + witness.set(u128_out, out_f32.repacked::()).unwrap(); + } + + let channel = builder.add_channel(); + // count defines how many values ( 0 .. count ) from a given columns to send (pushing to a channel) + builder.send(channel, 4, [u128_in, u128_out]).unwrap(); + + let witness = builder.take_witness().unwrap(); + let cs = builder.build().unwrap(); + + // consider our 4-u32 values from a given tupple as 4 limbs of u128 + let f = |limb0: u32, limb1: u32, limb2: u32, limb3: u32| { + let mut x = 0u128; + + x ^= (limb3 as u128) << 96; + x ^= (limb2 as u128) << 64; + x ^= (limb1 as u128) << 32; + x ^= limb0 as u128; + + F128::new(x) + }; + + // Boundaries define actual data (encoded in a set of Flushes) that verifier can push or pull from a given channel + // in order to check if prover is able to balance that channel + let mut offset = 0usize; + let boundaries = (0..4).into_iter().map(|_| { + let boundary = Boundary { + values: vec![ + f(m[offset], m[offset + 1], m[offset + 2], m[offset + 3]), + f(expected[offset], expected[offset + 1], expected[offset + 2], expected[offset + 3]) + ], + channel_id: channel, + direction: FlushDirection::Pull, + multiplicity: 1 + }; + offset += 4; + boundary + }).collect::>>(); + + validate_witness(&cs, &boundaries, &witness).unwrap(); +} From 8eb44b73b7f020b58b113f1dd155f04b2f415273 Mon Sep 17 00:00:00 2001 From: Artem Storozhuk Date: Mon, 24 Feb 2025 17:57:44 +0200 Subject: [PATCH 50/50] chore: Formatting --- examples/acc-constants.rs | 2 +- examples/acc-eq-ind-partial-eval.rs | 2 +- .../acc-linear-combination-with-offset.rs | 129 ------------- examples/acc-permutation-channels.rs | 174 ++++++++++-------- 4 files changed, 96 insertions(+), 211 deletions(-) delete mode 100644 examples/acc-linear-combination-with-offset.rs diff --git a/examples/acc-constants.rs b/examples/acc-constants.rs index a180d129f..9536f172d 100644 --- a/examples/acc-constants.rs +++ b/examples/acc-constants.rs @@ -43,7 +43,7 @@ fn constants_gadget( fn main() { let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); pub const SHA256_INIT: [u32; 8] = [ 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, diff --git a/examples/acc-eq-ind-partial-eval.rs b/examples/acc-eq-ind-partial-eval.rs index 78021e63f..6d76fb45f 100644 --- a/examples/acc-eq-ind-partial-eval.rs +++ b/examples/acc-eq-ind-partial-eval.rs @@ -20,7 +20,7 @@ const LOG_SIZE: usize = 3; // fn main() { let allocator = bumpalo::Bump::new(); - + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); // A truth table [000, 001, 010, 011 ... 111] where each row is in reversed order diff --git a/examples/acc-linear-combination-with-offset.rs b/examples/acc-linear-combination-with-offset.rs deleted file mode 100644 index 619eeaec4..000000000 --- a/examples/acc-linear-combination-with-offset.rs +++ /dev/null @@ -1,129 +0,0 @@ -use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::unconstrained}; -use binius_core::{constraint_system::validate::validate_witness, oracle::OracleId}; -use binius_field::{ - arch::OptimalUnderlier, packed::set_packed_slice, AESTowerField128b, AESTowerField8b, - BinaryField1b, ExtensionField, PackedField, TowerField, -}; - -type U = OptimalUnderlier; -type F128 = AESTowerField128b; -type F8 = AESTowerField8b; -type F1 = BinaryField1b; - -fn aes_s_box(x: F8) -> F8 { - #[rustfmt::skip] - const S_BOX: [u8; 256] = [ - 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, - 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, - 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, - 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, - 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, - 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, - 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, - 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, - 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, - 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, - 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, - 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, - 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, - 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, - 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, - 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, - 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, - 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, - 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, - 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, - 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, - 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, - 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, - 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, - 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, - 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, - 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, - 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, - 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, - 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, - 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, - 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, - ]; - let idx = u8::from(x) as usize; - F8::from(S_BOX[idx]) -} - -// AES s-box is equivalent to the affine transformation defined as follows: -// -// s[i] = b[i] + -// b[(i+4) mod 8] + -// b[(i+5) mod 8] + -// b[(i+6) mod 8] + -// b[(i+7) mod 8] + -// c[i] -// -// where 'b' is input byte, 's' is output byte, 'c' is constant which is equal to 0x63 (0b01100011) and 'i' is a bit position. -// The '+' operation is defined over Rijndael finite field : GF(2^8) = GF(2) [x] / (x^8 + x^4 + x^3 + x + 1). -// -const C: F8 = F8::new(0x63); -const AES_AFFINE_TRANSFORMATION: [F8; 8] = [ - F8::new(0b00011111), - F8::new(0b00111110), - F8::new(0b01111100), - F8::new(0b11111000), - F8::new(0b11110001), - F8::new(0b11100011), - F8::new(0b11000111), - F8::new(0b10001111), -]; - -// FIXME: Following gadget is unconstrained. Only for demonstrative purpose, don't use in production - -fn main() { - let allocator = bumpalo::Bump::new(); - let mut builder = ConstraintSystemBuilder::::new_with_witness(&allocator); - - let log_size = 1usize; - let byte_in = unconstrained::(&mut builder, "input_byte", log_size).unwrap(); - - let bits: [OracleId; 8] = - builder.add_committed_multiple("decomposition", log_size, F1::TOWER_LEVEL); - - let byte_out = builder - .add_linear_combination_with_offset( - "lc", - log_size, - C.into(), - (0..8).map(|i| (bits[i], AES_AFFINE_TRANSFORMATION[i].into())), - ) - .unwrap(); - - if let Some(witness) = builder.witness() { - // get initial values of input bytes - let byte_in_values = witness.get::(byte_in).unwrap().as_slice::(); - - // create column for expected values of the output bytes - let mut byte_out_witness = witness.new_column::(byte_out); - let byte_out_values = byte_out_witness.as_mut_slice::(); - - // For each (inverted!) input byte, write correspondent bits to the decomposition - let mut bits_witness = bits.map(|bit| witness.new_column::(bit)); - let packed_bits = bits_witness.each_mut().map(|bit| bit.packed()); - - for byte_position in 0..byte_in_values.len() { - // write expected byte value to the output after applying s_box - byte_out_values[byte_position] = aes_s_box(byte_in_values[byte_position]); - - // invert input byte and write it to a decomposition bits - let input_inverted = byte_in_values[byte_position].invert_or_zero(); - - let bases = ExtensionField::::iter_bases(&input_inverted); - - for (bit_position, bit) in bases.clone().enumerate() { - set_packed_slice(packed_bits[bit_position], byte_position, bit); - } - } - } - - let witness = builder.take_witness().unwrap(); - let cs = builder.build().unwrap(); - - validate_witness(&cs, &[], &witness).unwrap(); -} diff --git a/examples/acc-permutation-channels.rs b/examples/acc-permutation-channels.rs index 11fe04074..a628bf350 100644 --- a/examples/acc-permutation-channels.rs +++ b/examples/acc-permutation-channels.rs @@ -1,97 +1,111 @@ -use bumpalo::Bump; -use binius_circuits::builder::ConstraintSystemBuilder; -use binius_circuits::unconstrained::fixed_u32; -use binius_core::constraint_system::channel::{Boundary, FlushDirection}; -use binius_core::constraint_system::validate::validate_witness; +use binius_circuits::{builder::ConstraintSystemBuilder, unconstrained::fixed_u32}; +use binius_core::constraint_system::{ + channel::{Boundary, FlushDirection}, + validate::validate_witness, +}; use binius_field::{BinaryField128b, BinaryField32b}; +use bumpalo::Bump; type F128 = BinaryField128b; type F32 = BinaryField32b; const MSG_PERMUTATION: [usize; 16] = [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8]; - // Permutation is a classic construction in a traditional cryptography. It has well-defined security properties // and high performance due to implementation via lookups. One can possible to implement gadget for permutations using // channels API from Binius. The following examples shows how to enforce Blake3 permutation - verifier pulls pairs of // input/output of the permutation (encoded as a BinaryField128b elements, to reduce number of flushes), // while prover is expected to push similar IO to make channel balanced. fn permute(m: &mut [u32; 16]) { - let mut permuted = [0; 16]; - for i in 0..16 { - permuted[i] = m[MSG_PERMUTATION[i]]; - } - *m = permuted; + let mut permuted = [0; 16]; + for i in 0..16 { + permuted[i] = m[MSG_PERMUTATION[i]]; + } + *m = permuted; } fn main() { - let log_size = 4usize; - - let allocator = Bump::new(); - let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); - - let m = [0xfffffff0, 0xfffffff1, 0xfffffff2, 0xfffffff3, 0xfffffff4, 0xfffffff5, 0xfffffff6, 0xfffffff7, 0xfffffff8, 0xfffffff9, 0xfffffffa, 0xfffffffb, 0xfffffffc, 0xfffffffd, 0xfffffffe, 0xffffffff]; - - let mut m_clone = m.clone(); - permute(&mut m_clone); - - let expected = [0xfffffff2, 0xfffffff6, 0xfffffff3, 0xfffffffa, 0xfffffff7, 0xfffffff0, 0xfffffff4, 0xfffffffd, 0xfffffff1, 0xfffffffb, 0xfffffffc, 0xfffffff5, 0xfffffff9, 0xfffffffe, 0xffffffff, 0xfffffff8]; - assert_eq!(m_clone, expected); - - - let u32_in = fixed_u32::(&mut builder, "in", log_size, m.to_vec()).unwrap(); - let u32_out = fixed_u32::(&mut builder, "out", log_size, expected.to_vec()).unwrap(); - - // we pack 4-u32 (F32) tuples of permutation IO into F128 columns and use them for flushing - let u128_in = builder.add_packed("in_packed", u32_in, 2).unwrap(); - let u128_out = builder.add_packed("out_packed", u32_out, 2).unwrap(); - - // populate memory layout (witness) - if let Some(witness) = builder.witness() { - let in_f32 = witness.get::(u32_in).unwrap(); - let out_f32 = witness.get::(u32_out).unwrap(); - witness.new_column::(u128_in); - witness.new_column::(u128_out); - - witness.set(u128_in, in_f32.repacked::()).unwrap(); - witness.set(u128_out, out_f32.repacked::()).unwrap(); - } - - let channel = builder.add_channel(); - // count defines how many values ( 0 .. count ) from a given columns to send (pushing to a channel) - builder.send(channel, 4, [u128_in, u128_out]).unwrap(); - - let witness = builder.take_witness().unwrap(); - let cs = builder.build().unwrap(); - - // consider our 4-u32 values from a given tupple as 4 limbs of u128 - let f = |limb0: u32, limb1: u32, limb2: u32, limb3: u32| { - let mut x = 0u128; - - x ^= (limb3 as u128) << 96; - x ^= (limb2 as u128) << 64; - x ^= (limb1 as u128) << 32; - x ^= limb0 as u128; - - F128::new(x) - }; - - // Boundaries define actual data (encoded in a set of Flushes) that verifier can push or pull from a given channel - // in order to check if prover is able to balance that channel - let mut offset = 0usize; - let boundaries = (0..4).into_iter().map(|_| { - let boundary = Boundary { - values: vec![ - f(m[offset], m[offset + 1], m[offset + 2], m[offset + 3]), - f(expected[offset], expected[offset + 1], expected[offset + 2], expected[offset + 3]) - ], - channel_id: channel, - direction: FlushDirection::Pull, - multiplicity: 1 - }; - offset += 4; - boundary - }).collect::>>(); - - validate_witness(&cs, &boundaries, &witness).unwrap(); + let log_size = 4usize; + + let allocator = Bump::new(); + let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator); + + let m = [ + 0xfffffff0, 0xfffffff1, 0xfffffff2, 0xfffffff3, 0xfffffff4, 0xfffffff5, 0xfffffff6, + 0xfffffff7, 0xfffffff8, 0xfffffff9, 0xfffffffa, 0xfffffffb, 0xfffffffc, 0xfffffffd, + 0xfffffffe, 0xffffffff, + ]; + + let mut m_clone = m; + permute(&mut m_clone); + + let expected = [ + 0xfffffff2, 0xfffffff6, 0xfffffff3, 0xfffffffa, 0xfffffff7, 0xfffffff0, 0xfffffff4, + 0xfffffffd, 0xfffffff1, 0xfffffffb, 0xfffffffc, 0xfffffff5, 0xfffffff9, 0xfffffffe, + 0xffffffff, 0xfffffff8, + ]; + assert_eq!(m_clone, expected); + + let u32_in = fixed_u32::(&mut builder, "in", log_size, m.to_vec()).unwrap(); + let u32_out = fixed_u32::(&mut builder, "out", log_size, expected.to_vec()).unwrap(); + + // we pack 4-u32 (F32) tuples of permutation IO into F128 columns and use them for flushing + let u128_in = builder.add_packed("in_packed", u32_in, 2).unwrap(); + let u128_out = builder.add_packed("out_packed", u32_out, 2).unwrap(); + + // populate memory layout (witness) + if let Some(witness) = builder.witness() { + let in_f32 = witness.get::(u32_in).unwrap(); + let out_f32 = witness.get::(u32_out).unwrap(); + witness.new_column::(u128_in); + witness.new_column::(u128_out); + + witness.set(u128_in, in_f32.repacked::()).unwrap(); + witness.set(u128_out, out_f32.repacked::()).unwrap(); + } + + let channel = builder.add_channel(); + // count defines how many values ( 0 .. count ) from a given columns to send (pushing to a channel) + builder.send(channel, 4, [u128_in, u128_out]).unwrap(); + + let witness = builder.take_witness().unwrap(); + let cs = builder.build().unwrap(); + + // consider our 4-u32 values from a given tupple as 4 limbs of u128 + let f = |limb0: u32, limb1: u32, limb2: u32, limb3: u32| { + let mut x = 0u128; + + x ^= (limb3 as u128) << 96; + x ^= (limb2 as u128) << 64; + x ^= (limb1 as u128) << 32; + x ^= limb0 as u128; + + F128::new(x) + }; + + // Boundaries define actual data (encoded in a set of Flushes) that verifier can push or pull from a given channel + // in order to check if prover is able to balance that channel + let mut offset = 0usize; + let boundaries = (0..4) + .map(|_| { + let boundary = Boundary { + values: vec![ + f(m[offset], m[offset + 1], m[offset + 2], m[offset + 3]), + f( + expected[offset], + expected[offset + 1], + expected[offset + 2], + expected[offset + 3], + ), + ], + channel_id: channel, + direction: FlushDirection::Pull, + multiplicity: 1, + }; + offset += 4; + boundary + }) + .collect::>>(); + + validate_witness(&cs, &boundaries, &witness).unwrap(); }