diff --git a/src/poly_commit/mod.rs b/src/poly_commit/mod.rs index 3ec5d20..1b004f3 100644 --- a/src/poly_commit/mod.rs +++ b/src/poly_commit/mod.rs @@ -1,12 +1,12 @@ pub mod types; use crate::{ - poly_commit::types::{CRS, Statement, Witness}, + poly_commit::types::{CRS, PolyCommit, Statement, Witness}, vector_ops::inner_product, }; use ark_ec::CurveGroup; -use ark_ff::{Field, UniformRand, batch_inversion}; -use ark_poly::polynomial::DenseUVPolynomial; +use ark_ff::{Field, UniformRand, Zero, batch_inversion}; +use ark_poly::{polynomial::DenseUVPolynomial, univariate::DensePolynomial}; use ark_std::log2; use rayon::prelude::*; use spongefish::{ @@ -16,7 +16,7 @@ use spongefish::{ GroupToUnitDeserialize, GroupToUnitSerialize, UnitToField, }, }; -use std::ops::Mul; +use std::{marker::PhantomData, ops::Mul}; pub trait OpeningProofDomainSeparator { fn opening_proof_statement(self) -> Self; @@ -80,7 +80,7 @@ pub fn prove, Rng: rand::Rng statement: &Statement, witness: &Witness, rng: &mut Rng, -) -> ProofResult> { +) -> ProofResult> { let u: G = { let [u_coeff]: [G::ScalarField; 1] = prover_state.challenge_scalars()?; G::generator().mul(u_coeff) @@ -93,6 +93,8 @@ pub fn prove, Rng: rand::Rng let mut a: Vec = witness.p.coeffs().to_vec(); let mut b: Vec = powers_of_x(statement.x).take(n).collect(); + let mut ui: Vec = Vec::with_capacity(log2(n) as usize); + while n != 1 { n /= 2; let (g_lo, g_hi) = g.split_at(n); @@ -116,6 +118,7 @@ pub fn prove, Rng: rand::Rng prover_state.add_points(&[left_j, right_j])?; let [u_j]: [G::ScalarField; 1] = prover_state.challenge_scalars()?; + ui.push(u_j); let u_j_inv = u_j.inverse().expect("non-zero u_j"); let (new_a, new_b) = rayon::join( @@ -131,7 +134,17 @@ pub fn prove, Rng: rand::Rng } prover_state.add_scalars(&[a[0], r])?; - Ok(prover_state.narg_string().to_vec()) + { + let h_poly = HPoly { ui: ui.clone() }; + let ss = h_poly.coeffs(); + let g_comp = G::msm_unchecked(&crs.gs, &ss); + assert_eq!(g[0], g_comp.into_affine(), "Gs NEQ") + } + + Ok(Todo { + g: PolyCommit { g: g[0].into() }, + h_poly: HPoly { ui }, + }) } pub fn verify( @@ -163,41 +176,193 @@ pub fn verify( }); let ss: Vec = { - let challenge_powers: Vec = transcript.iter().map(|&(_, x)| x).collect(); - let challenge_inverses = { - let mut inverses = challenge_powers.clone(); + let h_poly = HPoly { + ui: transcript.iter().map(|&(_, x)| x).collect(), + }; + h_poly.coeffs() + }; + + let g = G::msm_unchecked(&crs.gs, &ss); + let b = { + let b = powers_of_x(statement.x).take(n); + inner_product(ss, b) + }; + + let [a, r]: [G::ScalarField; 2] = verifier_state.next_scalars()?; + let res = g.mul(a) + crs.h.mul(r) + u.mul(a * b) - q; + + assert!(res.is_zero(), "Q equality"); + Ok(()) +} + +#[derive(Debug, Clone, PartialEq)] +pub struct HPoly { + pub ui: Vec, +} + +impl HPoly { + pub fn evaluate(&self, x: F) -> F { + self.ui + .iter() + .rev() + .enumerate() + .map(|(i, &u_i)| { + let u_i_inv = u_i.inverse().unwrap(); + let exp = 2_u64.pow(i as u32); + u_i_inv + u_i * x.pow([exp]) + }) + .product() + } + + pub fn coeffs(&self) -> Vec { + let ui_inverses = { + let mut inverses = self.ui.clone(); batch_inversion(&mut inverses); inverses }; + let k = self.ui.len(); + let n = (2_u64).pow(k as u32); + (0..n) .into_par_iter() .map(|i| { - (0..log2_n) + (0..k) .map(|j| { - let idx = log2_n - j - 1; + let idx = k - j - 1; if (i >> j) & 1 == 1 { - challenge_powers[idx] + self.ui[idx] } else { - challenge_inverses[idx] + ui_inverses[idx] } }) .product() }) .collect() + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Todo { + pub h_poly: HPoly, + pub g: PolyCommit, +} + +pub fn fold_todos_witness( + todos: &[Todo], + alpha: G::ScalarField, +) -> Witness> { + todos + .iter() + .map(|Todo { h_poly, .. }| Witness { + // It is important to put zero for the amortization, i.e. "Halo Trick" + r: G::ScalarField::zero(), + p: DensePolynomial::from_coefficients_vec(h_poly.coeffs()), + _group: PhantomData, + }) + .zip(powers_of_x(alpha)) + .map(|(witness, alpha_i)| witness.mul(alpha_i)) + .reduce(|x, y| x + y) + .expect("Non empty statement list") +} + +pub fn fold_todos_statement( + todos: &[Todo], + alpha: G::ScalarField, + x: G::ScalarField, +) -> Statement { + todos + .iter() + .map(|Todo { h_poly, g }| Statement { + commitment: *g, + evaluation: h_poly.evaluate(x), + x, + }) + .zip(powers_of_x(alpha)) + .map(|(s, alpha_i)| s.mul(alpha_i)) + .reduce(|x, y| x + y) + .expect("Non empty statement list") +} + +pub fn lazy_verify( + verifier_state: &mut VerifierState, + crs: &CRS, + statement: &Statement, + assumption: PolyCommit, + mut todos: Vec>, +) -> ProofResult>> { + let n = crs.size(); + let log2_n = log2(n) as usize; + + let u: G = { + let [u_coeff]: [G::ScalarField; 1] = verifier_state.challenge_scalars()?; + G::generator().mul(u_coeff) }; - let g = G::msm_unchecked(&crs.gs, &ss); - let b = { - let b = powers_of_x(statement.x).take(n); - inner_product(ss, b) + let p_prime = statement.commitment.g + u.mul(statement.evaluation); + + let transcript = (0..log2_n) + .map(|_| { + let [left, right]: [G; 2] = verifier_state.next_points()?; + let [x]: [G::ScalarField; 1] = verifier_state.challenge_scalars()?; + Ok(((left, right), x)) + }) + .collect::>>()?; + + let q: G = transcript.iter().fold(p_prime, |acc, ((l_j, r_j), u_j)| { + let u_j_inv = u_j.inverse().expect("non zero u_j"); + acc + l_j.mul(u_j.square()) + r_j.mul(u_j_inv.square()) + }); + + let h_poly = HPoly { + ui: transcript.iter().map(|&(_, x)| x).collect(), }; + let b = h_poly.evaluate(statement.x); + let [a, r]: [G::ScalarField; 2] = verifier_state.next_scalars()?; - let res = g.mul(a) + crs.h.mul(r) + u.mul(a * b) - q; + let res = assumption.g.mul(a) + crs.h.mul(r) + u.mul(a * b) - q; assert!(res.is_zero(), "Q equality"); - Ok(()) + let todo = Todo { + g: assumption, + h_poly, + }; + todos.push(todo); + Ok(todos) +} + +#[cfg(test)] +mod test_hpoly { + use super::*; + use ark_secp256k1::Fr; + use ark_std::UniformRand; + use rand::rngs::OsRng; + + #[test] + fn test_hpoly_evaluation_matches_expansion() { + let mut rng = OsRng; + + // Test with a small HPoly + let ui: Vec = (0..3).map(|_| Fr::rand(&mut rng)).collect(); + let h_poly = HPoly { ui }; + let x = Fr::rand(&mut rng); + + // Compute via efficient evaluate method + let eval_result = h_poly.evaluate(x); + + // Compute via expansion and inner product + let coeffs = h_poly.coeffs(); + let n = coeffs.len(); + let x_powers: Vec = powers_of_x(x).take(n).collect(); + let inner_result = inner_product(coeffs, x_powers); + + assert_eq!( + eval_result, inner_result, + "HPoly evaluation mismatch: eval={:?}, inner={:?}", + eval_result, inner_result + ); + } } #[cfg(test)] @@ -210,57 +375,164 @@ mod tests_proof { use ark_std::UniformRand; use proptest::{prelude::*, test_runner::Config}; use rand::rngs::OsRng; - use spongefish::DomainSeparator; use spongefish::codecs::arkworks_algebra::{CommonFieldToUnit, CommonGroupToUnit}; + use spongefish::{DomainSeparator, ProofError}; + use std::ops::Add; + + type Fr = + as PrimeGroup>::ScalarField; + + pub fn prove_verify( + crs: &CRS, + witness: &Witness>, + statement: &Statement, + rng: &mut OsRng, + ) -> ProofResult<()> { + let domain_separator = { + let domain_separator = DomainSeparator::new("test-poly-comm"); + // add the IO of the bulletproof statement + let domain_separator = + OpeningProofDomainSeparator::::opening_proof_statement( + domain_separator, + ) + .ratchet(); + // add the IO of the bulletproof protocol (the transcript) + OpeningProofDomainSeparator::::add_opening_proof( + domain_separator, + witness.size(), + ) + }; + + let mut prover_state = domain_separator.to_prover_state(); + let proof = { + prover_state + .public_points(&[statement.commitment.g]) + .unwrap(); + prover_state.public_scalars(&[statement.x]).unwrap(); + prover_state + .public_scalars(&[statement.evaluation]) + .unwrap(); + prover_state.ratchet().unwrap(); + prove(&mut prover_state, crs, statement, witness, rng) + .expect("proof should be generated"); + prover_state.narg_string() + }; + + let mut verifier_state = domain_separator.to_verifier_state(proof); + verifier_state + .public_points(&[statement.commitment.g]) + .expect("cannot add statement"); + verifier_state + .public_scalars(&[statement.x]) + .expect("cannot add statement"); + verifier_state + .public_scalars(&[statement.evaluation]) + .expect("cannot add statement"); + verifier_state.ratchet().expect("failed to ratchet"); + verify(&mut verifier_state, crs, statement) + } proptest! { #![proptest_config(Config::with_cases(2))] #[test] - fn test_poly_comm_prove_verify_works((crs, witness, x) in any::().prop_map(|crs_size| { + fn test_poly_comm_prove_verify_works((crs, witness1, witness2, statement1, statement2) in any::().prop_map(|crs_size| { let mut rng = OsRng; - type Fr = as PrimeGroup>::ScalarField; let crs = >::rand(crs_size, &mut rng); let n = crs.size() as u64; - let witness: Witness> = Witness::rand((n - 1) as usize, &mut rng); + let witness1: Witness> = Witness::rand((n - 1) as usize, &mut rng); + let witness2: Witness> = Witness::rand((n - 1) as usize, &mut rng); let x = Fr::rand(&mut rng); - (crs, witness, x) + let statement1 = witness1.statement(&crs,x); + let statement2 = witness2.statement(&crs,x); + (crs, witness1, witness2, statement1, statement2) })) { + let mut rng = OsRng; - { + // works in normal case + prove_verify(&crs, &witness1, &statement1, &mut rng).expect("normal"); - let mut rng = OsRng; + //can scale + let alpha1 = Fr::rand(&mut rng); + let alpha2 = Fr::rand(&mut rng); + let witness = witness1.mul(alpha1).add(witness2.mul(alpha2)); + let statement = statement1.mul(alpha1).add(statement2.mul(alpha2)); + assert_eq!(statement, witness.statement(&crs, statement1.x), "statements are linear"); + prove_verify(&crs, &witness, &statement, &mut rng).expect("linear"); - let domain_separator = { - let domain_separator = DomainSeparator::new("test-poly-comm"); - // add the IO of the bulletproof statement - let domain_separator = - OpeningProofDomainSeparator::::opening_proof_statement(domain_separator).ratchet(); - // add the IO of the bulletproof protocol (the transcript) - OpeningProofDomainSeparator::::add_opening_proof(domain_separator, witness.size()) - }; - let (statement, proof) = { + } + } - let statement: Statement = witness.statement(&crs, x); - let mut prover_state = domain_separator.to_prover_state(); - prover_state.public_points(&[statement.commitment.g]).unwrap(); - prover_state.public_scalars(&[statement.x]).unwrap(); - prover_state.public_scalars(&[statement.evaluation]).unwrap(); - prover_state.ratchet().unwrap(); - let proof = prove(&mut prover_state, &crs, &statement, &witness, &mut rng).expect("proof should be generated"); - (statement, proof) - }; + proptest! { + #![proptest_config(Config::with_cases(2))] + #[test] + fn test_poly_comm_amortize((crs, witnesses, points) in any::<(CrsSize, u8)>().prop_map(|(crs_size, _)| { + let mut rng = OsRng; + let crs = >::rand(crs_size, &mut rng); + let n = crs.size() as u64; + let m = 4; + let witnesses = { + let mut ws = Vec::with_capacity(m as usize); + for _ in 0..m { + ws.push(>>::rand((n - 1) as usize, &mut rng)); + }; + ws + }; - let mut verifier_state = domain_separator.to_verifier_state(&proof); - verifier_state.public_points(&[statement.commitment.g]).expect("cannot add statement"); - verifier_state.public_scalars(&[statement.x]).expect("cannot add statement"); - verifier_state.public_scalars(&[statement.evaluation]).expect("cannot add statement"); - verifier_state.ratchet().expect("failed to ratchet"); - verify(&mut verifier_state, &crs, &statement).expect("proof should verify"); + let points: Vec = (0..witnesses.len()).map(|_| Fr::rand(&mut rng)).collect(); + (crs, witnesses, points) + })) { + let mut rng = OsRng; + let domain_separator = DomainSeparator::new("test-poly-comm"); + + let proofs = points.iter().zip(witnesses.iter()).map(|(&x, witness)| { + let domain_separator = OpeningProofDomainSeparator::::opening_proof_statement(domain_separator.clone()).ratchet(); + let domain_separator = OpeningProofDomainSeparator::::add_opening_proof(domain_separator, crs.size()); + let mut prover_state = domain_separator.to_prover_state(); + let statement = witness.statement(&crs, x); + prover_state + .public_points(&[statement.commitment.g]) + .unwrap(); + prover_state.public_scalars(&[statement.x]).unwrap(); + prover_state + .public_scalars(&[statement.evaluation]) + .unwrap(); + prover_state.ratchet().unwrap(); + let proof = prove(&mut prover_state, &crs, &statement, witness, &mut rng)?; + Ok((domain_separator, prover_state.narg_string().to_vec(), statement, proof)) + }).collect::, ProofError>>()?; + + let verifier_todos = proofs.iter().try_fold(Vec::new(), |todos, (domain_separator, proof, statement, prover_todo)| { + let mut verifier_state = domain_separator.to_verifier_state(proof); + verifier_state + .public_points(&[statement.commitment.g]) + .expect("cannot add statement"); + verifier_state + .public_scalars(&[statement.x]) + .expect("cannot add statement"); + verifier_state + .public_scalars(&[statement.evaluation]) + .expect("cannot add statement"); + verifier_state.ratchet().expect("failed to ratchet"); + lazy_verify(&mut verifier_state, &crs, statement, prover_todo.g, todos) + })?; + + let prover_todos: Vec> = proofs.iter().map(|(_,_,_,todo)| todo).cloned().collect(); + + assert_eq!(prover_todos, verifier_todos, "Prover todos don't match verifier todos"); + + let alpha = Fr::rand(&mut rng); + let x = Fr::rand(&mut rng); + let witness = fold_todos_witness(&prover_todos, alpha); + let statement = fold_todos_statement(&verifier_todos, alpha, x); + { + let prover_statement = witness.statement(&crs, x); + assert_eq!(prover_statement, statement, "Statements match"); } + prove_verify(&crs, &witness, &statement, &mut rng)?; } } } diff --git a/src/poly_commit/types.rs b/src/poly_commit/types.rs index 37e57cf..7240fbe 100644 --- a/src/poly_commit/types.rs +++ b/src/poly_commit/types.rs @@ -2,6 +2,7 @@ use ark_ec::CurveGroup; use ark_ff::Zero; use ark_poly::{DenseUVPolynomial, polynomial}; use std::marker::PhantomData; +use std::ops::{Add, Mul}; use crate::ipa::types::CrsSize; @@ -26,16 +27,35 @@ impl CRS { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub struct PolyCommit { pub g: G, } -#[derive(Debug, Clone, Copy)] +impl PolyCommit { + #[allow(clippy::should_implement_trait)] + pub fn mul(self, alpha: G::ScalarField) -> Self { + Self { + g: self.g.mul(alpha), + } + } +} + +impl Add for PolyCommit { + type Output = Self; + + fn add(self, other: Self) -> Self { + PolyCommit { + g: self.g + other.g, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] pub struct Witness> { pub p: P, pub r: G::ScalarField, - _group: PhantomData, + pub _group: PhantomData, } impl Witness> { @@ -51,6 +71,26 @@ impl Witness Self { + Self { + p: self.p.mul(alpha), + r: self.r * alpha, + _group: self._group, + } + } +} + +impl> Add for Witness { + type Output = Witness; + fn add(self, other: Self) -> Self { + Self { + p: self.p + other.p, + r: self.r + other.r, + _group: self._group, + } + } } impl> Witness { @@ -59,13 +99,39 @@ impl> Witness { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub struct Statement { pub commitment: PolyCommit, pub x: G::ScalarField, pub evaluation: G::ScalarField, } +impl Statement { + #[allow(clippy::should_implement_trait)] + pub fn mul(self, alpha: G::ScalarField) -> Self { + Self { + commitment: self.commitment.mul(alpha), + x: self.x, + evaluation: self.evaluation * alpha, + } + } +} + +impl Add for Statement { + type Output = Statement; + fn add(self, other: Self) -> Self { + assert!( + self.x == other.x, + "Can only add Statements where the evaluation points match" + ); + Self { + commitment: self.commitment + other.commitment, + x: self.x, + evaluation: self.evaluation + other.evaluation, + } + } +} + impl> Witness { pub fn statement(&self, crs: &CRS, x: G::ScalarField) -> Statement { let commitment = {