diff --git a/crates/core/src/protocols/sumcheck/prove/batch_zerocheck.rs b/crates/core/src/protocols/sumcheck/prove/batch_zerocheck.rs index 9287cf596..b3de29c8b 100644 --- a/crates/core/src/protocols/sumcheck/prove/batch_zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/batch_zerocheck.rs @@ -1,6 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. -use std::sync::Arc; +use std::{sync::Arc, time::Instant}; use binius_field::{ExtensionField, PackedExtension, PackedField, TowerField}; use binius_hal::{CpuBackend, make_portable_backend}; @@ -175,6 +175,7 @@ where Prover: ZerocheckProver<'a, P>, Challenger_: Challenger, { + // Check that the provers are in non-descending order by n_vars if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars())) { bail!(Error::ClaimsOutOfOrder); @@ -190,6 +191,7 @@ where .max() .unwrap_or(0); + // Sample batching coefficients while computing round polynomials per claim, then batch // those in Lagrange domain. let mut batch_coeffs = Vec::with_capacity(provers.len()); @@ -199,8 +201,12 @@ where let next_batch_coeff = transcript.sample(); batch_coeffs.push(next_batch_coeff); + // let start = Instant::now(); + let prover_round_evals = prover.execute_univariate_round(skip_rounds, max_domain_size, next_batch_coeff)?; + // let total = start.elapsed(); + // println!("batch_prove 1 {total:?}"); round_evals.add_assign_lagrange(&(prover_round_evals * next_batch_coeff))?; } @@ -209,12 +215,17 @@ where transcript.message().write_scalar_slice(&round_evals.evals); let univariate_challenge = transcript.sample(); + // Prove reduced multilinear eq-ind sumchecks, high-to-low, with front-loaded batching let mut sumcheck_provers = Vec::with_capacity(provers.len()); + + for prover in &mut provers { let sumcheck_prover = prover.fold_univariate_round(univariate_challenge)?; sumcheck_provers.push(sumcheck_prover); - } +} + +// let start = Instant::now(); let regular_sumcheck_prover = front_loaded::BatchProver::new_prebatched(batch_coeffs, sumcheck_provers)?; @@ -289,5 +300,8 @@ where concat_multilinear_evals, }; + // let total = start.elapsed(); + // println!("batch_prove 2 {total:?}"); + Ok(output) } diff --git a/crates/core/src/protocols/sumcheck/prove/univariate.rs b/crates/core/src/protocols/sumcheck/prove/univariate.rs index 656fc5995..3dc8bf6d3 100644 --- a/crates/core/src/protocols/sumcheck/prove/univariate.rs +++ b/crates/core/src/protocols/sumcheck/prove/univariate.rs @@ -1,6 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. -use std::{collections::HashMap, iter::repeat_n}; +use std::{collections::HashMap, iter::repeat_n, time::Instant}; use binius_field::{ BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, TowerField, @@ -335,24 +335,29 @@ where let pbase_coset_composition_evals_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree); + // println!("zerocheck_univariate_evals number of staggered rounds {:?}", 1 << log_subcube_count); + // let start = Instant::now(); // NB: we avoid evaluation on the first 2^skip_rounds points because honest // prover would always evaluate to zero there; we also factor out first // skip_rounds terms of the equality indicator and apply them pointwise to // the final round evaluations, which equates to lowering the composition_degree // by one (this is an extension of Gruen section 3.2 trick) - let staggered_round_evals = (0..1 << log_subcube_count) - .into_par_iter() + // let staggered_round_evals = (0..1 << log_subcube_count) + let staggered_round_evals = gray_order_par_iter(log_subcube_count) .try_fold( || { - ParFoldStates::::new( - n_multilinears, - skip_rounds, - log_batch, - log_embedding_degree, - composition_degrees.clone(), + ( + ParFoldStates::::new( + n_multilinears, + skip_rounds, + log_batch, + log_embedding_degree, + composition_degrees.clone(), + ), + None::, ) }, - |mut par_fold_states, subcube_index| -> Result<_, Error> { + |(mut par_fold_states, mut last_index), subcube_index| -> Result<_, Error> { let ParFoldStates { evals, extrapolated_evals, @@ -360,18 +365,50 @@ where packed_round_evals, } = &mut par_fold_states; + let pbase_log_width = P::LOG_WIDTH; + // Which bit flipped since the previous sub-cube? + let (need_recompute, flip_dim) = match last_index { + None => (true, 0), + Some(prev) => { + let diff = prev ^ subcube_index; + let d = diff.trailing_zeros() as usize; + // elements, so we recompute; otherwise we can just permute. + (d >= pbase_log_width + log_embedding_degree, d) + } + }; + + let start = Instant::now(); // Interpolate multilinear evals for each multilinear for (multilinear, extrapolated_evals) in izip!(multilinears, extrapolated_evals.iter_mut()) { // Sample evals subcube from a multilinear poly - multilinear.subcube_evals( - subcube_vars, - subcube_index, - log_embedding_degree, - evals, - )?; + // multilinear.subcube_evals( + // subcube_vars, + // subcube_index, + // log_embedding_degree, + // evals, + // )?; + + if need_recompute { + // println!("need recompute"); + multilinear.subcube_evals( + subcube_vars, + subcube_index, + log_embedding_degree, + evals, + )?; + } else { + // fast in-place update + gray_flip_axis(evals, flip_dim, pbase_log_width); + } + // if subcube_index == 1 || subcube_index == 0 { + // println!( + // "zerocheck_univariate_evals evals.len() subcube_index {} {:?}", + // subcube_index, evals + // ); + // } // Extrapolate evals using a conservative upper bound of the composition // degree. We use Additive NTT to extrapolate evals beyond the first // 2^skip_rounds, exploiting the fact that extension field NTT is a strided @@ -391,6 +428,11 @@ where )? } + // if subcube_index == 1 { + // let total = start.elapsed(); + // println!("zerocheck_univariate_evals_loop {total:?}"); + // } + // Evaluate the compositions and accumulate round results for (composition, packed_round_evals, &pbase_prefix_len) in izip!(compositions, packed_round_evals, &pbase_prefix_lens) @@ -436,31 +478,35 @@ where ); } } + // let total = start.elapsed(); + // println!("zerocheck_univariate_evals__rayon_loop {total:?}"); - Ok(par_fold_states) + last_index = Some(subcube_index); + Ok((par_fold_states, last_index)) }, ) .map(|states| -> Result<_, Error> { - let scalar_round_evals = izip!(composition_degrees.clone(), states?.packed_round_evals) - .map(|(composition_degree, packed_round_evals)| { - let mut composition_round_evals = Vec::with_capacity( - extrapolated_scalars_count(composition_degree, skip_rounds), - ); - - for packed_round_evals_coset in - packed_round_evals.chunks_exact(p_coset_round_evals_len) - { - let coset_scalars = packed_round_evals_coset - .iter() - .flat_map(|packed| packed.iter()) - .take(1 << skip_rounds); + let scalar_round_evals = + izip!(composition_degrees.clone(), states?.0.packed_round_evals) + .map(|(composition_degree, packed_round_evals)| { + let mut composition_round_evals = Vec::with_capacity( + extrapolated_scalars_count(composition_degree, skip_rounds), + ); - composition_round_evals.extend(coset_scalars); - } + for packed_round_evals_coset in + packed_round_evals.chunks_exact(p_coset_round_evals_len) + { + let coset_scalars = packed_round_evals_coset + .iter() + .flat_map(|packed| packed.iter()) + .take(1 << skip_rounds); - composition_round_evals - }) - .collect::>(); + composition_round_evals.extend(coset_scalars); + } + + composition_round_evals + }) + .collect::>(); Ok(scalar_round_evals) }) @@ -488,6 +534,8 @@ where }, )?; + // let total = start.elapsed(); + // println!("zerocheck_univariate_evals {total:?}"); // So far evals of each composition are "staggered" in a sense that they are evaluated on the // smallest domain which guarantees uniqueness of the round polynomial. We extrapolate them to // max_domain_size to aid in Gruen section 3.2 optimization below and batch mixing. @@ -504,6 +552,31 @@ where }) } +fn gray_order_par_iter( + m: usize, +) -> impl binius_maybe_rayon::prelude::ParallelIterator { + use binius_maybe_rayon::prelude::*; + (0..1usize << m).into_par_iter().map(|i| i ^ (i >> 1)) +} + +/// Swap the two halves of `slice` along the Boolean axis `dim`. +#[inline(always)] +fn gray_flip_axis(slice: &mut [T], dim: usize, pbase_log_width: usize) { + if dim < pbase_log_width { return } + + let block = 1usize << (dim - pbase_log_width); // elements to swap + let stride = block << 1; // distance of pairs + for chunk in (0..slice.len()).step_by(stride) { + for i in 0..block { + if chunk + i >= slice.len() || chunk + i + block >= slice.len() { + println!("gray_flip_axis dim {} block {} stride {}", dim, block, stride); + println!("gray_flip_axis chunk {} i {}", chunk, i); + } + slice.swap(chunk + i, chunk + i + block); + } + } +} + // A helper to perform spread multiplication of small field composition evals by appropriate // equality indicator scalars. See `zerocheck_univariate_evals` impl for intuition. fn spread_product( @@ -919,6 +992,8 @@ mod tests { .iter() .map(|round_evals| round_evals[round_evals_index]) .collect::>(); + println!("univariate_skip_composition_sums {:?}", univariate_skip_composition_sums); + assert_eq!(univariate_skip_composition_sums, composition_sums); } } diff --git a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs index 4952a8064..4e0cded05 100644 --- a/crates/core/src/protocols/sumcheck/prove/zerocheck.rs +++ b/crates/core/src/protocols/sumcheck/prove/zerocheck.rs @@ -1,6 +1,6 @@ // Copyright 2024-2025 Irreducible Inc. -use std::{marker::PhantomData, mem, sync::Arc}; +use std::{marker::PhantomData, mem, sync::Arc, time::Instant}; use binius_field::{ ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, RepackedExtension, @@ -319,6 +319,7 @@ where max_domain_size: usize, batch_coeff: F, ) -> Result, Error> { + let ZerocheckProverState::RoundEval { multilinears, compositions, @@ -340,6 +341,8 @@ where .map(|(_, composition_base, _)| composition_base) .collect::>(); + // let start = Instant::now(); + // Output contains values that are needed for computations that happen after // the round challenge has been sampled let univariate_evals_output = zerocheck_univariate_evals::<_, _, FBase, _, _, _, _>( @@ -351,6 +354,9 @@ where self.backend, )?; + // let total = start.elapsed(); + // println!("execute_univariate_round {total:?}"); + // Batch together Lagrange round evals using powers of batch_coeff let batched_round_evals = univariate_evals_output .round_evals @@ -422,6 +428,7 @@ where let lagrange_coeffs_query = MultilinearQuery::with_expansion(skip_rounds, packed_subcube_lagrange_coeffs)?; + // let start = Instant::now(); let folded_multilinears = padded_multilinears .par_iter() .map(|multilinear| -> Result<_, Error> { @@ -433,6 +440,9 @@ where }) .collect::, _>>()?; + // let total = start.elapsed(); + // println!("Folded multilinears in {total:?}"); + let composite_claims = izip!(compositions, claimed_sums) .map(|((_, _, composition), sum)| CompositeSumClaim { composition, sum }) .collect::>(); diff --git a/crates/math/src/mle_adapters.rs b/crates/math/src/mle_adapters.rs index a48dc1494..b5958248b 100644 --- a/crates/math/src/mle_adapters.rs +++ b/crates/math/src/mle_adapters.rs @@ -3,10 +3,9 @@ use std::{fmt::Debug, marker::PhantomData, ops::Deref, sync::Arc}; use binius_field::{ - ExtensionField, Field, PackedField, RepackedExtension, - packed::{ + arithmetic_traits::TaggedPackedTransformationFactory, packed::{ get_packed_slice, get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked, - }, + }, ExtensionField, Field, PackedField, RepackedExtension }; use binius_utils::bail; @@ -214,6 +213,9 @@ where log_embedding_degree: usize, evals: &mut [PE], ) -> Result<(), Error> { + // println!("subcube_evals 1 subcube_vars {} subcube_index {} log_embedding_degree {} evals.len() {}", + // subcube_vars, subcube_index, log_embedding_degree, evals.len()); + let log_extension_degree = PE::Scalar::LOG_DEGREE; if subcube_vars > self.n_vars() { @@ -250,6 +252,9 @@ where } let subcube_start = subcube_index << subcube_vars; + // if subcube_index == 1 { + // println!("subcube_start {}", subcube_start); + // } if log_embedding_degree == 0 { // One-to-one embedding can bypass the extension field construction overhead. @@ -273,6 +278,9 @@ where let bases_count = 1 << log_embedding_degree.min(subcube_vars); for i in 0..1 << subcube_vars.saturating_sub(log_embedding_degree) { for (j, base) in bases[..bases_count].iter_mut().enumerate() { + // if subcube_index == 1 { + // println!("subcube_evals inner loop i {} j {} base {}", i, j, base); + // } // Safety: i > 0 iff log_embedding_degree < subcube_vars and subcube_index < // max_index check *base = unsafe { @@ -517,6 +525,9 @@ where log_embedding_degree: usize, evals: &mut [P], ) -> Result<(), Error> { + // println!("subcube_evals subcube_vars {} subcube_index {} log_embedding_degree {} evals.len() {}", + // subcube_vars, subcube_index, log_embedding_degree, evals.len()); + let n_vars = self.n_vars(); if subcube_vars > n_vars { bail!(Error::ArgumentRangeError { diff --git a/examples/benches/binary_zerocheck.rs b/examples/benches/binary_zerocheck.rs index 1b5ef6b68..16e561cb4 100644 --- a/examples/benches/binary_zerocheck.rs +++ b/examples/benches/binary_zerocheck.rs @@ -111,9 +111,19 @@ fn bench_univariate_skip_aes_tower(c: &mut Criterion) { group.finish() } +fn criterion_config() -> Criterion { + let default = 10; + let ss = std::env::var("CRITERION_SAMPLE_SIZE") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(default); + + Criterion::default().sample_size(ss) +} + criterion_group! { name = binary_zerocheck; - config = Criterion::default().sample_size(10); + config = criterion_config(); targets = bench_univariate_skip_aes_tower } criterion_main!(binary_zerocheck);