Skip to content
This repository was archived by the owner on Sep 9, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions crates/core/src/protocols/sumcheck/prove/batch_zerocheck.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright 2024-2025 Irreducible Inc.

use std::sync::Arc;
use std::{sync::Arc, time::Instant};

use binius_field::{ExtensionField, PackedExtension, PackedField, TowerField};
use binius_hal::{CpuBackend, make_portable_backend};
Expand Down Expand Up @@ -175,6 +175,7 @@ where
Prover: ZerocheckProver<'a, P>,
Challenger_: Challenger,
{

// Check that the provers are in non-descending order by n_vars
if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars())) {
bail!(Error::ClaimsOutOfOrder);
Expand All @@ -190,6 +191,7 @@ where
.max()
.unwrap_or(0);


// Sample batching coefficients while computing round polynomials per claim, then batch
// those in Lagrange domain.
let mut batch_coeffs = Vec::with_capacity(provers.len());
Expand All @@ -199,8 +201,12 @@ where
let next_batch_coeff = transcript.sample();
batch_coeffs.push(next_batch_coeff);

// let start = Instant::now();

let prover_round_evals =
prover.execute_univariate_round(skip_rounds, max_domain_size, next_batch_coeff)?;
// let total = start.elapsed();
// println!("batch_prove 1 {total:?}");

round_evals.add_assign_lagrange(&(prover_round_evals * next_batch_coeff))?;
}
Expand All @@ -209,12 +215,17 @@ where
transcript.message().write_scalar_slice(&round_evals.evals);
let univariate_challenge = transcript.sample();


// Prove reduced multilinear eq-ind sumchecks, high-to-low, with front-loaded batching
let mut sumcheck_provers = Vec::with_capacity(provers.len());


for prover in &mut provers {
let sumcheck_prover = prover.fold_univariate_round(univariate_challenge)?;
sumcheck_provers.push(sumcheck_prover);
}
}

// let start = Instant::now();

let regular_sumcheck_prover =
front_loaded::BatchProver::new_prebatched(batch_coeffs, sumcheck_provers)?;
Expand Down Expand Up @@ -289,5 +300,8 @@ where
concat_multilinear_evals,
};

// let total = start.elapsed();
// println!("batch_prove 2 {total:?}");

Ok(output)
}
145 changes: 110 additions & 35 deletions crates/core/src/protocols/sumcheck/prove/univariate.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright 2024-2025 Irreducible Inc.

use std::{collections::HashMap, iter::repeat_n};
use std::{collections::HashMap, iter::repeat_n, time::Instant};

use binius_field::{
BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, TowerField,
Expand Down Expand Up @@ -335,43 +335,80 @@ where
let pbase_coset_composition_evals_len =
1 << subcube_vars.saturating_sub(P::LOG_WIDTH + log_embedding_degree);

// println!("zerocheck_univariate_evals number of staggered rounds {:?}", 1 << log_subcube_count);
// let start = Instant::now();
// NB: we avoid evaluation on the first 2^skip_rounds points because honest
// prover would always evaluate to zero there; we also factor out first
// skip_rounds terms of the equality indicator and apply them pointwise to
// the final round evaluations, which equates to lowering the composition_degree
// by one (this is an extension of Gruen section 3.2 trick)
let staggered_round_evals = (0..1 << log_subcube_count)
.into_par_iter()
// let staggered_round_evals = (0..1 << log_subcube_count)
let staggered_round_evals = gray_order_par_iter(log_subcube_count)
.try_fold(
|| {
ParFoldStates::<FBase, P>::new(
n_multilinears,
skip_rounds,
log_batch,
log_embedding_degree,
composition_degrees.clone(),
(
ParFoldStates::<FBase, P>::new(
n_multilinears,
skip_rounds,
log_batch,
log_embedding_degree,
composition_degrees.clone(),
),
None::<usize>,
)
},
|mut par_fold_states, subcube_index| -> Result<_, Error> {
|(mut par_fold_states, mut last_index), subcube_index| -> Result<_, Error> {
let ParFoldStates {
evals,
extrapolated_evals,
composition_evals,
packed_round_evals,
} = &mut par_fold_states;

let pbase_log_width = P::LOG_WIDTH;
// Which bit flipped since the previous sub-cube?
let (need_recompute, flip_dim) = match last_index {
None => (true, 0),
Some(prev) => {
let diff = prev ^ subcube_index;
let d = diff.trailing_zeros() as usize;
// elements, so we recompute; otherwise we can just permute.
(d >= pbase_log_width + log_embedding_degree, d)
}
};

let start = Instant::now();
// Interpolate multilinear evals for each multilinear
for (multilinear, extrapolated_evals) in
izip!(multilinears, extrapolated_evals.iter_mut())
{
// Sample evals subcube from a multilinear poly
multilinear.subcube_evals(
subcube_vars,
subcube_index,
log_embedding_degree,
evals,
)?;
// multilinear.subcube_evals(
// subcube_vars,
// subcube_index,
// log_embedding_degree,
// evals,
// )?;

if need_recompute {
// println!("need recompute");
multilinear.subcube_evals(
subcube_vars,
subcube_index,
log_embedding_degree,
evals,
)?;
} else {
// fast in-place update
gray_flip_axis(evals, flip_dim, pbase_log_width);
}

// if subcube_index == 1 || subcube_index == 0 {
// println!(
// "zerocheck_univariate_evals evals.len() subcube_index {} {:?}",
// subcube_index, evals
// );
// }
// Extrapolate evals using a conservative upper bound of the composition
// degree. We use Additive NTT to extrapolate evals beyond the first
// 2^skip_rounds, exploiting the fact that extension field NTT is a strided
Expand All @@ -391,6 +428,11 @@ where
)?
}

// if subcube_index == 1 {
// let total = start.elapsed();
// println!("zerocheck_univariate_evals_loop {total:?}");
// }

// Evaluate the compositions and accumulate round results
for (composition, packed_round_evals, &pbase_prefix_len) in
izip!(compositions, packed_round_evals, &pbase_prefix_lens)
Expand Down Expand Up @@ -436,31 +478,35 @@ where
);
}
}
// let total = start.elapsed();
// println!("zerocheck_univariate_evals__rayon_loop {total:?}");

Ok(par_fold_states)
last_index = Some(subcube_index);
Ok((par_fold_states, last_index))
},
)
.map(|states| -> Result<_, Error> {
let scalar_round_evals = izip!(composition_degrees.clone(), states?.packed_round_evals)
.map(|(composition_degree, packed_round_evals)| {
let mut composition_round_evals = Vec::with_capacity(
extrapolated_scalars_count(composition_degree, skip_rounds),
);

for packed_round_evals_coset in
packed_round_evals.chunks_exact(p_coset_round_evals_len)
{
let coset_scalars = packed_round_evals_coset
.iter()
.flat_map(|packed| packed.iter())
.take(1 << skip_rounds);
let scalar_round_evals =
izip!(composition_degrees.clone(), states?.0.packed_round_evals)
.map(|(composition_degree, packed_round_evals)| {
let mut composition_round_evals = Vec::with_capacity(
extrapolated_scalars_count(composition_degree, skip_rounds),
);

composition_round_evals.extend(coset_scalars);
}
for packed_round_evals_coset in
packed_round_evals.chunks_exact(p_coset_round_evals_len)
{
let coset_scalars = packed_round_evals_coset
.iter()
.flat_map(|packed| packed.iter())
.take(1 << skip_rounds);

composition_round_evals
})
.collect::<Vec<_>>();
composition_round_evals.extend(coset_scalars);
}

composition_round_evals
})
.collect::<Vec<_>>();

Ok(scalar_round_evals)
})
Expand Down Expand Up @@ -488,6 +534,8 @@ where
},
)?;

// let total = start.elapsed();
// println!("zerocheck_univariate_evals {total:?}");
// So far evals of each composition are "staggered" in a sense that they are evaluated on the
// smallest domain which guarantees uniqueness of the round polynomial. We extrapolate them to
// max_domain_size to aid in Gruen section 3.2 optimization below and batch mixing.
Expand All @@ -504,6 +552,31 @@ where
})
}

fn gray_order_par_iter(
m: usize,
) -> impl binius_maybe_rayon::prelude::ParallelIterator<Item = usize> {
use binius_maybe_rayon::prelude::*;
(0..1usize << m).into_par_iter().map(|i| i ^ (i >> 1))
}

/// Swap the two halves of `slice` along the Boolean axis `dim`.
#[inline(always)]
fn gray_flip_axis<T>(slice: &mut [T], dim: usize, pbase_log_width: usize) {
if dim < pbase_log_width { return }

let block = 1usize << (dim - pbase_log_width); // elements to swap
let stride = block << 1; // distance of pairs
for chunk in (0..slice.len()).step_by(stride) {
for i in 0..block {
if chunk + i >= slice.len() || chunk + i + block >= slice.len() {
println!("gray_flip_axis dim {} block {} stride {}", dim, block, stride);
println!("gray_flip_axis chunk {} i {}", chunk, i);
}
slice.swap(chunk + i, chunk + i + block);
}
}
}

// A helper to perform spread multiplication of small field composition evals by appropriate
// equality indicator scalars. See `zerocheck_univariate_evals` impl for intuition.
fn spread_product<P, FBase>(
Expand Down Expand Up @@ -919,6 +992,8 @@ mod tests {
.iter()
.map(|round_evals| round_evals[round_evals_index])
.collect::<Vec<_>>();
println!("univariate_skip_composition_sums {:?}", univariate_skip_composition_sums);

assert_eq!(univariate_skip_composition_sums, composition_sums);
}
}
Expand Down
12 changes: 11 additions & 1 deletion crates/core/src/protocols/sumcheck/prove/zerocheck.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright 2024-2025 Irreducible Inc.

use std::{marker::PhantomData, mem, sync::Arc};
use std::{marker::PhantomData, mem, sync::Arc, time::Instant};

use binius_field::{
ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, RepackedExtension,
Expand Down Expand Up @@ -319,6 +319,7 @@ where
max_domain_size: usize,
batch_coeff: F,
) -> Result<ZerocheckRoundEvals<F>, Error> {

let ZerocheckProverState::RoundEval {
multilinears,
compositions,
Expand All @@ -340,6 +341,8 @@ where
.map(|(_, composition_base, _)| composition_base)
.collect::<Vec<_>>();

// let start = Instant::now();

// Output contains values that are needed for computations that happen after
// the round challenge has been sampled
let univariate_evals_output = zerocheck_univariate_evals::<_, _, FBase, _, _, _, _>(
Expand All @@ -351,6 +354,9 @@ where
self.backend,
)?;

// let total = start.elapsed();
// println!("execute_univariate_round {total:?}");

// Batch together Lagrange round evals using powers of batch_coeff
let batched_round_evals = univariate_evals_output
.round_evals
Expand Down Expand Up @@ -422,6 +428,7 @@ where
let lagrange_coeffs_query =
MultilinearQuery::with_expansion(skip_rounds, packed_subcube_lagrange_coeffs)?;

// let start = Instant::now();
let folded_multilinears = padded_multilinears
.par_iter()
.map(|multilinear| -> Result<_, Error> {
Expand All @@ -433,6 +440,9 @@ where
})
.collect::<Result<Vec<_>, _>>()?;

// let total = start.elapsed();
// println!("Folded multilinears in {total:?}");

let composite_claims = izip!(compositions, claimed_sums)
.map(|((_, _, composition), sum)| CompositeSumClaim { composition, sum })
.collect::<Vec<_>>();
Expand Down
17 changes: 14 additions & 3 deletions crates/math/src/mle_adapters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
use std::{fmt::Debug, marker::PhantomData, ops::Deref, sync::Arc};

use binius_field::{
ExtensionField, Field, PackedField, RepackedExtension,
packed::{
arithmetic_traits::TaggedPackedTransformationFactory, packed::{
get_packed_slice, get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked,
},
}, ExtensionField, Field, PackedField, RepackedExtension
};
use binius_utils::bail;

Expand Down Expand Up @@ -214,6 +213,9 @@ where
log_embedding_degree: usize,
evals: &mut [PE],
) -> Result<(), Error> {
// println!("subcube_evals 1 subcube_vars {} subcube_index {} log_embedding_degree {} evals.len() {}",
// subcube_vars, subcube_index, log_embedding_degree, evals.len());

let log_extension_degree = PE::Scalar::LOG_DEGREE;

if subcube_vars > self.n_vars() {
Expand Down Expand Up @@ -250,6 +252,9 @@ where
}

let subcube_start = subcube_index << subcube_vars;
// if subcube_index == 1 {
// println!("subcube_start {}", subcube_start);
// }

if log_embedding_degree == 0 {
// One-to-one embedding can bypass the extension field construction overhead.
Expand All @@ -273,6 +278,9 @@ where
let bases_count = 1 << log_embedding_degree.min(subcube_vars);
for i in 0..1 << subcube_vars.saturating_sub(log_embedding_degree) {
for (j, base) in bases[..bases_count].iter_mut().enumerate() {
// if subcube_index == 1 {
// println!("subcube_evals inner loop i {} j {} base {}", i, j, base);
// }
// Safety: i > 0 iff log_embedding_degree < subcube_vars and subcube_index <
// max_index check
*base = unsafe {
Expand Down Expand Up @@ -517,6 +525,9 @@ where
log_embedding_degree: usize,
evals: &mut [P],
) -> Result<(), Error> {
// println!("subcube_evals subcube_vars {} subcube_index {} log_embedding_degree {} evals.len() {}",
// subcube_vars, subcube_index, log_embedding_degree, evals.len());

let n_vars = self.n_vars();
if subcube_vars > n_vars {
bail!(Error::ArgumentRangeError {
Expand Down
Loading