diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index eb495a02c..413b408f6 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -63,6 +63,18 @@ jobs: RUSTFLAGS: "-C opt-level=3" run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/fibonacci + - name: Run 3 shards fibonacci (debug) + env: + RUST_LOG: debug + RUSTFLAGS: "-C opt-level=3" + MOCK_PROVING: 1 + run: cargo run --package ceno_zkvm --features sanity-check --bin e2e -- --platform=ceno --min-cycle-per-shard=10 --max-cycle-per-shard=20000 --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/debug/examples/fibonacci + + - name: Run 3 shards fibonacci (release) + env: + RUSTFLAGS: "-C opt-level=3" + run: cargo run --release --package ceno_zkvm --features sanity-check --bin e2e -- --platform=ceno --min-cycle-per-shard=10 --max-cycle-per-shard=20000 --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/fibonacci + # note: the global chip does not support goldilocks field yet # - name: Run fibonacci (release + goldilocks) # env: @@ -81,6 +93,18 @@ jobs: RUSTFLAGS: "-C opt-level=3" run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/ceno_rt_alloc + - name: Run 3 shards Guest Heap Alloc (debug) + env: + RUST_LOG: debug + RUSTFLAGS: "-C opt-level=3" + MOCK_PROVING: 1 + run: cargo run --package ceno_zkvm --features sanity-check --bin e2e -- --platform=ceno --min-cycle-per-shard=10 --max-cycle-per-shard=300 examples/target/riscv32im-ceno-zkvm-elf/debug/examples/ceno_rt_alloc + + - name: Run 3 shards Guest Heap Alloc (release) + env: + RUSTFLAGS: "-C opt-level=3" + run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno --min-cycle-per-shard=10 --max-cycle-per-shard=1600 examples/target/riscv32im-ceno-zkvm-elf/release/examples/keccak_syscall + # note: the global chip does not support goldilocks field yet # - name: Run Guest Heap Alloc (release + goldilocks) # env: @@ -92,6 +116,11 @@ jobs: RUSTFLAGS: "-C opt-level=3" run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/keccak_syscall + - name: Run 3 shard keccak_syscall (release) + env: + RUSTFLAGS: "-C opt-level=3" + run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno --min-cycle-per-shard=10 --max-cycle-per-shard=1600 examples/target/riscv32im-ceno-zkvm-elf/release/examples/keccak_syscall + - name: Run secp256k1_add_syscall (release) env: RUSTFLAGS: "-C opt-level=3" diff --git a/Cargo.lock b/Cargo.lock index 64902df6f..f8e7a0523 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1904,7 +1904,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" dependencies = [ "once_cell", "p3", @@ -2565,7 +2565,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.53.4", + "windows-targets 0.52.6", ] [[package]] @@ -2716,7 +2716,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" dependencies = [ "bincode", "clap", @@ -2740,7 +2740,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" dependencies = [ "either", "ff_ext", @@ -2972,7 +2972,7 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77e878c846a8abae00dd069496dbe8751b16ac1c3d6bd2a7283a938e8228f90d" dependencies = [ - "proc-macro-crate 3.4.0", + "proc-macro-crate 1.3.1", "proc-macro2", "quote", "syn 2.0.101", @@ -3061,7 +3061,7 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" dependencies = [ "p3-air", "p3-baby-bear", @@ -3498,7 +3498,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" dependencies = [ "ff_ext", "p3", @@ -4482,7 +4482,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" dependencies = [ "cfg-if", "dashu", @@ -4604,7 +4604,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" dependencies = [ "either", "ff_ext", @@ -4622,7 +4622,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" dependencies = [ "itertools 0.13.0", "p3", @@ -5017,7 +5017,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -5289,7 +5289,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" dependencies = [ "bincode", "clap", @@ -5576,7 +5576,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.12#5f6c787886163236c88c1f9c018aaeefb77e5801" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index 5db870151..abfa679ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,16 +23,16 @@ repository = "https://github.com/scroll-tech/ceno" version = "0.1.0" [workspace.dependencies] -ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.12" } -mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.12" } -multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.12" } -p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.12" } -poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.12" } -sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.12" } -sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.12" } -transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.12" } -whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.12" } -witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.12" } +ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.13" } +mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.13" } +multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.13" } +p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.13" } +poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.13" } +sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.13" } +sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.13" } +transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.13" } +whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.13" } +witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.13" } alloy-primitives = "1.3" anyhow = { version = "1.0", default-features = false } diff --git a/ceno_cli/src/commands/common_args/ceno.rs b/ceno_cli/src/commands/common_args/ceno.rs index 5a0e8535a..0b4485ac6 100644 --- a/ceno_cli/src/commands/common_args/ceno.rs +++ b/ceno_cli/src/commands/common_args/ceno.rs @@ -78,13 +78,21 @@ pub struct CenoOptions { #[arg(long)] pub out_vk: Option, - /// shard id + /// prover id #[arg(long, default_value = "0")] - shard_id: u32, + prover_id: u32, - /// number of total shards. + /// number of available prover. #[arg(long, default_value = "1")] - max_num_shards: u32, + num_provers: u32, + + // min cycle per shard + #[arg(long, default_value = "16777216")] // 16777216 = 2^24 + min_cycle_per_shard: u64, + + // max cycle per shard + #[arg(long, default_value = "536870912")] // 536870912 = 2^29 + max_cycle_per_shard: u64, /// Profiling granularity. /// Setting any value restricts logs to profiling information @@ -345,7 +353,12 @@ fn run_elf_inner< std::fs::read(elf_path).context(format!("failed to read {}", elf_path.display()))?; let program = Program::load_elf(&elf_bytes, u32::MAX).context("failed to load elf")?; print_cargo_message("Loaded", format_args!("{}", elf_path.display())); - let shards = Shards::new(options.shard_id as usize, options.max_num_shards as usize); + let multi_prover = MultiProver::new( + options.prover_id as usize, + options.num_provers as usize, + options.min_cycle_per_shard, + options.max_cycle_per_shard, + ); let public_io = options .read_public_io() @@ -394,7 +407,7 @@ fn run_elf_inner< create_prover(backend.clone()), program, platform, - shards, + multi_prover, &hints, &public_io, options.max_steps, @@ -439,12 +452,12 @@ fn prove_inner< checkpoint: Checkpoint, ) -> anyhow::Result<()> { let result = run_elf_inner::(args, compilation_options, elf_path, checkpoint)?; - let zkvm_proof = result.proof.expect("PrepSanityCheck should yield proof."); + let zkvm_proofs = result.proofs.expect("PrepSanityCheck should yield proof."); let vk = result.vk.expect("PrepSanityCheck should yield vk."); let start = std::time::Instant::now(); let verifier = ZKVMVerifier::new(vk); - if let Err(e) = verify(&zkvm_proof, &verifier) { + if let Err(e) = verify(zkvm_proofs.clone(), &verifier) { bail!("Verification failed: {e:?}"); } print_cargo_message( @@ -457,7 +470,7 @@ fn prove_inner< print_cargo_message("Writing", format_args!("proof to {}", path.display())); let proof_file = File::create(&path).context(format!("failed to create {}", path.display()))?; - bincode::serialize_into(proof_file, &zkvm_proof) + bincode::serialize_into(proof_file, &zkvm_proofs) .context("failed to serialize zkvm proof")?; } if let Some(out_vk) = args.out_vk.as_ref() { diff --git a/ceno_cli/src/commands/verify.rs b/ceno_cli/src/commands/verify.rs index ec45032da..8a133b103 100644 --- a/ceno_cli/src/commands/verify.rs +++ b/ceno_cli/src/commands/verify.rs @@ -53,7 +53,7 @@ fn run_inner + Serialize>( ) -> anyhow::Result<()> { let start = std::time::Instant::now(); - let zkvm_proof: ZKVMProof = + let zkvm_proofs: Vec> = bincode::deserialize_from(File::open(&args.proof).context("Failed to open proof file")?) .context("Failed to deserialize proof file")?; print_cargo_message( @@ -80,7 +80,7 @@ fn run_inner + Serialize>( let start = std::time::Instant::now(); let verifier = ZKVMVerifier::new(vk); - if let Err(e) = verify(&zkvm_proof, &verifier) { + if let Err(e) = verify(zkvm_proofs, &verifier) { bail!("Verification failed: {e:?}"); } diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index 325c59f46..65bc83896 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -13,7 +13,7 @@ use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; -use ceno_zkvm::{e2e::Shards, scheme::verifier::ZKVMVerifier}; +use ceno_zkvm::{e2e::MultiProver, scheme::verifier::ZKVMVerifier}; use mpcs::BasefoldDefault; use transcript::BasicTranscript; @@ -54,13 +54,16 @@ fn fibonacci_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), - Shards::default(), + MultiProver::default(), &Vec::from(&hints), &[], max_steps, Checkpoint::Complete, ); - let proof = result.proof.expect("PrepSanityCheck do not provide proof"); + let proof = result + .proofs + .expect("PrepSanityCheck do not provide proof") + .remove(0); let vk = result.vk.expect("PrepSanityCheck do not provide verifier"); println!("e2e proof {}", proof); @@ -92,7 +95,7 @@ fn fibonacci_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), - Shards::default(), + MultiProver::default(), &Vec::from(&hints), &[], max_steps, diff --git a/ceno_zkvm/benches/fibonacci_witness.rs b/ceno_zkvm/benches/fibonacci_witness.rs index d942743db..02b479da6 100644 --- a/ceno_zkvm/benches/fibonacci_witness.rs +++ b/ceno_zkvm/benches/fibonacci_witness.rs @@ -9,7 +9,7 @@ use std::{fs, path::PathBuf, time::Duration}; mod alloc; use criterion::*; -use ceno_zkvm::e2e::Shards; +use ceno_zkvm::e2e::MultiProver; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; use mpcs::BasefoldDefault; @@ -66,7 +66,7 @@ fn fibonacci_witness(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), - Shards::default(), + MultiProver::default(), &Vec::from(&hints), &[], max_steps, diff --git a/ceno_zkvm/benches/is_prime.rs b/ceno_zkvm/benches/is_prime.rs index 6d66ff859..2620c7ecb 100644 --- a/ceno_zkvm/benches/is_prime.rs +++ b/ceno_zkvm/benches/is_prime.rs @@ -8,7 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; -use ceno_zkvm::e2e::Shards; +use ceno_zkvm::e2e::MultiProver; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; @@ -63,7 +63,7 @@ fn is_prime_1(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), - Shards::default(), + MultiProver::default(), &hints, &[], max_steps, diff --git a/ceno_zkvm/benches/keccak.rs b/ceno_zkvm/benches/keccak.rs index 19011d460..9ab4eed1a 100644 --- a/ceno_zkvm/benches/keccak.rs +++ b/ceno_zkvm/benches/keccak.rs @@ -8,7 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; -use ceno_zkvm::{e2e::Shards, scheme::verifier::ZKVMVerifier}; +use ceno_zkvm::{e2e::MultiProver, scheme::verifier::ZKVMVerifier}; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; @@ -51,13 +51,16 @@ fn keccak_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), - Shards::default(), + MultiProver::default(), &Vec::from(&hints), &[], max_steps, Checkpoint::Complete, ); - let proof = result.proof.expect("PrepSanityCheck do not provide proof"); + let proof = result + .proofs + .expect("PrepSanityCheck do not provide proof") + .remove(0); let vk = result.vk.expect("PrepSanityCheck do not provide verifier"); println!("e2e proof {}", proof); @@ -86,7 +89,7 @@ fn keccak_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), - Shards::default(), + MultiProver::default(), &Vec::from(&hints), &[], max_steps, diff --git a/ceno_zkvm/benches/quadratic_sorting.rs b/ceno_zkvm/benches/quadratic_sorting.rs index 93389c388..b4d0a66f5 100644 --- a/ceno_zkvm/benches/quadratic_sorting.rs +++ b/ceno_zkvm/benches/quadratic_sorting.rs @@ -8,7 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; -use ceno_zkvm::e2e::Shards; +use ceno_zkvm::e2e::MultiProver; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; @@ -64,7 +64,7 @@ fn quadratic_sorting_1(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), - Shards::default(), + MultiProver::default(), &hints, &[], max_steps, diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 9d8cc22e8..20b7da4cd 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -55,7 +55,7 @@ fn bench_add(c: &mut Criterion) { let pk = zkvm_cs .clone() - .key_gen::(pp, vp, zkvm_fixed_traces) + .key_gen::(pp, vp, 0, zkvm_fixed_traces) .expect("keygen failed"); let (max_num_variables, security_level) = default_backend_config(); @@ -111,6 +111,7 @@ fn bench_add(c: &mut Criterion) { witness: polys, structural_witness: vec![], public_input: vec![], + pub_io_evals: vec![], num_instances: vec![num_instances], has_ecc_ops: false, }; diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index b9d47d09e..e59db55ff 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -4,8 +4,8 @@ use ceno_host::{CenoStdin, memory_from_file}; use ceno_zkvm::print_allocated_bytes; use ceno_zkvm::{ e2e::{ - Checkpoint, FieldType, PcsKind, Preset, Shards, run_e2e_with_checkpoint, setup_platform, - setup_platform_debug, verify, + Checkpoint, FieldType, MultiProver, PcsKind, Preset, run_e2e_with_checkpoint, + setup_platform, setup_platform_debug, verify, }, scheme::{ ZKVMProof, constants::MAX_NUM_VARIABLES, create_backend, create_prover, hal::ProverDevice, @@ -109,13 +109,21 @@ struct Args { #[arg(short, long, value_enum, default_value_t = SecurityLevel::default())] security_level: SecurityLevel, - // shard id + // prover id #[arg(long, default_value = "0")] - shard_id: u32, + prover_id: u32, - // number of total shards + // number of available prover. #[arg(long, default_value = "1")] - max_num_shards: u32, + num_provers: u32, + + // min cycle per shard + #[arg(long, default_value = "16777216")] // 16777216 = 2^24 + min_cycle_per_shard: u64, + + // max cycle per shard + #[arg(long, default_value = "536870912")] // 536870912 = 2^29 + max_cycle_per_shard: u64, } fn main() { @@ -248,7 +256,12 @@ fn main() { .unwrap_or_default(); let max_steps = args.max_steps.unwrap_or(usize::MAX); - let shards = Shards::new(args.shard_id as usize, args.max_num_shards as usize); + let multi_prover = MultiProver::new( + args.prover_id as usize, + args.num_provers as usize, + args.min_cycle_per_shard, + args.max_cycle_per_shard, + ); match (args.pcs, args.field) { (PcsKind::Basefold, FieldType::Goldilocks) => { @@ -258,7 +271,7 @@ fn main() { prover, program, platform, - shards, + multi_prover, &hints, &public_io, max_steps, @@ -274,7 +287,7 @@ fn main() { prover, program, platform, - shards, + multi_prover, &hints, &public_io, max_steps, @@ -290,7 +303,7 @@ fn main() { prover, program, platform, - shards, + multi_prover, &hints, &public_io, max_steps, @@ -306,7 +319,7 @@ fn main() { prover, program, platform, - shards, + multi_prover, &hints, &public_io, max_steps, @@ -333,7 +346,7 @@ fn run_inner< pd: PD, program: Program, platform: Platform, - shards: Shards, + multi_prover: MultiProver, hints: &[u32], public_io: &[u32], max_steps: usize, @@ -342,23 +355,30 @@ fn run_inner< checkpoint: Checkpoint, ) { let result = run_e2e_with_checkpoint::( - pd, program, platform, shards, hints, public_io, max_steps, checkpoint, + pd, + program, + platform, + multi_prover, + hints, + public_io, + max_steps, + checkpoint, ); - let zkvm_proof = result - .proof + let zkvm_proofs = result + .proofs .expect("PrepSanityCheck should yield zkvm_proof."); let vk = result.vk.expect("PrepSanityCheck should yield vk."); - let proof_bytes = bincode::serialize(&zkvm_proof).unwrap(); + let proof_bytes = bincode::serialize(&zkvm_proofs).unwrap(); fs::write(&proof_file, proof_bytes).unwrap(); let vk_bytes = bincode::serialize(&vk).unwrap(); fs::write(&vk_file, vk_bytes).unwrap(); if checkpoint > Checkpoint::PrepVerify { let verifier = ZKVMVerifier::new(vk); - verify(&zkvm_proof, &verifier).expect("Verification failed"); - soundness_test(zkvm_proof, &verifier); + verify(zkvm_proofs.clone(), &verifier).expect("Verification failed"); + soundness_test(zkvm_proofs.first().cloned().unwrap(), &verifier); } } diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 9ea76595a..7822a9d5d 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -4,8 +4,8 @@ use gkr_iop::{error::CircuitBuilderError, tables::LookupTable}; use crate::{ circuit_builder::CircuitBuilder, instructions::riscv::constants::{ - END_CYCLE_IDX, END_PC_IDX, END_SHARD_ID_IDX, EXIT_CODE_IDX, GLOBAL_RW_SUM_IDX, - INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX, UINT_LIMBS, + END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, GLOBAL_RW_SUM_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, + PUBLIC_IO_IDX, SHARD_ID_IDX, UINT_LIMBS, }, scheme::constants::SEPTIC_EXTENSION_DEGREE, tables::InsnRecord, @@ -42,53 +42,46 @@ impl<'a, E: ExtensionField> InstFetch for CircuitBuilder<'a, E> { impl<'a, E: ExtensionField> PublicIOQuery for CircuitBuilder<'a, E> { fn query_exit_code(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError> { Ok([ - self.cs.query_instance(|| "exit_code_low", EXIT_CODE_IDX)?, - self.cs - .query_instance(|| "exit_code_high", EXIT_CODE_IDX + 1)?, + self.cs.query_instance(EXIT_CODE_IDX)?, + self.cs.query_instance(EXIT_CODE_IDX + 1)?, ]) } fn query_init_pc(&mut self) -> Result { - self.cs.query_instance(|| "init_pc", INIT_PC_IDX) + self.cs.query_instance(INIT_PC_IDX) } fn query_init_cycle(&mut self) -> Result { - self.cs.query_instance(|| "init_cycle", INIT_CYCLE_IDX) + self.cs.query_instance(INIT_CYCLE_IDX) } fn query_end_pc(&mut self) -> Result { - self.cs.query_instance(|| "end_pc", END_PC_IDX) + self.cs.query_instance(END_PC_IDX) } fn query_end_cycle(&mut self) -> Result { - self.cs.query_instance(|| "end_cycle", END_CYCLE_IDX) + self.cs.query_instance(END_CYCLE_IDX) } fn query_shard_id(&mut self) -> Result { - self.cs.query_instance(|| "shard_id", END_SHARD_ID_IDX) + self.cs.query_instance(SHARD_ID_IDX) } fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError> { Ok([ - self.cs.query_instance(|| "public_io_low", PUBLIC_IO_IDX)?, - self.cs - .query_instance(|| "public_io_high", PUBLIC_IO_IDX + 1)?, + self.cs.query_instance_for_openings(PUBLIC_IO_IDX)?, + self.cs.query_instance_for_openings(PUBLIC_IO_IDX + 1)?, ]) } fn query_global_rw_sum(&mut self) -> Result, CircuitBuilderError> { let x = (0..SEPTIC_EXTENSION_DEGREE) - .map(|i| { - self.cs - .query_instance(|| format!("global_rw_sum_x_{}", i), GLOBAL_RW_SUM_IDX + i) - }) + .map(|i| self.cs.query_instance(GLOBAL_RW_SUM_IDX + i)) .collect::, CircuitBuilderError>>()?; let y = (0..SEPTIC_EXTENSION_DEGREE) .map(|i| { - self.cs.query_instance( - || format!("global_rw_sum_y_{}", i), - GLOBAL_RW_SUM_IDX + SEPTIC_EXTENSION_DEGREE + i, - ) + self.cs + .query_instance(GLOBAL_RW_SUM_IDX + SEPTIC_EXTENSION_DEGREE + i) }) .collect::, CircuitBuilderError>>()?; diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 62f3e425f..75ad40e7a 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1,6 +1,9 @@ use crate::{ error::ZKVMError, - instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig}, + instructions::{ + global::GlobalChip, + riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig}, + }, scheme::{ PublicValues, ZKVMProof, constants::SEPTIC_EXTENSION_DEGREE, @@ -14,7 +17,9 @@ use crate::{ ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMProvingKey, ZKVMVerifyingKey, ZKVMWitnesses, }, - tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig}, + tables::{ + MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig, TableCircuit, + }, }; use ceno_emul::{ Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, IterAddresses, NextCycleAccess, @@ -23,7 +28,7 @@ use ceno_emul::{ }; use clap::ValueEnum; use either::Either; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, SmallField}; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; use gkr_iop::{RAMType, hal::ProverBackend}; @@ -31,15 +36,18 @@ use itertools::{Itertools, MinMaxResult, chain}; use mpcs::{PolynomialCommitmentScheme, SecurityLevel}; use multilinear_extensions::util::max_usable_threads; use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use rustc_hash::FxHashSet; use serde::Serialize; use std::{ - borrow::Cow, collections::{BTreeMap, BTreeSet, HashMap, HashSet}, sync::Arc, }; use transcript::BasicTranscript as Transcript; use witness::next_pow2_instance_padding; +pub const DEFAULT_MIN_CYCLE_PER_SHARDS: Cycle = 1 << 24; +pub const DEFAULT_MAX_CYCLE_PER_SHARDS: Cycle = 1 << 27; + /// The polynomial commitment scheme kind #[derive( Default, @@ -99,7 +107,7 @@ pub struct EmulationResult<'a> { pub all_records: Vec, pub final_mem_state: FinalMemState, pub pi: PublicValues, - pub shard_ctx: ShardContext<'a>, + pub shard_ctxs: Vec>, } pub struct RAMRecord { @@ -119,43 +127,50 @@ pub struct RAMRecord { } #[derive(Clone, Debug)] -pub struct Shards { - pub shard_id: usize, - pub max_num_shards: usize, +pub struct MultiProver { + pub prover_id: usize, + pub max_provers: usize, + pub min_cycle_per_shard: Cycle, + pub max_cycle_per_shard: Cycle, } -impl Shards { - pub fn new(shard_id: usize, max_num_shards: usize) -> Self { - assert!(shard_id < max_num_shards); +impl MultiProver { + pub fn new( + prover_id: usize, + max_provers: usize, + min_cycle_per_shard: Cycle, + max_cycle_per_shard: Cycle, + ) -> Self { + assert!(prover_id < max_provers); Self { - shard_id, - max_num_shards, + prover_id, + max_provers, + min_cycle_per_shard, + max_cycle_per_shard, } } - - pub fn is_first_shard(&self) -> bool { - self.shard_id == 0 - } - - pub fn is_last_shard(&self) -> bool { - self.shard_id == self.max_num_shards - 1 - } } -impl Default for Shards { +impl Default for MultiProver { fn default() -> Self { Self { - shard_id: 0, - max_num_shards: 1, + prover_id: 0, + max_provers: 1, + min_cycle_per_shard: DEFAULT_MIN_CYCLE_PER_SHARDS, + max_cycle_per_shard: DEFAULT_MAX_CYCLE_PER_SHARDS, } } } pub struct ShardContext<'a> { - shards: Shards, + shard_id: usize, + num_shards: usize, max_cycle: Cycle, // TODO optimize this map as it's super huge - addr_future_accesses: Cow<'a, NextCycleAccess>, + addr_future_accesses: Arc, + // this is only updated in first shard + addr_accessed_thread_based_first_shard: + Either>, &'a mut FxHashSet>, read_thread_based_record_storage: Either>, &'a mut BTreeMap>, write_thread_based_record_storage: @@ -168,9 +183,16 @@ impl<'a> Default for ShardContext<'a> { fn default() -> Self { let max_threads = max_usable_threads(); Self { - shards: Shards::default(), + shard_id: 0, + num_shards: 1, max_cycle: Cycle::default(), - addr_future_accesses: Cow::Owned(Default::default()), + addr_future_accesses: Arc::new(Default::default()), + addr_accessed_thread_based_first_shard: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| Default::default()) + .collect::>(), + ), read_thread_based_record_storage: Either::Left( (0..max_threads) .into_par_iter() @@ -189,72 +211,160 @@ impl<'a> Default for ShardContext<'a> { } } +/// `prover_id` and `num_provers` in MultiProver are exposed as arguments +/// to specify the number of physical provers in a cluster, +/// each mark with a prover_id. +/// The overall trace data is divided into shards, which are distributed evenly among the provers. +/// The number of shards are in general agnostic to number of provers. +/// Each prover is assigned n shard where n can be even empty +/// +/// Shard distribution follows a balanced allocation strategy +/// for example, if there are 10 shards and 3 provers, +/// the shard counts will be distributed as 3, 3, and 4, ensuring an even workload across all provers. impl<'a> ShardContext<'a> { pub fn new( - shards: Shards, + multi_prover: MultiProver, executed_instructions: usize, addr_future_accesses: NextCycleAccess, - ) -> Self { - // current strategy: at least each shard deal with one instruction - let max_num_shards = shards.max_num_shards.min(executed_instructions); + ) -> Vec { + let min_cycle_per_shard = multi_prover.min_cycle_per_shard; + let max_cycle_per_shard = multi_prover.max_cycle_per_shard; assert!( - shards.shard_id < max_num_shards, - "implement mechanism to skip current shard proof" + min_cycle_per_shard < max_cycle_per_shard, + "invalid input: min_cycle_per_shard {min_cycle_per_shard} >= max_cycle_per_shard {max_cycle_per_shard}" ); - let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN as usize; let max_threads = max_usable_threads(); - let expected_inst_per_shard = executed_instructions.div_ceil(max_num_shards); - let max_cycle = (executed_instructions + 1) * subcycle_per_insn; // cycle start from subcycle_per_insn - let cur_shard_cycle_range = (shards.shard_id * expected_inst_per_shard * subcycle_per_insn - + subcycle_per_insn) - ..((shards.shard_id + 1) * expected_inst_per_shard * subcycle_per_insn - + subcycle_per_insn) - .min(max_cycle); - - ShardContext { - shards, - max_cycle: max_cycle as Cycle, - addr_future_accesses: Cow::Owned(addr_future_accesses), - // TODO with_capacity optimisation - read_thread_based_record_storage: Either::Left( - (0..max_threads) - .into_par_iter() - .map(|_| BTreeMap::new()) - .collect::>(), - ), - // TODO with_capacity optimisation - write_thread_based_record_storage: Either::Left( - (0..max_threads) - .into_par_iter() - .map(|_| BTreeMap::new()) - .collect::>(), - ), - cur_shard_cycle_range, - expected_inst_per_shard, + + // strategies + // 0. set cur_num_shards = num_provers + // 1. split instructions evenly by cur_num_shards + // 2. stop if min_inst <= shard instructions < max_inst + // 3.1 if shard instructions >= max_inst, update cur_num_shards += 1 then goes to 1 + // 3.2 if shard instructions < min_inst, update cur_num_shards -= 1 then goes to 1 + const MAX_ITER: usize = 1000; + let mut num_shards = multi_prover.max_provers; + let mut last_shard_count = None; + let mut expected_inst_per_shard = 0; + for _ in 0..MAX_ITER { + expected_inst_per_shard = executed_instructions.div_ceil(num_shards); + let expected_cycle_per_shard = expected_inst_per_shard * subcycle_per_insn; + if (min_cycle_per_shard as usize..max_cycle_per_shard as usize) + .contains(&expected_cycle_per_shard) + { + break; + } + + if expected_cycle_per_shard >= max_cycle_per_shard as usize { + num_shards += 1; + } else if expected_cycle_per_shard < min_cycle_per_shard as usize { + if num_shards == 1 { + break; + } + num_shards -= 1; + } + + // Detect oscillation (no progress) + if let Some(last_shard_count) = last_shard_count + && last_shard_count == num_shards + { + panic!( + "no convergence detected: shard count stuck at {num_shards}, \ + per-shard={expected_inst_per_shard}" + ); + } + + last_shard_count = Some(num_shards); } + + // generated shards belong to this prover id + let prover_id_shards_mapping = + Self::distribute_shards_into_provers(num_shards, multi_prover.max_provers); + assert!(multi_prover.prover_id < prover_id_shards_mapping.len()); + + let max_cycle = (executed_instructions + 1) * subcycle_per_insn; // cycle start from subcycle_per_insn + let addr_future_accesses = Arc::new(addr_future_accesses); + + // sum for all shards before prover id + let start = prover_id_shards_mapping + .iter() + .take(multi_prover.prover_id) + .sum::(); + // length of shards belong to prover id + let shard_len = prover_id_shards_mapping[multi_prover.prover_id]; + tracing::info!( + "total num_shards {num_shards}, num_shards belong to this prover: {shard_len}, multi-prover {:?}", + multi_prover + ); + let end = start + shard_len; + (start..end) + .map(|shard_id| { + let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * subcycle_per_insn + + subcycle_per_insn) + ..((shard_id + 1) * expected_inst_per_shard * subcycle_per_insn + + subcycle_per_insn) + .min(max_cycle); + ShardContext { + shard_id, + num_shards, + max_cycle: max_cycle as Cycle, + addr_future_accesses: addr_future_accesses.clone(), + addr_accessed_thread_based_first_shard: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| Default::default()) + .collect::>(), + ), + // TODO with_capacity optimisation + read_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| BTreeMap::new()) + .collect::>(), + ), + // TODO with_capacity optimisation + write_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| BTreeMap::new()) + .collect::>(), + ), + cur_shard_cycle_range, + expected_inst_per_shard, + } + }) + .collect_vec() } pub fn get_forked(&mut self) -> Vec> { match ( &mut self.read_thread_based_record_storage, &mut self.write_thread_based_record_storage, + &mut self.addr_accessed_thread_based_first_shard, ) { ( Either::Left(read_thread_based_record_storage), Either::Left(write_thread_based_record_storage), + Either::Left(addr_accessed_thread_based_first_shard), ) => read_thread_based_record_storage .iter_mut() .zip(write_thread_based_record_storage.iter_mut()) - .map(|(read, write)| ShardContext { - shards: self.shards.clone(), - max_cycle: self.max_cycle, - addr_future_accesses: Cow::Borrowed(self.addr_future_accesses.as_ref()), - read_thread_based_record_storage: Either::Right(read), - write_thread_based_record_storage: Either::Right(write), - cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), - expected_inst_per_shard: self.expected_inst_per_shard, - }) + .zip(addr_accessed_thread_based_first_shard.iter_mut()) + .map( + |((read, write), addr_accessed_thread_based_first_shard)| ShardContext { + shard_id: self.shard_id, + num_shards: self.num_shards, + max_cycle: self.max_cycle, + addr_future_accesses: self.addr_future_accesses.clone(), + addr_accessed_thread_based_first_shard: Either::Right( + addr_accessed_thread_based_first_shard, + ), + read_thread_based_record_storage: Either::Right(read), + write_thread_based_record_storage: Either::Right(write), + cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), + expected_inst_per_shard: self.expected_inst_per_shard, + }, + ) .collect_vec(), _ => panic!("invalid type"), } @@ -276,12 +386,12 @@ impl<'a> ShardContext<'a> { #[inline(always)] pub fn is_first_shard(&self) -> bool { - self.shards.shard_id == 0 + self.shard_id == 0 } #[inline(always)] pub fn is_last_shard(&self) -> bool { - self.shards.shard_id == self.shards.max_num_shards - 1 + self.shard_id == self.num_shards - 1 } #[inline(always)] @@ -289,6 +399,16 @@ impl<'a> ShardContext<'a> { self.cur_shard_cycle_range.contains(&(cycle as usize)) } + #[inline(always)] + pub fn before_current_shard_cycle(&self, cycle: Cycle) -> bool { + (cycle as usize) < self.cur_shard_cycle_range.start + } + + #[inline(always)] + pub fn after_current_shard_cycle(&self, cycle: Cycle) -> bool { + (cycle as usize) >= self.cur_shard_cycle_range.end + } + #[inline(always)] pub fn extract_prev_shard_id(&self, cycle: Cycle) -> usize { let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN; @@ -316,6 +436,23 @@ impl<'a> ShardContext<'a> { (self.cur_shard_cycle_range.start as Cycle) - Tracer::SUBCYCLES_PER_INSN } + #[inline(always)] + pub fn find_future_next_access(&self, cycle: Cycle, addr: WordAddr) -> Option { + self.addr_future_accesses + .get(cycle as usize) + .and_then(|res| { + if res.len() == 1 { + Some(res[0].1) + } else if res.len() > 1 { + res.iter() + .find(|(m_addr, _)| *m_addr == addr) + .map(|(_, cycle)| *cycle) + } else { + None + } + }) + } + #[inline(always)] #[allow(clippy::too_many_arguments)] pub fn send( @@ -330,7 +467,7 @@ impl<'a> ShardContext<'a> { ) { // check read from external mem bus // exclude first shard - if prev_cycle < self.cur_shard_cycle_range.start as Cycle + if self.before_current_shard_cycle(prev_cycle) && self.is_current_shard_cycle(cycle) && !self.is_first_shard() { @@ -357,21 +494,8 @@ impl<'a> ShardContext<'a> { } // check write to external mem bus - if let Some(future_touch_cycle) = - self.addr_future_accesses - .get(cycle as usize) - .and_then(|res| { - if res.len() == 1 { - Some(res[0].1) - } else if res.len() > 1 { - res.iter() - .find(|(m_addr, _)| *m_addr == addr) - .map(|(_, cycle)| *cycle) - } else { - None - } - }) - && future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle + if let Some(future_touch_cycle) = self.find_future_next_access(cycle, addr) + && self.after_current_shard_cycle(future_touch_cycle) && self.is_current_shard_cycle(cycle) { let shard_cycle = self.aligned_current_ts(cycle); @@ -391,10 +515,93 @@ impl<'a> ShardContext<'a> { shard_cycle, prev_value, value, - shard_id: self.shards.shard_id, + shard_id: self.shard_id, }, ); } + + if self.is_first_shard() { + let addr_accessed = self + .addr_accessed_thread_based_first_shard + .as_mut() + .right() + .expect("illegal type"); + addr_accessed.insert(addr); + } + } + + /// merge map from different thread, which keep the largest cycle when matched same address + pub fn get_addr_accessed_first_shard(&self) -> FxHashSet { + let mut merged = FxHashSet::default(); + let addr_accessed_thread_based_first_shard = + match &self.addr_accessed_thread_based_first_shard { + Either::Left(addr_accessed_thread_based_first_shard) => { + addr_accessed_thread_based_first_shard + } + Either::Right(_) => panic!("invalid type"), + }; + + for s in addr_accessed_thread_based_first_shard { + merged.extend(s); + } + merged + } + + /// Splits a total count `num_shards` into up to `num_provers` non-empty parts, distributing as evenly as possible. + /// + /// # Behavior + /// + /// - If `num_shards == 0` or `num_provers == 0`, returns an empty vector `[]`. + /// - If `num_shards <= num_provers`, each part will have size `1`, and the total number of parts equals `num_shards`. + /// - Otherwise, divides `num_shards` evenly across `num_provers` parts so that: + /// - The first `num_shards % num_provers` parts get `base + 1` elements, + /// - The rest get `base` elements, + /// where `base = num_shards / num_provers`. + /// + /// This ensures that: + /// - Every part is non-zero in size. + /// - The sum of all parts equals `num_shards`. + /// - The distribution is as balanced as possible (difference <= 1). + /// + /// # Examples + /// + /// ``` + /// # fn main() { + /// use ceno_zkvm::e2e::ShardContext; + /// assert_eq!(ShardContext::distribute_shards_into_provers(3, 2), vec![2, 1]); + /// assert_eq!(ShardContext::distribute_shards_into_provers(4, 2), vec![2, 2]); + /// assert_eq!(ShardContext::distribute_shards_into_provers(5, 2), vec![3, 2]); + /// assert_eq!(ShardContext::distribute_shards_into_provers(10, 3), vec![4, 3, 3]); + /// + /// // When n <= m, each item gets its own shard. + /// assert_eq!(ShardContext::distribute_shards_into_provers(1, 2), vec![1]); + /// assert_eq!(ShardContext::distribute_shards_into_provers(2, 3), vec![1, 1]); + /// assert_eq!(ShardContext::distribute_shards_into_provers(3, 4), vec![1, 1, 1]); + /// + /// // Edge cases + /// assert_eq!(ShardContext::distribute_shards_into_provers(0, 3), Vec::::new()); + /// assert_eq!(ShardContext::distribute_shards_into_provers(5, 0), Vec::::new()); + /// # } + /// ``` + /// # Returns + /// + /// A `Vec` representing the size of each part, whose total sum equals `n`. + pub fn distribute_shards_into_provers(num_shards: usize, num_provers: usize) -> Vec { + if num_shards == 0 || num_provers == 0 { + return vec![]; + } + + // If there are more shards than items, just give each item its own shard + if num_shards <= num_provers { + return vec![1; num_shards]; + } + + let base = num_shards / num_provers; + let remainder = num_shards % num_provers; + + (0..num_provers) + .map(|i| if i < remainder { base + 1 } else { base }) + .collect() } } @@ -403,7 +610,7 @@ pub fn emulate_program<'a>( max_steps: usize, init_mem_state: &InitMemState, platform: &Platform, - shards: &Shards, + multi_prover: &MultiProver, ) -> EmulationResult<'a> { let InitMemState { mem: mem_init, @@ -461,7 +668,7 @@ pub fn emulate_program<'a>( Tracer::SUBCYCLES_PER_INSN, vm.get_pc().into(), end_cycle, - shards.shard_id as u32, + multi_prover.prover_id as u32, io_init.iter().map(|rec| rec.value).collect_vec(), vec![0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity ); @@ -477,6 +684,7 @@ pub fn emulate_program<'a>( ram_type: RAMType::Register, addr: rec.addr, value: vm.peek_register(index), + init_value: rec.value, cycle: *final_access.get(&vma).unwrap_or(&0), } } else { @@ -485,6 +693,7 @@ pub fn emulate_program<'a>( ram_type: RAMType::Register, addr: rec.addr, value: 0, + init_value: 0, cycle: 0, } } @@ -500,6 +709,7 @@ pub fn emulate_program<'a>( ram_type: RAMType::Memory, addr: rec.addr, value: vm.peek_memory(vma), + init_value: rec.value, cycle: *final_access.get(&vma).unwrap_or(&0), } }) @@ -512,6 +722,7 @@ pub fn emulate_program<'a>( ram_type: RAMType::Memory, addr: rec.addr, value: rec.value, + init_value: rec.value, cycle: *final_access.get(&rec.addr.into()).unwrap_or(&0), }) .collect_vec(); @@ -523,6 +734,7 @@ pub fn emulate_program<'a>( ram_type: RAMType::Memory, addr: rec.addr, value: rec.value, + init_value: rec.value, cycle: *final_access.get(&rec.addr.into()).unwrap_or(&0), }) .collect_vec(); @@ -541,6 +753,7 @@ pub fn emulate_program<'a>( ram_type: RAMType::Memory, addr: byte_addr.0, value: vm.peek_memory(vma), + init_value: 0, cycle: *final_access.get(&vma).unwrap_or(&0), } }) @@ -565,6 +778,7 @@ pub fn emulate_program<'a>( ram_type: RAMType::Memory, addr: byte_addr.0, value: vm.peek_memory(vma), + init_value: 0, cycle: *final_access.get(&vma).unwrap_or(&0), } }) @@ -584,13 +798,17 @@ pub fn emulate_program<'a>( ), ); - let shard_ctx = ShardContext::new(shards.clone(), insts, vm.take_tracer().next_accesses()); + let shard_ctxs = ShardContext::new( + multi_prover.clone(), + insts, + vm.take_tracer().next_accesses(), + ); EmulationResult { pi, exit_code, all_records, - shard_ctx, + shard_ctxs, final_mem_state: FinalMemState { reg: reg_final, io: io_final, @@ -765,67 +983,165 @@ pub fn generate_fixed_traces( zkvm_fixed_traces } -pub fn generate_witness( +pub fn generate_witness<'a, E: ExtensionField>( system_config: &ConstraintSystemConfig, - mut emul_result: EmulationResult, + mut emul_result: EmulationResult<'a>, program: &Program, -) -> ZKVMWitnesses { - let mut zkvm_witness = ZKVMWitnesses::default(); - // assign opcode circuits - let dummy_records = system_config - .config - .assign_opcode_circuit( - &system_config.zkvm_cs, - &mut emul_result.shard_ctx, - &mut zkvm_witness, - emul_result.all_records, - ) - .unwrap(); - system_config - .dummy_config - .assign_opcode_circuit( - &system_config.zkvm_cs, - &mut emul_result.shard_ctx, - &mut zkvm_witness, - dummy_records, - ) - .unwrap(); - zkvm_witness.finalize_lk_multiplicities(); - - // assign table circuits - system_config - .config - .assign_table_circuit(&system_config.zkvm_cs, &mut zkvm_witness) - .unwrap(); - system_config - .mmu_config - .assign_table_circuit( - &system_config.zkvm_cs, - &emul_result.shard_ctx, - &mut zkvm_witness, - &emul_result.final_mem_state.reg, - &emul_result.final_mem_state.mem, - &emul_result - .final_mem_state - .io - .iter() - .map(|rec| rec.cycle) - .collect_vec(), - &emul_result.final_mem_state.hints, - &emul_result.final_mem_state.stack, - &emul_result.final_mem_state.heap, - ) - .unwrap(); - // assign program circuit - zkvm_witness - .assign_table_circuit::>( - &system_config.zkvm_cs, - &system_config.prog_config, - program, - ) - .unwrap(); +) -> impl Iterator, ShardContext<'a>, PublicValues)> { + let shard_ctxs = std::mem::take(&mut emul_result.shard_ctxs); + assert!(!shard_ctxs.is_empty()); + let mut all_records = std::mem::take(&mut emul_result.all_records); + assert!(!all_records.is_empty()); + + tracing::debug!( + "first shard cycle range {:?}", + shard_ctxs[0].cur_shard_cycle_range + ); + // clean up all records before first shard start cycle, as it's not belong to current prover + let start = all_records.iter().position(|step| { + shard_ctxs[0] + .cur_shard_cycle_range + .contains(&(step.cycle() as usize)) + }); + + if let Some(start) = start { + tracing::debug!("drop {} records as not belong to current shard", start); + // Drop everything before `start` efficiently + let tail = all_records.split_off(start); + all_records = tail; + } - zkvm_witness + let pi = std::mem::take(&mut emul_result.pi); + shard_ctxs.into_iter().map(move |mut shard_ctx| { + // assume public io clone low cost + let mut pi = pi.clone(); + let n = all_records + .iter() + .take_while(|step| shard_ctx.is_current_shard_cycle(step.cycle())) + .count(); + let mut filtered_steps = all_records.split_off(n); // moves pointer boundary, no mem shift + std::mem::swap(&mut all_records, &mut filtered_steps); + + tracing::debug!("{}th shard collect {n} steps", shard_ctx.shard_id); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let current_shard_end_cycle = filtered_steps.last().unwrap().cycle() + + Tracer::SUBCYCLES_PER_INSN + - current_shard_offset_cycle; + let current_shard_init_pc = if shard_ctx.is_first_shard() { + program.entry + } else { + filtered_steps[0].pc().before.0 + }; + let current_shard_end_pc = filtered_steps.last().unwrap().pc().after.0; + + let mut zkvm_witness = ZKVMWitnesses::default(); + // assign opcode circuits + let dummy_records = system_config + .config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut shard_ctx, + &mut zkvm_witness, + filtered_steps, + ) + .unwrap(); + system_config + .dummy_config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut shard_ctx, + &mut zkvm_witness, + dummy_records, + ) + .unwrap(); + zkvm_witness.finalize_lk_multiplicities(); + + // assign table circuits + system_config + .config + .assign_table_circuit(&system_config.zkvm_cs, &mut zkvm_witness) + .unwrap(); + + if shard_ctx.is_first_shard() { + // assign init table on first shard + system_config + .mmu_config + .assign_init_table_circuit( + &system_config.zkvm_cs, + &mut zkvm_witness, + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result.final_mem_state.io, + &emul_result.final_mem_state.hints, + &emul_result.final_mem_state.stack, + &emul_result.final_mem_state.heap, + ) + .unwrap(); + } else { + // empty assignment + system_config + .mmu_config + .assign_init_table_circuit( + &system_config.zkvm_cs, + &mut zkvm_witness, + &[], + &[], + &[], + &[], + &[], + &[], + ) + .unwrap(); + } + + // assign continuation circuit + system_config + .mmu_config + .assign_continuation_circuit( + &system_config.zkvm_cs, + &shard_ctx, + &mut zkvm_witness, + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result.final_mem_state.io, + &emul_result.final_mem_state.hints, + &emul_result.final_mem_state.stack, + &emul_result.final_mem_state.heap, + ) + .unwrap(); + + // assign program circuit + zkvm_witness + .assign_table_circuit::>( + &system_config.zkvm_cs, + &system_config.prog_config, + program, + ) + .unwrap(); + + pi.init_pc = current_shard_init_pc; + pi.init_cycle = Tracer::SUBCYCLES_PER_INSN; + pi.shard_id = shard_ctx.shard_id as u32; + pi.end_pc = current_shard_end_pc; + pi.end_cycle = current_shard_end_cycle; + // set shard ram bus expected output to pi + let global_chip_withess = zkvm_witness.get_table_witness(&GlobalChip::::name()); + if let Some(global_chip_withess) = global_chip_withess + && global_chip_withess[0].num_instances() > 0 + { + for (f, v) in GlobalChip::::extract_ec_sum( + &system_config.mmu_config.ram_bus_circuit, + &global_chip_withess[0], + ) + .into_iter() + .zip_eq(pi.global_sum.as_mut_slice()) + { + *v = f.to_canonical_u64() as u32; + } + } + + (zkvm_witness, shard_ctx, pi) + }) } // Encodes useful early return points of the e2e pipeline @@ -846,7 +1162,7 @@ pub type IntermediateState = (Option>, Option { pub program: Arc, pub platform: Platform, - pub shards: Shards, + pub multi_prover: MultiProver, pub static_addrs: Vec, pub pubio_len: usize, pub system_config: ConstraintSystemConfig<'a, E>, @@ -858,7 +1174,7 @@ pub struct E2EProgramCtx<'a, E: ExtensionField> { /// end-to-end pipeline result, stopping at a certain checkpoint pub struct E2ECheckpointResult> { /// The proof generated by the pipeline, if any - pub proof: Option>, + pub proofs: Option>>, /// The verifying key generated by the pipeline, if any pub vk: Option>, /// The next step to run after the checkpoint @@ -877,7 +1193,7 @@ impl> E2ECheckpointResult< pub fn setup_program<'a, E: ExtensionField>( program: Program, platform: Platform, - shards: Shards, + multi_prover: MultiProver, ) -> E2EProgramCtx<'a, E> { let static_addrs = init_static_addrs(&program); let pubio_len = platform.public_io.iter_addresses().len(); @@ -903,7 +1219,7 @@ pub fn setup_program<'a, E: ExtensionField>( E2EProgramCtx { program: Arc::new(program), platform, - shards, + multi_prover, static_addrs, pubio_len, system_config, @@ -926,7 +1242,12 @@ impl E2EProgramCtx<'_, E> { .system_config .zkvm_cs .clone() - .key_gen::(pp.clone(), vp.clone(), self.zkvm_fixed_traces.clone()) + .key_gen::( + pp.clone(), + vp.clone(), + self.program.entry, + self.zkvm_fixed_traces.clone(), + ) .expect("keygen failed"); let vk = pk.get_vk_slow(); (pk, vk) @@ -946,6 +1267,7 @@ impl E2EProgramCtx<'_, E> { .key_gen::( pb.get_pp().clone(), pb.get_vp().clone(), + self.program.entry, self.zkvm_fixed_traces.clone(), ) .expect("keygen failed"); @@ -996,14 +1318,14 @@ pub fn run_e2e_with_checkpoint< device: PD, program: Program, platform: Platform, - shards: Shards, + multi_prover: MultiProver, hints: &[u32], public_io: &[u32], max_steps: usize, checkpoint: Checkpoint, ) -> E2ECheckpointResult { let start = std::time::Instant::now(); - let ctx = setup_program::(program, platform, shards); + let ctx = setup_program::(program, platform, multi_prover); tracing::debug!("setup_program done in {:?}", start.elapsed()); // Keygen @@ -1019,7 +1341,7 @@ pub fn run_e2e_with_checkpoint< let is_mock_proving = std::env::var("MOCK_PROVING").is_ok(); if let Checkpoint::PrepE2EProving = checkpoint { return E2ECheckpointResult { - proof: None, + proofs: None, vk: Some(vk), next_step: Some(Box::new(move || { _ = run_e2e_proof::( @@ -1041,17 +1363,16 @@ pub fn run_e2e_with_checkpoint< max_steps, &init_full_mem, &ctx.platform, - &ctx.shards, + &ctx.multi_prover, ); tracing::debug!("emulate done in {:?}", start.elapsed()); // Clone some emul_result fields before consuming - let pi = emul_result.pi.clone(); let exit_code = emul_result.exit_code; if let Checkpoint::PrepWitnessGen = checkpoint { return E2ECheckpointResult { - proof: None, + proofs: None, vk: Some(vk), next_step: Some(Box::new(move || { // When we run e2e and halt before generate_witness, this implies we are going to @@ -1062,48 +1383,58 @@ pub fn run_e2e_with_checkpoint< }; } - let zkvm_witness = generate_witness(&ctx.system_config, emul_result, &ctx.program); + let prover = ZKVMProver::new(pk, device); - let mut prover = ZKVMProver::new(pk, device); + let zkvm_witness = generate_witness(&ctx.system_config, emul_result, &ctx.program); - if is_mock_proving { - MockProver::assert_satisfied_full( - &ctx.system_config.zkvm_cs, - ctx.zkvm_fixed_traces.clone(), - &zkvm_witness, - &pi, - &ctx.program, - ); - tracing::info!("Mock proving passed"); - } + let zkvm_proofs = zkvm_witness + .map(|(zkvm_witness, shard_ctx, pi)| { + if is_mock_proving { + MockProver::assert_satisfied_full( + &shard_ctx, + &ctx.system_config.zkvm_cs, + ctx.zkvm_fixed_traces.clone(), + &zkvm_witness, + &pi, + &ctx.program, + ); + tracing::info!("Mock proving passed"); + } - // Run proof phase - let transcript = Transcript::new(b"riscv"); - let start = std::time::Instant::now(); - let zkvm_proof = prover - .create_proof(zkvm_witness, pi, transcript) - .expect("create_proof failed"); - tracing::debug!("proof created in {:?}", start.elapsed()); - tracing::info!("e2e proof stat: {}", zkvm_proof); + // Run proof phase + let transcript = Transcript::new(b"riscv"); + let start = std::time::Instant::now(); + let zkvm_proof = prover + .create_proof(&shard_ctx, zkvm_witness, pi, transcript) + .expect("create_proof failed"); + tracing::debug!( + "{}th shard proof created in {:?}", + shard_ctx.shard_id, + start.elapsed() + ); + tracing::info!("e2e proof stat: {}", zkvm_proof); + zkvm_proof + }) + .collect_vec(); let verifier = ZKVMVerifier::new(vk.clone()); if let Checkpoint::PrepVerify = checkpoint { return E2ECheckpointResult { - proof: Some(zkvm_proof.clone()), + proofs: Some(zkvm_proofs.clone()), vk: Some(vk), next_step: Some(Box::new(move || { - run_e2e_verify(&verifier, zkvm_proof, exit_code, max_steps) + run_e2e_verify(&verifier, zkvm_proofs, exit_code, max_steps) })), }; } let start = std::time::Instant::now(); - run_e2e_verify(&verifier, zkvm_proof.clone(), exit_code, max_steps); + run_e2e_verify(&verifier, zkvm_proofs.clone(), exit_code, max_steps); tracing::debug!("verified in {:?}", start.elapsed()); E2ECheckpointResult { - proof: Some(zkvm_proof), + proofs: Some(zkvm_proofs), vk: Some(vk), next_step: None, } @@ -1123,52 +1454,59 @@ pub fn run_e2e_proof< pk: ZKVMProvingKey, max_steps: usize, is_mock_proving: bool, -) -> ZKVMProof { +) -> Vec> { // Emulate program let emul_result = emulate_program( ctx.program.clone(), max_steps, init_full_mem, &ctx.platform, - &ctx.shards, + &ctx.multi_prover, ); - // clone pi before consuming - let pi = emul_result.pi.clone(); - // Generate witness let zkvm_witness = generate_witness(&ctx.system_config, emul_result, &ctx.program); // proving - let mut prover = ZKVMProver::new(pk, device); - - if is_mock_proving { - MockProver::assert_satisfied_full( - &ctx.system_config.zkvm_cs, - ctx.zkvm_fixed_traces.clone(), - &zkvm_witness, - &pi, - &ctx.program, - ); - tracing::info!("Mock proving passed"); - } + let prover = ZKVMProver::new(pk, device); + + zkvm_witness + .map(|(zkvm_witness, shard_ctx, pi)| { + if is_mock_proving { + if shard_ctx.num_shards > 1 { + todo!("support mock proving on more than 1 shard") + } + MockProver::assert_satisfied_full( + &shard_ctx, + &ctx.system_config.zkvm_cs, + ctx.zkvm_fixed_traces.clone(), + &zkvm_witness, + &pi, + &ctx.program, + ); + tracing::info!("Mock proving passed"); + } - let transcript = Transcript::new(b"riscv"); - prover - .create_proof(zkvm_witness, pi, transcript) - .expect("create_proof failed") + let transcript = Transcript::new(b"riscv"); + prover + .create_proof(&shard_ctx, zkvm_witness, pi, transcript) + .expect("create_proof failed") + }) + .collect_vec() } pub fn run_e2e_verify>( verifier: &ZKVMVerifier, - zkvm_proof: ZKVMProof, + zkvm_proofs: Vec>, exit_code: Option, max_steps: usize, ) { - let transcript = Transcript::new(b"riscv"); + let transcripts = (0..zkvm_proofs.len()) + .map(|_| Transcript::new(b"riscv")) + .collect_vec(); assert!( verifier - .verify_proof_halt(zkvm_proof, transcript, exit_code.is_some()) + .verify_proofs_halt(zkvm_proofs, transcripts, exit_code.is_some()) .expect("verify proof return with error"), ); match exit_code { @@ -1225,19 +1563,18 @@ fn format_segment(platform: &Platform, addr: u32) -> String { } pub fn verify + serde::Serialize>( - zkvm_proof: &ZKVMProof, + zkvm_proofs: Vec>, verifier: &ZKVMVerifier, ) -> Result<(), ZKVMError> { #[cfg(debug_assertions)] { Instrumented::<<::BaseField as PoseidonField>::P>::clear_metrics(); } - let transcript = Transcript::new(b"riscv"); - verifier.verify_proof_halt( - zkvm_proof.clone(), - transcript, - zkvm_proof.has_halt(&verifier.vk), - )?; + let transcripts = (0..zkvm_proofs.len()) + .map(|_| Transcript::new(b"riscv")) + .collect_vec(); + let has_halt = zkvm_proofs.last().unwrap().has_halt(&verifier.vk); + verifier.verify_proofs_halt(zkvm_proofs, transcripts, has_halt)?; // print verification statistics such as hash count #[cfg(debug_assertions)] { @@ -1249,3 +1586,94 @@ pub fn verify + serde::Ser } Ok(()) } + +#[cfg(test)] +mod tests { + use crate::e2e::{MultiProver, ShardContext}; + use ceno_emul::{Cycle, NextCycleAccess}; + + #[test] + fn test_single_prover_shard_ctx() { + for (name, max_cycle_per_shard, executed_instruction, expected_shard) in [ + ("1 shard", 1 << 6, (1 << 6) / 4 - 1, 1), + ( + "max inst + 10, split to 2 shard", + 1 << 6, + (1 << 6) / 4 + 10, + 2, + ), + ] { + test_single_shard_ctx_helper( + name, + max_cycle_per_shard, + executed_instruction, + expected_shard, + ); + } + } + + fn test_single_shard_ctx_helper( + name: &str, + max_cycle_per_shard: Cycle, + executed_instruction: usize, + expected_shard: usize, + ) { + let shard_ctx = ShardContext::new( + MultiProver::new(0, 1, 1 << 3, max_cycle_per_shard), + executed_instruction, + NextCycleAccess::default(), + ); + assert_eq!(shard_ctx.len(), expected_shard, "{name} test case failed"); + assert_eq!( + shard_ctx.first().unwrap().cur_shard_cycle_range.start, + 4, + "{name} test case failed" + ); + assert_eq!( + shard_ctx.last().unwrap().cur_shard_cycle_range.end, + executed_instruction * 4 + 4, + "{name} test case failed" + ); + if shard_ctx.len() > 1 { + for pair in shard_ctx.windows(2) { + assert_eq!( + pair[0].cur_shard_cycle_range.end, pair[1].cur_shard_cycle_range.start, + "{name} test case failed" + ); + } + } + } + + #[test] + fn test_multi_prover_shard_ctx() { + for (name, num_shards, num_prover, expected_num_shards_of_provers) in [ + ("2 provers", 7, 2, vec![4, 3]), + ("2 provers", 10, 3, vec![4, 3, 3]), + ] { + test_multi_shard_ctx_helper( + name, + num_shards, + num_prover, + expected_num_shards_of_provers, + ); + } + } + + fn test_multi_shard_ctx_helper( + name: &str, + num_shards: usize, + num_prover: usize, + expected_num_shards_of_provers: Vec, + ) { + let max_cycle_per_shard = (1 << 8) * 4; + let executed_instruction = (1 << 8) * num_shards - 10; // this will be split into num_shards + for (prover_id, expected_shard) in (0..num_prover).zip(expected_num_shards_of_provers) { + let shard_ctx = ShardContext::new( + MultiProver::new(prover_id, num_prover, 1 << 3, max_cycle_per_shard), + executed_instruction, + NextCycleAccess::default(), + ); + assert_eq!(shard_ctx.len(), expected_shard, "{name} test case failed"); + } + } +} diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 65ef58d2b..7c00e03f6 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -3,7 +3,7 @@ use crate::{ tables::RMMCollections, witness::LkMultiplicity, }; use ceno_emul::StepRecord; -use ff_ext::{ExtensionField, FieldInto}; +use ff_ext::ExtensionField; use gkr_iop::{ chip::Chip, gkr::{GKRCircuit, layer::Layer}, @@ -11,13 +11,13 @@ use gkr_iop::{ utils::lk_multiplicity::Multiplicity, }; use itertools::Itertools; -use multilinear_extensions::{StructuralWitInType, ToExpr, WitIn, util::max_usable_threads}; +use multilinear_extensions::{ToExpr, util::max_usable_threads}; use p3::field::FieldAlgebra; use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, slice::ParallelSlice, }; -use witness::{InstancePaddingStrategy, RowMajorMatrix, set_val}; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; pub mod global; pub mod riscv; @@ -48,15 +48,7 @@ pub trait Instruction { let zero_len = cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - let selector = cb.create_structural_witin( - || "selector", - StructuralWitInType::EqualDistanceSequence { - max_len: 0, - offset: 0, - multi_factor: 0, - descending: false, - }, - ); + let selector = cb.create_placeholder_structural_witin(|| "selector"); let selector_type = SelectorType::Prefix(selector.expr()); // all shared the same selector @@ -114,7 +106,6 @@ pub trait Instruction { // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` assert!(num_structural_witin == 0 || num_structural_witin == 1); let num_structural_witin = num_structural_witin.max(1); - let selector_witin = WitIn { id: 0 }; let nthreads = max_usable_threads(); let num_instance_per_batch = if steps.len() > 256 { @@ -148,7 +139,7 @@ pub trait Instruction { .zip_eq(structural_instance.chunks_mut(num_structural_witin)) .zip_eq(steps) .map(|((instance, structural_instance), step)| { - set_val!(structural_instance, selector_witin, E::BaseField::ONE); + *structural_instance.last_mut().unwrap() = E::BaseField::ONE; Self::assign_instance( config, &mut shard_ctx, diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index c98cb3634..33ae57b0f 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -22,12 +22,10 @@ use gkr_iop::{ selector::SelectorType, }; use itertools::{Itertools, chain}; -use multilinear_extensions::{ - Expression, StructuralWitInType::EqualDistanceSequence, ToExpr, WitIn, util::max_usable_threads, -}; +use multilinear_extensions::{Expression, ToExpr, WitIn, util::max_usable_threads}; use p3::{ field::{Field, FieldAlgebra}, - matrix::dense::RowMajorMatrix, + matrix::{Matrix, dense::RowMajorMatrix}, symmetric::Permutation, }; use rayon::{ @@ -370,6 +368,21 @@ impl GlobalChip { Ok(()) } + + pub fn extract_ec_sum( + config: &GlobalConfig, + rmm: &witness::RowMajorMatrix<::BaseField>, + ) -> Vec<::BaseField> { + assert!(rmm.height() >= 2); + let instance = &rmm[rmm.height() - 2]; + + config + .x + .iter() + .chain(config.y.iter()) + .map(|witin| instance[witin.id as usize]) + .collect_vec() + } } impl TableCircuit for GlobalChip { @@ -395,36 +408,9 @@ impl TableCircuit for GlobalChip { param: &ProgramParams, ) -> Result<(Self::TableConfig, Option>), crate::error::ZKVMError> { // create three selectors: selector_r, selector_w, selector_zero - let selector_r = cb.create_structural_witin( - || "selector_r", - // this is just a placeholder, the actural type is SelectorType::Prefix() - EqualDistanceSequence { - max_len: 0, - offset: 0, - multi_factor: 0, - descending: false, - }, - ); - let selector_w = cb.create_structural_witin( - || "selector_w", - // this is just a placeholder, the actural type is SelectorType::Prefix() - EqualDistanceSequence { - max_len: 0, - offset: 0, - multi_factor: 0, - descending: false, - }, - ); - let selector_zero = cb.create_structural_witin( - || "selector_zero", - // this is just a placeholder, the actural type is SelectorType::Prefix() - EqualDistanceSequence { - max_len: 0, - offset: 0, - multi_factor: 0, - descending: false, - }, - ); + let selector_r = cb.create_placeholder_structural_witin(|| "selector_r"); + let selector_w = cb.create_placeholder_structural_witin(|| "selector_w"); + let selector_zero = cb.create_placeholder_structural_witin(|| "selector_zero"); let config = Self::construct_circuit(cb, param)?; @@ -652,20 +638,23 @@ impl TableCircuit for GlobalChip { #[cfg(test)] mod tests { - use std::sync::Arc; - + use either::Either; use ff_ext::{BabyBearExt4, FromUniformBytes, PoseidonField}; use itertools::Itertools; use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; use p3::babybear::BabyBear; use rand::thread_rng; + use std::{ops::Index, sync::Arc}; use tracing_forest::{ForestLayer, util::LevelFilter}; use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; use transcript::BasicTranscript; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::global::{GlobalChip, GlobalChipInput, GlobalRecord}, + instructions::{ + global::{GlobalChip, GlobalChipInput, GlobalRecord}, + riscv::constants::GLOBAL_RW_SUM_IDX, + }, scheme::{ PublicValues, create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, septic_curve::SepticPoint, verifier::ZKVMVerifier, @@ -779,6 +768,17 @@ mod tests { ) .unwrap(); + // api extract ec sum from rmm witness + assert_eq!( + public_value + .to_vec::() + .into_iter() + .skip(GLOBAL_RW_SUM_IDX) + .flatten() + .collect_vec(), + GlobalChip::extract_ec_sum(&config, &witness[0]) + ); + let composed_cs = ComposedConstrainSystem { zkvm_v1_css: cs, gkr_circuit, @@ -801,11 +801,17 @@ mod tests { .into_iter() .map(|v| Arc::new(v.into_mle())) .collect_vec(); + let pub_io_evals = public_value + .to_vec::() + .into_iter() + .map(|v| Either::Right(E::from(*v.index(0)))) + .collect_vec(); let proof_input = ProofInput { witness: witness[0].to_mles().into_iter().map(Arc::new).collect(), structural_witness: witness[1].to_mles().into_iter().map(Arc::new).collect(), fixed: vec![], public_input: public_input_mles.clone(), + pub_io_evals, num_instances: vec![n_global_writes as usize, n_global_reads as usize], has_ecc_ops: true, }; @@ -827,12 +833,13 @@ mod tests { .iter() .map(|mle| mle.evaluate(&point[..mle.num_vars()])) .collect_vec(); - let vrf_point = verifier - .verify_opcode_proof( + let (vrf_point, _) = verifier + .verify_chip_proof( "global", &pk.vk, &proof, &pi_evals, + &public_value.to_vec::(), &mut transcript, 2, &PointAndEval::default(), diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index d98412b6f..e02b77f5c 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -9,7 +9,7 @@ pub const INIT_PC_IDX: usize = 2; pub const INIT_CYCLE_IDX: usize = 3; pub const END_PC_IDX: usize = 4; pub const END_CYCLE_IDX: usize = 5; -pub const END_SHARD_ID_IDX: usize = 6; +pub const SHARD_ID_IDX: usize = 6; pub const PUBLIC_IO_IDX: usize = 7; pub const GLOBAL_RW_SUM_IDX: usize = PUBLIC_IO_IDX + 2; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index bf38a67c4..986a9e8cc 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -11,7 +11,7 @@ use crate::{ ecall_insn::EcallInstructionConfig, }, }, - structs::ProgramParams, + structs::{ProgramParams, RAMType}, witness::LkMultiplicity, }; use ceno_emul::{StepRecord, Tracer}; @@ -71,7 +71,7 @@ impl Instruction for HaltInstruction { fn assign_instance( config: &Self::InstructionConfig, - _shard_ctx: &mut ShardContext, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -87,23 +87,32 @@ impl Instruction for HaltInstruction { step.pc().after.0 ); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = step.cycle() - current_shard_offset_cycle; + let rs2_prev_cycle = shard_ctx.aligned_prev_ts(step.rs2().unwrap().previous_cycle); // the access of X10 register is stored in rs2() - set_val!( - instance, - config.prev_x10_ts, - step.rs2().unwrap().previous_cycle + set_val!(instance, config.prev_x10_ts, rs2_prev_cycle); + + shard_ctx.send( + RAMType::Register, + step.rs2().unwrap().addr, + ceno_emul::Platform::reg_arg0() as u64, + step.cycle() + Tracer::SUBCYCLE_RS2, + step.rs2().unwrap().previous_cycle, + step.rs2().unwrap().value, + None, ); config.lt_x10_cfg.assign_instance( instance, lk_multiplicity, - step.rs2().unwrap().previous_cycle, - step.cycle() + Tracer::SUBCYCLE_RS2, + rs2_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS2, )?; config .ecall_cfg - .assign_instance::(instance, lk_multiplicity, step)?; + .assign_instance::(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index dccdf34a2..b0e060882 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -128,8 +128,7 @@ impl Instruction for KeccakInstruction { let (out_evals, mut chip) = layout.finalize(cb); - let layer = - Layer::from_circuit_builder(cb, "Rounds".to_string(), layout.n_challenges, out_evals); + let layer = Layer::from_circuit_builder(cb, Self::name(), layout.n_challenges, out_evals); chip.add_layer(layer); let circuit = chip.gkr_circuit(); diff --git a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs index b7eb20ea4..b00bfb2bb 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs @@ -4,6 +4,7 @@ use crate::{ general::InstFetch, }, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, tables::InsnRecord, @@ -71,27 +72,29 @@ impl EcallInstructionConfig { pub fn assign_instance( &self, instance: &mut [E::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(step.rs1().unwrap().previous_cycle); + let shard_cycle = step.cycle() - current_shard_offset_cycle; set_val!(instance, self.pc, step.pc().before.0 as u64); - set_val!(instance, self.ts, step.cycle()); + set_val!(instance, self.ts, shard_cycle); lk_multiplicity.fetch(step.pc().before.0); // the access of X5 register is stored in rs1() - set_val!( - instance, - self.prev_x5_ts, - step.rs1().unwrap().previous_cycle - ); + set_val!(instance, self.prev_x5_ts, shard_prev_cycle); self.lt_x5_cfg.assign_instance( instance, lk_multiplicity, - step.rs1().unwrap().previous_cycle, - step.cycle() + Tracer::SUBCYCLE_RS1, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS1, )?; + // skip shard_ctx.send() as ecall_halt is the last instruction + Ok(()) } } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 9957f2122..94dd5da39 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -424,49 +424,39 @@ impl Rv32imConfig { let mut secp256k1_add_records = Vec::new(); let mut secp256k1_double_records = Vec::new(); let mut secp256k1_decompress_records = Vec::new(); - steps - .into_iter() - .filter_map(|step| { - if shard_ctx.is_current_shard_cycle(step.cycle()) { - Some(step) - } else { - None + + steps.into_iter().for_each(|record| { + let insn_kind = record.insn.kind; + match insn_kind { + // ecall / halt + InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => { + halt_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == KeccakSpec::CODE => { + keccak_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Bn254AddSpec::CODE => { + bn254_add_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Bn254DoubleSpec::CODE => { + bn254_double_records.push(record); } - }) - .for_each(|record| { - let insn_kind = record.insn.kind; - match insn_kind { - // ecall / halt - InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => { - halt_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == KeccakSpec::CODE => { - keccak_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254AddSpec::CODE => { - bn254_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254DoubleSpec::CODE => { - bn254_double_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1AddSpec::CODE => { - secp256k1_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DoubleSpec::CODE => { - secp256k1_double_records.push(record); - } - InsnKind::ECALL - if record.rs1().unwrap().value == Secp256k1DecompressSpec::CODE => - { - secp256k1_decompress_records.push(record); - } - // other type of ecalls are handled by dummy ecall instruction - _ => { - // it's safe to unwrap as all_records are initialized with Vec::new() - all_records.get_mut(&insn_kind).unwrap().push(record); - } + InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1AddSpec::CODE => { + secp256k1_add_records.push(record); } - }); + InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DoubleSpec::CODE => { + secp256k1_double_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DecompressSpec::CODE => { + secp256k1_decompress_records.push(record); + } + // other type of ecalls are handled by dummy ecall instruction + _ => { + // it's safe to unwrap as all_records are initialized with Vec::new() + all_records.get_mut(&insn_kind).unwrap().push(record); + } + } + }); for (insn_kind, (_, records)) in izip!(InsnKind::iter(), &all_records).sorted_by_key(|(_, (_, a))| Reverse(a.len())) diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 900672a3d..c9b89552d 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -4,13 +4,13 @@ use crate::{ instructions::global::GlobalChip, structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ - DynVolatileRamTable, HeapInitCircuit, HeapTable, HintsCircuit, LocalFinalCircuit, - MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RegTable, - RegTableInitCircuit, StackInitCircuit, StackTable, StaticMemInitCircuit, StaticMemTable, - TableCircuit, + DynVolatileRamTable, HeapInitCircuit, HeapTable, HintsInitCircuit, HintsTable, + LocalFinalCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOInitCircuit, + PubIOTable, RegTable, RegTableInitCircuit, StackInitCircuit, StackTable, + StaticMemInitCircuit, StaticMemTable, TableCircuit, }, }; -use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; +use ceno_emul::{Addr, IterAddresses, WORD_SIZE, Word}; use ff_ext::ExtensionField; use itertools::{Itertools, chain}; use std::{collections::HashSet, iter::zip, ops::Range, sync::Arc}; @@ -22,9 +22,9 @@ pub struct MmuConfig<'a, E: ExtensionField> { /// Initialization of memory with static addresses. pub static_mem_init_config: as TableCircuit>::TableConfig, /// Initialization of public IO. - pub public_io_config: as TableCircuit>::TableConfig, + pub public_io_init_config: as TableCircuit>::TableConfig, /// Initialization of hints. - pub hints_config: as TableCircuit>::TableConfig, + pub hints_init_config: as TableCircuit>::TableConfig, /// Initialization of heap. pub heap_init_config: as TableCircuit>::TableConfig, /// Initialization of stack. @@ -42,9 +42,9 @@ impl MmuConfig<'_, E> { let static_mem_init_config = cs.register_table_circuit::>(); - let public_io_config = cs.register_table_circuit::>(); + let public_io_init_config = cs.register_table_circuit::>(); - let hints_config = cs.register_table_circuit::>(); + let hints_init_config = cs.register_table_circuit::>(); let stack_init_config = cs.register_table_circuit::>(); let heap_init_config = cs.register_table_circuit::>(); let local_final_circuit = cs.register_table_circuit::>(); @@ -53,8 +53,8 @@ impl MmuConfig<'_, E> { Self { reg_init_config, static_mem_init_config, - public_io_config, - hints_config, + public_io_init_config, + hints_init_config, stack_init_config, heap_init_config, local_final_circuit, @@ -90,8 +90,12 @@ impl MmuConfig<'_, E> { static_mem_init, ); - fixed.register_table_circuit::>(cs, &self.public_io_config, io_addrs); - fixed.register_table_circuit::>(cs, &self.hints_config, &()); + fixed.register_table_circuit::>( + cs, + &self.public_io_init_config, + io_addrs, + ); + fixed.register_table_circuit::>(cs, &self.hints_init_config, &()); fixed.register_table_circuit::>(cs, &self.stack_init_config, &()); fixed.register_table_circuit::>(cs, &self.heap_init_config, &()); fixed.register_table_circuit::>(cs, &self.local_final_circuit, &()); @@ -99,14 +103,13 @@ impl MmuConfig<'_, E> { } #[allow(clippy::too_many_arguments)] - pub fn assign_table_circuit( + pub fn assign_init_table_circuit( &self, cs: &ZKVMConstraintSystem, - shard_ctx: &ShardContext, witness: &mut ZKVMWitnesses, reg_final: &[MemFinalRecord], static_mem_final: &[MemFinalRecord], - io_cycles: &[Cycle], + io_final: &[MemFinalRecord], hints_final: &[MemFinalRecord], stack_final: &[MemFinalRecord], heap_final: &[MemFinalRecord], @@ -123,8 +126,16 @@ impl MmuConfig<'_, E> { static_mem_final, )?; - witness.assign_table_circuit::>(cs, &self.public_io_config, io_cycles)?; - witness.assign_table_circuit::>(cs, &self.hints_config, hints_final)?; + witness.assign_table_circuit::>( + cs, + &self.public_io_init_config, + io_final, + )?; + witness.assign_table_circuit::>( + cs, + &self.hints_init_config, + hints_final, + )?; witness.assign_table_circuit::>( cs, &self.stack_init_config, @@ -135,8 +146,24 @@ impl MmuConfig<'_, E> { &self.heap_init_config, heap_final, )?; + Ok(()) + } + #[allow(clippy::too_many_arguments)] + pub fn assign_continuation_circuit( + &self, + cs: &ZKVMConstraintSystem, + shard_ctx: &ShardContext, + witness: &mut ZKVMWitnesses, + reg_final: &[MemFinalRecord], + static_mem_final: &[MemFinalRecord], + io_final: &[MemFinalRecord], + hints_final: &[MemFinalRecord], + stack_final: &[MemFinalRecord], + heap_final: &[MemFinalRecord], + ) -> Result<(), ZKVMError> { let all_records = vec![ + (InstancePaddingStrategy::Default, io_final), (InstancePaddingStrategy::Default, reg_final), (InstancePaddingStrategy::Default, static_mem_final), ( @@ -146,6 +173,13 @@ impl MmuConfig<'_, E> { }), stack_final, ), + ( + InstancePaddingStrategy::Custom({ + let params = cs.params.clone(); + Arc::new(move |row: u64, _: u64| HintsTable::addr(¶ms, row as usize) as u64) + }), + hints_final, + ), ( InstancePaddingStrategy::Custom({ let params = cs.params.clone(); @@ -163,8 +197,11 @@ impl MmuConfig<'_, E> { &self.local_final_circuit, &(shard_ctx, all_records.as_slice()), )?; - witness.assign_global_chip_circuit(cs, shard_ctx, &self.ram_bus_circuit)?; - + witness.assign_global_chip_circuit( + cs, + &(shard_ctx, all_records.as_slice()), + &self.ram_bus_circuit, + )?; Ok(()) } diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index 0ced182b8..716462927 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -12,10 +12,12 @@ impl ZKVMConstraintSystem { self, pp: PCS::ProverParam, vp: PCS::VerifierParam, + entry_pc: u32, mut vm_fixed_traces: ZKVMFixedTraces, ) -> Result, ZKVMError> { let mut vm_pk = ZKVMProvingKey::new(pp.clone(), vp); let mut fixed_traces = BTreeMap::new(); + let mut fixed_traces_no_omc_init = BTreeMap::new(); for (circuit_index, (c_name, cs)) in self.circuit_css.into_iter().enumerate() { // fixed_traces is optional @@ -29,6 +31,11 @@ impl ZKVMConstraintSystem { vm_pk .circuit_index_fixed_num_instances .insert(circuit_index, fixed_trace_rmm.num_instances()); + + if !cs.with_omc_init_only() { + fixed_traces_no_omc_init.insert(circuit_index, fixed_trace_rmm.clone()); + } + fixed_traces.insert(circuit_index, fixed_trace_rmm); } @@ -36,11 +43,13 @@ impl ZKVMConstraintSystem { assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); } - vm_pk.commit_fixed(fixed_traces)?; + vm_pk.commit_fixed(fixed_traces, fixed_traces_no_omc_init)?; vm_pk.initial_global_state_expr = self.initial_global_state_expr; vm_pk.finalize_global_state_expr = self.finalize_global_state_expr; + vm_pk.set_program_entry_pc(entry_pc); + Ok(vm_pk) } } diff --git a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs index 51bf0092a..e1b259501 100644 --- a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs +++ b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs @@ -918,6 +918,7 @@ pub fn run_keccakf + 'stat &[], &[], &[], + &[], ); exit_span!(span); @@ -1000,6 +1001,7 @@ pub fn run_keccakf + 'stat &out_evals, &[], &[], + &[], &mut verifier_transcript, &selector_ctxs, ) diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index bb105899d..7f8e2bd62 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -20,7 +20,7 @@ use gkr_iop::{ use itertools::{Itertools, iproduct, izip, zip_eq}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - Expression, StructuralWitIn, StructuralWitInType, ToExpr, WitIn, + Expression, StructuralWitIn, ToExpr, WitIn, mle::PointAndEval, util::{ceil_log2, max_usable_threads}, }; @@ -185,15 +185,7 @@ impl KeccakLayout { // cb.create_fixed(|| format!("keccak_fixed_{}", id)) // })), array::from_fn(|id| { - cb.create_structural_witin( - || format!("keccak_eq_{}", id), - StructuralWitInType::EqualDistanceSequence { - max_len: 0, - offset: 0, - multi_factor: 0, - descending: false, - }, - ) + cb.create_placeholder_structural_witin(|| format!("keccak_eq_{}", id)) }), ) }; @@ -992,8 +984,12 @@ pub fn setup_gkr_circuit() let (out_evals, mut chip) = layout.finalize(&mut cb); - let layer = - Layer::from_circuit_builder(&cb, "Rounds".to_string(), layout.n_challenges, out_evals); + let layer = Layer::from_circuit_builder( + &cb, + "lookup_keccak".to_string(), + layout.n_challenges, + out_evals, + ); chip.add_layer(layer); Ok(( @@ -1155,6 +1151,7 @@ pub fn run_faster_keccakf &structural_witness, &fixed, &[], + &[], &challenges, ); exit_span!(span); @@ -1265,6 +1262,7 @@ pub fn run_faster_keccakf gkr_proof.clone(), &out_evals, &[], + &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index 76df2b06a..6e2dfa62c 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -41,7 +41,7 @@ use gkr_iop::{ use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - Expression, StructuralWitInType, ToExpr, WitIn, + Expression, ToExpr, WitIn, util::{ceil_log2, max_usable_threads}, }; use num::BigUint; @@ -132,15 +132,7 @@ impl WeierstrassAddAssignLayout { slope_times_p_x_minus_x: FieldOpCols::create(cb, || "slope_times_p_x_minus_x"), }; - let eq = cb.create_structural_witin( - || "weierstrass_add_eq", - StructuralWitInType::EqualDistanceSequence { - max_len: 0, - offset: 0, - multi_factor: 0, - descending: false, - }, - ); + let eq = cb.create_placeholder_structural_witin(|| "weierstrass_add_eq"); let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { sel_mem_read: sel.clone(), @@ -713,6 +705,7 @@ pub fn run_weierstrass_add< &structural_witness, &fixed, &[], + &[], &challenges, ); exit_span!(span); @@ -786,6 +779,7 @@ pub fn run_weierstrass_add< gkr_proof.clone(), &out_evals, &[], + &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index 9f37a26c7..a88062d08 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -41,7 +41,7 @@ use gkr_iop::{ use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - Expression, StructuralWitInType, ToExpr, WitIn, + Expression, ToExpr, WitIn, macros::{entered_span, exit_span}, util::{ceil_log2, max_usable_threads}, }; @@ -150,15 +150,7 @@ impl neg_y: FieldOpCols::create(cb, || "neg_y"), }; - let eq = cb.create_structural_witin( - || "weierstrass_decompress_eq", - StructuralWitInType::EqualDistanceSequence { - max_len: 0, - offset: 0, - multi_factor: 0, - descending: false, - }, - ); + let eq = cb.create_placeholder_structural_witin(|| "weierstrass_decompress_eq"); let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { sel_mem_read: sel.clone(), @@ -693,6 +685,7 @@ pub fn run_weierstrass_decompress< &structural_witness, &fixed, &[], + &[], &challenges, ); exit_span!(span); @@ -766,6 +759,7 @@ pub fn run_weierstrass_decompress< gkr_proof.clone(), &out_evals, &[], + &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index 7f9a02997..a88f49eec 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -41,7 +41,7 @@ use gkr_iop::{ use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - Expression, StructuralWitInType, ToExpr, WitIn, + Expression, ToExpr, WitIn, util::{ceil_log2, max_usable_threads}, }; use num::BigUint; @@ -134,15 +134,7 @@ impl slope_times_p_x_minus_x: FieldOpCols::create(cb, || "slope_times_p_x_minus_x"), }; - let eq = cb.create_structural_witin( - || "weierstrass_double_eq", - StructuralWitInType::EqualDistanceSequence { - max_len: 0, - offset: 0, - multi_factor: 0, - descending: false, - }, - ); + let eq = cb.create_placeholder_structural_witin(|| "weierstrass_double_eq"); let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { sel_mem_read: sel.clone(), @@ -715,6 +707,7 @@ pub fn run_weierstrass_double< &structural_witness, &fixed, &[], + &[], &challenges, ); exit_span!(span); @@ -788,6 +781,7 @@ pub fn run_weierstrass_double< gkr_proof.clone(), &out_evals, &[], + &[], &challenges, &mut verifier_transcript, &selector_ctxs, diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index aa3928153..84693d78c 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -71,14 +71,14 @@ pub struct ZKVMChipProof { /// each field will be interpret to (constant) polynomial #[derive(Default, Clone, Debug)] pub struct PublicValues { - exit_code: u32, - init_pc: u32, - init_cycle: u64, - end_pc: u32, - end_cycle: u64, - shard_id: u32, - public_io: Vec, - global_sum: Vec, + pub exit_code: u32, + pub init_pc: u32, + pub init_cycle: u64, + pub end_pc: u32, + pub end_cycle: u64, + pub shard_id: u32, + pub public_io: Vec, + pub global_sum: Vec, } impl PublicValues { diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index cebe79899..290640b7a 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -1,17 +1,15 @@ use super::hal::{ - DeviceTransporter, MainSumcheckProver, OpeningProver, ProverDevice, TowerProver, TraceCommitter, + DeviceTransporter, MainSumcheckEvals, MainSumcheckProver, OpeningProver, ProverDevice, + TowerProver, TraceCommitter, }; use crate::{ - circuit_builder::ConstraintSystem, + e2e::ShardContext, error::ZKVMError, scheme::{ - constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE}, - hal::{DeviceProvingKey, EccQuarkProver, MainSumcheckEvals, ProofInput, TowerProverSpec}, + constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, + hal::{DeviceProvingKey, EccQuarkProver, ProofInput, TowerProverSpec}, septic_curve::{SepticExtension, SepticPoint, SymbolicSepticExtension}, - utils::{ - infer_tower_logup_witness, infer_tower_product_witness, masked_mle_split_to_chunks, - wit_infer_by_expr, - }, + utils::{infer_tower_logup_witness, infer_tower_product_witness}, }, structs::{ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs}, }; @@ -26,7 +24,7 @@ use gkr_iop::{ use itertools::{Itertools, chain}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - Expression, Instance, WitnessId, + Expression, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, util::ceil_log2, virtual_poly::build_eq_x_r_vec, @@ -193,7 +191,6 @@ impl CpuEccProver { .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); - // zerocheck: 0 = s[1,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) with b != (1,...,1) exprs_add.extend( (s.clone() * (&x0 - &x3) - (&y0 + &y3)) @@ -214,7 +211,6 @@ impl CpuEccProver { .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); - // 0 = (y[1,b] - y[b,0]) exprs_bypass.extend( (&y3 - &y0) @@ -235,7 +231,6 @@ impl CpuEccProver { let rt = state.collect_raw_challenges(); let evals = state.get_mle_flatten_final_evaluations(); - assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); // 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[1,rt] assert_eq!(evals.len(), 2 + SEPTIC_EXTENSION_DEGREE * 7); @@ -272,6 +267,7 @@ impl CpuEccProver { assert_eq!(y3[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 6 + i]); } } + assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); EccQuarkProof { zerocheck_proof, @@ -509,7 +505,7 @@ impl> TraceCommitter> { fn commit_traces<'a>( - &mut self, + &self, traces: BTreeMap>, ) -> ( Vec>, @@ -554,8 +550,6 @@ impl> TowerProver, input: &ProofInput<'a, CpuBackend>, records: &'c [ArcMultilinearExtension<'b, E>], - is_padded: bool, - challenges: &[E; 2], ) -> ( Vec>>, Vec>>, @@ -568,13 +562,9 @@ impl> TowerProver> TowerProver>(); let w_set_last_layer = r_set_last_layer.split_off(r_set_wit.len()); let mut lk_numerator_last_layer = lk_n_wit .iter() .chain(lk_d_wit.iter()) - .enumerate() - .map(|(i, wit)| { - if is_padded { - wit.as_view_chunks(NUM_FANIN) - } else { - let default = if i < lk_n_wit.len() { - // For table circuit, the last layer's length is always two's power - // so the padding will not happen, therefore we can use any value here. - E::ONE - } else { - chip_record_alpha - }; - masked_mle_split_to_chunks( - wit, - num_instances_with_rotation, - NUM_FANIN_LOGUP, - default, - ) - } - }) + .map(|wit| wit.as_view_chunks(NUM_FANIN)) .collect::>(); let lk_denominator_last_layer = lk_numerator_last_layer.split_off(lk_n_wit.len()); exit_span!(span); @@ -770,8 +735,7 @@ impl> TowerProver, input: &ProofInput<'a, CpuBackend>, records: &'c [Arc>], - is_padded: bool, - challenges: &[E; 2], + _challenges: &[E; 2], transcript: &mut impl Transcript, ) -> TowerRelationOutput where @@ -781,7 +745,7 @@ impl> TowerProver> TowerProver> MainSumcheckProver> for CpuProver> { - #[allow(clippy::type_complexity)] - #[tracing::instrument(skip_all, name = "table_witness", fields(profiling_2), level = "trace")] - fn table_witness<'a>( - &self, - input: &ProofInput<'a, CpuBackend< as ProverBackend>::E, PCS>>, - cs: &ConstraintSystem< as ProverBackend>::E>, - challenges: &[ as ProverBackend>::E], - ) -> Vec as ProverBackend>::MultilinearPoly<'a>>> { - // main constraint: lookup denominator and numerator record witness inference - let span = entered_span!("witness_infer", profiling_2 = true); - let records: Vec> = cs - .r_table_expressions - .par_iter() - .map(|r| &r.expr) - .chain(cs.r_expressions.par_iter()) - .chain(cs.w_table_expressions.par_iter().map(|w| &w.expr)) - .chain(cs.w_expressions.par_iter()) - .chain( - cs.lk_table_expressions - .par_iter() - .map(|lk| &lk.multiplicity), - ) - .chain(cs.lk_table_expressions.par_iter().map(|lk| &lk.values)) - .chain(cs.lk_expressions.par_iter()) - .map(|expr| { - wit_infer_by_expr( - expr, - cs.num_witin, - cs.num_structural_witin, - cs.num_fixed as WitnessId, - &input.fixed, - &input.witness, - &input.structural_witness, - &input.public_input, - challenges, - ) - }) - .collect(); - exit_span!(span); - records - } - #[allow(clippy::type_complexity)] #[tracing::instrument( skip_all, @@ -874,124 +796,99 @@ impl> MainSumcheckProver 1" - ); - match mle.evaluations() { - FieldType::Base(smart_slice) => E::from(smart_slice[0]), - FieldType::Ext(smart_slice) => smart_slice[0], - _ => unreachable!(), - } - }) - .collect_vec(); - let selector_ctxs = if cs.ec_final_sum.is_empty() { - // it's not global chip - vec![ - SelectorContext { - offset: 0, - num_instances, - num_vars: num_var_with_rotation, - }; - gkr_circuit - .layers - .first() - .map(|layer| layer.out_sel_and_eval_exprs.len()) - .unwrap_or(0) - ] - } else { - // it's global chip - vec![ - SelectorContext { - offset: 0, - num_instances: input.num_instances[0], - num_vars: num_var_with_rotation, - }, - SelectorContext { - offset: input.num_instances[0], - num_instances: input.num_instances[1], - num_vars: num_var_with_rotation, - }, - SelectorContext { - offset: 0, - num_instances, - num_vars: num_var_with_rotation, - }, - ] - }; - let GKRProverOutput { - gkr_proof, - opening_evaluations, - } = gkr_circuit.prove::, CpuProver<_>>( - num_threads, - num_var_with_rotation, - gkr::GKRCircuitWitness { - layers: vec![LayerWitness( - chain!(&input.witness, &input.structural_witness, &input.fixed) - .cloned() - .collect_vec(), - )], + let Some(gkr_circuit) = gkr_circuit else { + panic!("empty gkr circuit") + }; + let pub_io_mles = cs + .instance_openings + .iter() + .map(|instance| input.public_input[instance.0].clone()) + .collect_vec(); + let selector_ctxs = if cs.ec_final_sum.is_empty() { + // it's not global chip + vec![ + SelectorContext { + offset: 0, + num_instances, + num_vars: num_var_with_rotation, + }; + gkr_circuit + .layers + .first() + .map(|layer| layer.out_sel_and_eval_exprs.len()) + .unwrap_or(0) + ] + } else { + // it's global chip + vec![ + SelectorContext { + offset: 0, + num_instances: input.num_instances[0], + num_vars: num_var_with_rotation, }, - // eval value doesnt matter as it wont be used by prover - &vec![PointAndEval::new(rt_tower, E::ZERO); gkr_circuit.final_out_evals.len()], - &pub_io_evals, - challenges, - transcript, - &selector_ctxs, - )?; - Ok(( - opening_evaluations[0].point.clone(), - MainSumcheckEvals { - wits_in_evals: opening_evaluations - .iter() - .take(cs.num_witin as usize) - .map(|Evaluation { value, .. }| value) - .copied() - .collect_vec(), - fixed_in_evals: opening_evaluations - .iter() - .skip((cs.num_witin + cs.num_structural_witin) as usize) - .take(cs.num_fixed) - .map(|Evaluation { value, .. }| value) - .copied() - .collect_vec(), + SelectorContext { + offset: input.num_instances[0], + num_instances: input.num_instances[1], + num_vars: num_var_with_rotation, }, - None, - Some(gkr_proof), - )) - } else { - let (wits_in_evals, fixed_in_evals, main_sumcheck_proof, rt) = { - let span = entered_span!("fixed::evals + witin::evals"); - let mut evals = input - .witness - .par_iter() - .chain(input.fixed.par_iter()) - .map(|poly| poly.evaluate(&rt_tower[..poly.num_vars()])) - .collect::>(); - let fixed_in_evals = evals.split_off(input.witness.len()); - let wits_in_evals = evals; - exit_span!(span); - (wits_in_evals, fixed_in_evals, None, rt_tower) - }; - - Ok(( - rt, - MainSumcheckEvals { - wits_in_evals, - fixed_in_evals, + SelectorContext { + offset: 0, + num_instances, + num_vars: num_var_with_rotation, }, - main_sumcheck_proof, - None, - )) - } + ] + }; + let GKRProverOutput { + gkr_proof, + opening_evaluations, + mut rt, + } = gkr_circuit.prove::, CpuProver<_>>( + num_threads, + num_var_with_rotation, + gkr::GKRCircuitWitness { + layers: vec![LayerWitness( + chain!( + &input.witness, + &input.fixed, + &pub_io_mles, + &input.structural_witness, + ) + .cloned() + .collect_vec(), + )], + }, + // eval value doesnt matter as it wont be used by prover + &vec![PointAndEval::new(rt_tower, E::ZERO); gkr_circuit.final_out_evals.len()], + &input + .pub_io_evals + .iter() + .map(|v| v.map_either(E::from, |v| v).into_inner()) + .collect_vec(), + challenges, + transcript, + &selector_ctxs, + )?; + assert_eq!(rt.len(), 1, "TODO support multi-layer gkr iop"); + Ok(( + rt.remove(0), + MainSumcheckEvals { + wits_in_evals: opening_evaluations + .iter() + .take(cs.num_witin as usize) + .map(|Evaluation { value, .. }| value) + .copied() + .collect_vec(), + fixed_in_evals: opening_evaluations + .iter() + .skip(cs.num_witin as usize) + .take(cs.num_fixed) + .map(|Evaluation { value, .. }| value) + .copied() + .collect_vec(), + }, + None, + Some(gkr_proof), + )) } } @@ -1045,6 +942,7 @@ impl> DeviceTransporter as ProverBackend>::E, @@ -1052,9 +950,13 @@ impl> DeviceTransporter, >, ) -> DeviceProvingKey<'_, CpuBackend> { - let pcs_data = pk.fixed_commit_wd.clone().unwrap(); - let fixed_mles = - PCS::get_arc_mle_witness_from_commitment(pk.fixed_commit_wd.as_ref().unwrap()); + let pcs_data = if shard_ctx.is_first_shard() { + pk.fixed_commit_wd.clone().unwrap() + } else { + pk.fixed_no_omc_init_commit_wd.clone().unwrap() + }; + + let fixed_mles = PCS::get_arc_mle_witness_from_commitment(pcs_data.as_ref()); DeviceProvingKey { pcs_data, diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index 07e5adb4d..be0a85b21 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -45,7 +45,7 @@ use gkr_iop::gpu::gpu_prover::*; pub struct GpuTowerProver; -use crate::scheme::constants::NUM_FANIN; +use crate::{e2e::ShardContext, scheme::constants::NUM_FANIN}; use gkr_iop::gpu::{ArcMultilinearExtensionGpu, MultilinearExtensionGpu}; // Extract out_evals from GPU-built tower witnesses @@ -410,8 +410,6 @@ impl> TowerProver, _input: &ProofInput<'a, GpuBackend>, _records: &'c [ArcMultilinearExtensionGpu<'b, E>], - _is_padded: bool, - _challenges: &[E; 2], ) -> ( Vec>>, Vec>>, @@ -436,7 +434,6 @@ impl> TowerProver, input: &ProofInput<'a, GpuBackend>, records: &'c [ArcMultilinearExtensionGpu<'b, E>], - _is_padded: bool, challenges: &[E; 2], transcript: &mut impl Transcript, ) -> TowerRelationOutput @@ -508,122 +505,6 @@ impl> TowerProver> MainSumcheckProver> for GpuProver> { - #[allow(clippy::type_complexity)] - #[tracing::instrument(skip_all, name = "table_witness", fields(profiling_2), level = "trace")] - fn table_witness<'a>( - &self, - input: &ProofInput<'a, GpuBackend>, - cs: &ConstraintSystem< as ProverBackend>::E>, - challenges: &[ as ProverBackend>::E], - ) -> Vec as ProverBackend>::MultilinearPoly<'a>>> { - assert!( - !cs.lk_table_expressions.is_empty() - || !cs.r_table_expressions.is_empty() - || !cs.w_table_expressions.is_empty(), - "assert table circuit" - ); - - assert!( - cs.r_table_expressions - .iter() - .zip_eq(cs.w_table_expressions.iter()) - .all(|(r, w)| r.table_spec.len == w.table_spec.len) - ); - - let span = entered_span!("preprocess", profiling_2 = true); - let layer_witin = input - .witness - .iter() - .chain(&input.structural_witness) - .chain(&input.fixed) - .chain(&input.public_input) - .map(|w| w.as_ref()) - .collect_vec(); - let num_vars = input.witness[0].num_vars(); - - // main constraint: lookup denominator and numerator record witness inference - let (num_non_zero_expr, term_coefficients, mle_indices_per_term, _) = cs - .r_table_expressions - .iter() - .map(|r| &r.expr) - .chain(cs.r_expressions.iter()) - .chain(cs.w_table_expressions.iter().map(|w| &w.expr)) - .chain(cs.w_expressions.iter()) - .chain(cs.lk_table_expressions.iter().map(|lk| &lk.multiplicity)) - .chain(cs.lk_table_expressions.iter().map(|lk| &lk.values)) - .chain(cs.lk_expressions.iter()) - .map(|expr| { - assert_eq!(expr.degree(), 1); - - let monomial_term = monomialize_expr_to_wit_terms( - expr, - cs.num_witin as WitnessId, - cs.num_structural_witin as WitnessId, - cs.num_fixed as WitnessId, - ); - - let (coeffs, indices, size_info) = extract_mle_relationships_from_monomial_terms( - &monomial_term, - &layer_witin, - &[], - challenges, - ); - let coeffs_gl64: Vec = unsafe { std::mem::transmute(coeffs) }; - (coeffs_gl64, indices, size_info) - }) - .fold( - (0, Vec::new(), Vec::new(), Vec::new()), - |(mut num_non_zero_expr, mut coeff_acc, mut indices_acc, mut size_acc), - (coeffs, indices, size_info)| { - num_non_zero_expr += 1; - coeff_acc.push(coeffs); - indices_acc.push(indices); - size_acc.push(size_info); - (num_non_zero_expr, coeff_acc, indices_acc, size_acc) - }, - ); - exit_span!(span); - - let span = entered_span!("witness_infer", profiling_2 = true); - let cuda_hal = get_cuda_hal().unwrap(); - let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu> = - unsafe { std::mem::transmute(layer_witin) }; - let all_witins_gpu_type_gl64 = all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec(); - - // buffer for output witness from gpu - let mut next_witness_buf = (0..num_non_zero_expr) - .map(|_| { - cuda_hal - .alloc_ext_elems_on_device(1 << num_vars) - .map_err(|e| format!("Failed to allocate prod GPU buffer: {:?}", e)) - }) - .collect::, _>>() - .unwrap(); - - cuda_hal - .witness_infer - .wit_infer_by_monomial_expr( - &cuda_hal, - all_witins_gpu_type_gl64, - &term_coefficients, - &mle_indices_per_term, - &mut next_witness_buf, - ) - .unwrap(); - exit_span!(span); - - let next_mles = next_witness_buf - .into_iter() - .map(|buf| { - Arc::new(MultilinearExtensionGpu::from_ceno_gpu_ext( - GpuPolynomialExt::new(buf, num_vars), - )) - }) - .collect_vec(); - - next_mles - } - #[allow(clippy::type_complexity)] #[tracing::instrument( skip_all, @@ -657,90 +538,65 @@ impl> MainSumcheckProver 1" - ); - let mle_cpu = mle.inner_to_mle(); - match mle_cpu.evaluations() { - FieldType::Base(smart_slice) => E::from(smart_slice[0]), - FieldType::Ext(smart_slice) => smart_slice[0], - _ => unreachable!(), - } - }) - .collect_vec(); - let GKRProverOutput { - gkr_proof, - opening_evaluations, - } = gkr_circuit.prove::, GpuProver<_>>( - num_threads, - num_var_with_rotation, - gkr::GKRCircuitWitness { - layers: vec![LayerWitness( - chain!(&input.witness, &input.structural_witness, &input.fixed) - .cloned() - .collect_vec(), - )], - }, - // eval value doesnt matter as it wont be used by prover - &vec![PointAndEval::new(rt_tower, E::ZERO); gkr_circuit.final_out_evals.len()], - &pub_io_evals, - challenges, - transcript, - num_instances, - )?; - Ok(( - opening_evaluations[0].point.clone(), - MainSumcheckEvals { - wits_in_evals: opening_evaluations - .iter() - .take(cs.num_witin as usize) - .map(|Evaluation { value, .. }| value) - .copied() - .collect_vec(), - fixed_in_evals: opening_evaluations - .iter() - .skip((cs.num_witin + cs.num_structural_witin) as usize) - .take(cs.num_fixed) - .map(|Evaluation { value, .. }| value) - .copied() - .collect_vec(), - }, - None, - Some(gkr_proof), - )) - } else { - let span = entered_span!("fixed::evals + witin::evals"); - // In table proof, we always skip same point sumcheck for now - // as tower sumcheck batch product argument/logup in same length - let mut evals = input - .witness - .par_iter() - .chain(input.fixed.par_iter()) - .map(|poly| poly.evaluate(&rt_tower[..poly.num_vars()])) - .collect::>(); - let fixed_in_evals = evals.split_off(input.witness.len()); - let wits_in_evals = evals; - exit_span!(span); - - Ok(( - rt_tower, - MainSumcheckEvals { - wits_in_evals, - fixed_in_evals, - }, - None, - None, - )) - } + let Some(gkr_circuit) = gkr_circuit else { + panic!("empty gkr circuit") + }; + let pub_io_mles = cs + .instance_openings + .iter() + .map(|instance| input.public_input[instance.0].clone()) + .collect_vec(); + let GKRProverOutput { + gkr_proof, + opening_evaluations, + mut rt, + } = gkr_circuit.prove::, GpuProver<_>>( + num_threads, + num_var_with_rotation, + gkr::GKRCircuitWitness { + layers: vec![LayerWitness( + chain!( + &input.witness, + &input.fixed, + &pub_io_mles, + &input.structural_witness, + ) + .cloned() + .collect_vec(), + )], + }, + // eval value doesnt matter as it wont be used by prover + &vec![PointAndEval::new(rt_tower, E::ZERO); gkr_circuit.final_out_evals.len()], + &input + .pub_io_evals + .iter() + .map(|v| v.map_either(E::from, |v| v).into_inner()) + .collect_vec(), + challenges, + transcript, + num_instances, + )?; + assert_eq!(rt.len(), 1, "TODO support multi-layer gkr iop"); + Ok(( + rt.remove(0), + MainSumcheckEvals { + wits_in_evals: opening_evaluations + .iter() + .take(cs.num_witin as usize) + .map(|Evaluation { value, .. }| value) + .copied() + .collect_vec(), + fixed_in_evals: opening_evaluations + .iter() + .skip(cs.num_witin as usize) + .take(cs.num_fixed) + .map(|Evaluation { value, .. }| value) + .copied() + .collect_vec(), + }, + None, + Some(gkr_proof), + )) } } @@ -752,7 +608,7 @@ impl> OpeningProver as ProverBackend>::PcsData, fixed_data: Option as ProverBackend>::PcsData>>, points: Vec>, - mut evals: Vec>, // where each inner Vec = wit_evals + fixed_evals + mut evals: Vec>>, // where each inner Vec = wit_evals + fixed_evals transcript: &mut (impl Transcript + 'static), ) -> PCS::Proof { if std::any::TypeId::of::() != std::any::TypeId::of::() { @@ -850,6 +706,7 @@ impl> DeviceTransporter as ProverBackend>::E, @@ -857,7 +714,11 @@ impl> DeviceTransporter, >, ) -> DeviceProvingKey<'_, GpuBackend> { - let pcs_data_original = pk.fixed_commit_wd.clone().unwrap(); + let pcs_data_original = if shard_ctx.is_first_shard() { + pk.fixed_commit_wd.clone().unwrap() + } else { + pk.fixed_no_omc_init_commit_wd.clone().unwrap() + }; // assert pcs match let is_pcs_match = std::mem::size_of::>() diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index 44aa75c21..27cb9adc8 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -1,11 +1,10 @@ -use std::{collections::BTreeMap, sync::Arc}; - use crate::{ - circuit_builder::ConstraintSystem, + e2e::ShardContext, error::ZKVMError, scheme::cpu::TowerRelationOutput, structs::{ComposedConstrainSystem, EccQuarkProof, ZKVMProvingKey}, }; +use either::Either; use ff_ext::ExtensionField; use gkr_iop::{ gkr::GKRProof, @@ -13,6 +12,7 @@ use gkr_iop::{ }; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{mle::MultilinearExtension, util::ceil_log2}; +use std::{collections::BTreeMap, sync::Arc}; use sumcheck::structs::IOPProverMessage; use transcript::Transcript; use witness::next_pow2_instance_padding; @@ -38,6 +38,7 @@ pub struct ProofInput<'a, PB: ProverBackend> { pub structural_witness: Vec>>, pub fixed: Vec>>, pub public_input: Vec>>, + pub pub_io_evals: Vec::BaseField, PB::E>>, pub num_instances: Vec, pub has_ecc_ops: bool, } @@ -71,7 +72,7 @@ pub trait TraceCommitter { // the traces in the form of multilinear polynomials #[allow(clippy::type_complexity)] fn commit_traces<'a>( - &mut self, + &self, traces: BTreeMap::BaseField>>, ) -> ( Vec>, @@ -107,8 +108,6 @@ pub trait TowerProver { cs: &ComposedConstrainSystem, input: &ProofInput<'a, PB>, records: &'c [Arc>], - is_padded: bool, - challenge: &[PB::E; 2], ) -> ( Vec>>, Vec>, @@ -126,7 +125,6 @@ pub trait TowerProver { composed_cs: &ComposedConstrainSystem, input: &ProofInput<'a, PB>, records: &'c [Arc>], - is_padded: bool, challenges: &[PB::E; 2], transcript: &mut impl Transcript, ) -> TowerRelationOutput @@ -141,13 +139,6 @@ pub struct MainSumcheckEvals { } pub trait MainSumcheckProver { - fn table_witness<'a>( - &self, - input: &ProofInput<'a, PB>, - cs: &ConstraintSystem, - challenges: &[PB::E], - ) -> Vec>>; - // this prover aims to achieve two goals: // 1. the validity of last layer in the tower tree is reduced to // the validity of read/write/logup records through sumchecks; @@ -192,6 +183,7 @@ pub struct DeviceProvingKey<'a, PB: ProverBackend> { pub trait DeviceTransporter { fn transport_proving_key( &self, + shard_ctx: &ShardContext, proving_key: Arc>, ) -> DeviceProvingKey<'_, PB>; diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index edf7a63f1..b1e89cab5 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -2,6 +2,7 @@ use super::{PublicValues, utils::wit_infer_by_expr}; use crate::{ ROMType, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, state::{GlobalState, StateCircuit}, structs::{ ComposedConstrainSystem, ProgramParams, RAMType, ZKVMConstraintSystem, ZKVMFixedTraces, @@ -39,6 +40,7 @@ use std::{ hash::Hash, io::{BufReader, ErrorKind}, marker::PhantomData, + ops::Index, sync::OnceLock, }; use strum::IntoEnumIterator; @@ -522,6 +524,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { structural_witin, &[], &[], + &[], Some(challenge), lkm, ) @@ -533,7 +536,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { program: &[ceno_emul::Instruction], lkm: Option>, ) -> Result<(), Vec>> { - Self::run_maybe_challenge(cb, &[], wits_in, &[], program, &[], None, lkm) + Self::run_maybe_challenge(cb, &[], wits_in, &[], program, &[], &[], None, lkm) } #[allow(clippy::too_many_arguments)] @@ -543,7 +546,8 @@ impl<'a, E: ExtensionField + Hash> MockProver { wits_in: &[ArcMultilinearExtension<'a, E>], structural_witin: &[ArcMultilinearExtension<'a, E>], program: &[ceno_emul::Instruction], - pi: &[ArcMultilinearExtension<'a, E>], + pi_mles: &[ArcMultilinearExtension<'a, E>], + pub_io_evals: &[Either], challenge: Option<[E; 2]>, lkm: Option>, ) -> Result<(), Vec>> { @@ -556,7 +560,8 @@ impl<'a, E: ExtensionField + Hash> MockProver { fixed, wits_in, structural_witin, - pi, + pi_mles, + pub_io_evals, 1, challenge, lkm, @@ -571,7 +576,8 @@ impl<'a, E: ExtensionField + Hash> MockProver { fixed: &[ArcMultilinearExtension<'a, E>], wits_in: &[ArcMultilinearExtension<'a, E>], structural_witin: &[ArcMultilinearExtension<'a, E>], - pi: &[ArcMultilinearExtension<'a, E>], + pi_mles: &[ArcMultilinearExtension<'a, E>], + pub_io_evals: &[Either], num_instances: usize, challenge: [E; 2], expected_lkm: Option>, @@ -579,6 +585,14 @@ impl<'a, E: ExtensionField + Hash> MockProver { let mut shared_lkm = LkMultiplicityRaw::::default(); let mut errors = vec![]; + let num_instance_padded = wits_in + .first() + .or_else(|| fixed.first()) + .or_else(|| pi_mles.first()) + .or_else(|| structural_witin.first()) + .map(|mle| mle.evaluations().len()) + .unwrap_or_else(|| next_pow2_instance_padding(num_instances)); + // Assert zero expressions for (expr, name) in cs .assert_zero_expressions @@ -603,9 +617,12 @@ impl<'a, E: ExtensionField + Hash> MockProver { structural_witin[zero_selector.selector_expr().id()].clone() } else { let mut selector = vec![E::BaseField::ONE; num_instances]; - selector.resize(wits_in[0].evaluations().len(), E::BaseField::ZERO); - MultilinearExtension::from_evaluation_vec_smart(wits_in[0].num_vars(), selector) - .into() + selector.resize(num_instance_padded, E::BaseField::ZERO); + MultilinearExtension::from_evaluation_vec_smart( + ceil_log2(num_instance_padded), + selector, + ) + .into() }; // require_equal does not always have the form of Expr::Sum as @@ -618,12 +635,13 @@ impl<'a, E: ExtensionField + Hash> MockProver { let left_evaluated = wit_infer_by_expr( left, cs.num_witin, - cs.num_structural_witin, cs.num_fixed as WitnessId, + cs.instance_openings.len(), fixed, wits_in, structural_witin, - pi, + pi_mles, + pub_io_evals, &challenge, ); let left_evaluated = @@ -632,12 +650,13 @@ impl<'a, E: ExtensionField + Hash> MockProver { let right_evaluated = wit_infer_by_expr( &right, cs.num_witin, - cs.num_structural_witin, cs.num_fixed as WitnessId, + cs.instance_openings.len(), fixed, wits_in, structural_witin, - pi, + pi_mles, + pub_io_evals, &challenge, ); let right_evaluated = @@ -663,12 +682,13 @@ impl<'a, E: ExtensionField + Hash> MockProver { let expr_evaluated = wit_infer_by_expr( expr, cs.num_witin, - cs.num_structural_witin, cs.num_fixed as WitnessId, + cs.instance_openings.len(), fixed, wits_in, structural_witin, - pi, + pi_mles, + pub_io_evals, &challenge, ); let expr_evaluated = @@ -691,26 +711,30 @@ impl<'a, E: ExtensionField + Hash> MockProver { structural_witin[lk_selector.selector_expr().id()].clone() } else { let mut selector = vec![E::BaseField::ONE; num_instances]; - selector.resize(wits_in[0].evaluations().len(), E::BaseField::ZERO); - MultilinearExtension::from_evaluation_vec_smart(wits_in[0].num_vars(), selector).into() + selector.resize(num_instance_padded, E::BaseField::ZERO); + MultilinearExtension::from_evaluation_vec_smart( + ceil_log2(num_instance_padded), + selector, + ) + .into() }; // Lookup expressions - for ((expr, name), (rom_type, _)) in cs - .lk_expressions - .iter() - .zip_eq(cs.lk_expressions_namespace_map.iter()) - .zip_eq(cs.lk_expressions_items_map.iter()) - { + for (expr, (name, (rom_type, _))) in cs.lk_expressions.iter().zip( + cs.lk_expressions_namespace_map + .iter() + .zip_eq(cs.lk_expressions_items_map.iter()), + ) { let expr_evaluated = wit_infer_by_expr( expr, cs.num_witin, - cs.num_structural_witin, cs.num_fixed as WitnessId, + cs.instance_openings.len(), fixed, wits_in, structural_witin, - pi, + pi_mles, + pub_io_evals, &challenge, ); let expr_evaluated = filter_mle_by_selector_mle(expr_evaluated, lk_selector.clone()); @@ -750,12 +774,13 @@ impl<'a, E: ExtensionField + Hash> MockProver { let arg_eval = wit_infer_by_expr( arg_expr, cs.num_witin, - cs.num_structural_witin, cs.num_fixed as WitnessId, + cs.instance_openings.len(), fixed, wits_in, structural_witin, - pi, + pi_mles, + pub_io_evals, &challenge, ); if arg_expr.is_constant() && arg_eval.evaluations.len() == 1 { @@ -944,6 +969,7 @@ Hints: } pub fn assert_satisfied_full( + shard_ctx: &ShardContext, cs: &ZKVMConstraintSystem, mut fixed_trace: ZKVMFixedTraces, witnesses: &ZKVMWitnesses, @@ -952,13 +978,12 @@ Hints: ) where E: LkMultiplicityKey, { - let instance = pi + let pub_io_evals = pi .to_vec::() - .concat() .into_iter() - .map(|i| E::from(i)) + .map(|v| Either::Right(E::from(*v.index(0)))) .collect_vec(); - let pi_mles = pi + let pi_mles: Vec> = pi .to_vec::() .into_mles() .into_iter() @@ -995,10 +1020,23 @@ Hints: // Process all circuits. for (circuit_name, composed_cs) in &cs.circuit_css { let ComposedConstrainSystem { - zkvm_v1_css: cs, - gkr_circuit, + zkvm_v1_css: cs, .. } = &composed_cs; - let is_opcode = gkr_circuit.is_some(); + let pi_mles = cs + .instance_openings + .iter() + .map(|instance| pi_mles[instance.0].clone()) + .collect_vec(); + + // skip init table on non-first shard + if composed_cs.with_omc_init_only() && !shard_ctx.is_first_shard() { + wit_mles.insert(circuit_name.clone(), vec![]); + structural_wit_mles.insert(circuit_name.clone(), vec![]); + fixed_mles.insert(circuit_name.clone(), vec![]); + num_instances.insert(circuit_name.clone(), 0); + continue; + } + let [witness, structural_witness] = witnesses .get_opcode_witness(circuit_name) .or_else(|| witnesses.get_table_witness(circuit_name)) @@ -1037,7 +1075,8 @@ Hints: .map_or(vec![], |fixed| { fixed.to_mles().into_iter().map(|f| f.into()).collect_vec() }); - if is_opcode { + // not lookup table + if cs.lk_table_expressions.is_empty() { tracing::info!( "Mock proving opcode {} with {} entries", circuit_name, @@ -1054,6 +1093,7 @@ Hints: &witness, &structural_witness, &pi_mles, + &pub_io_evals, num_rows, challenges, lkm_from_assignments, @@ -1079,12 +1119,13 @@ Hints: let lk_table = wit_infer_by_expr( &expr.values, cs.num_witin, - cs.num_structural_witin, cs.num_fixed as WitnessId, + cs.instance_openings.len(), &fixed, &witness, &structural_witness, &pi_mles, + &pub_io_evals, &challenges, ) .get_ext_field_vec() @@ -1093,12 +1134,13 @@ Hints: let multiplicity = wit_infer_by_expr( &expr.multiplicity, cs.num_witin, - cs.num_structural_witin, cs.num_fixed as WitnessId, + cs.instance_openings.len(), &fixed, &witness, &structural_witness, &pi_mles, + &pub_io_evals, &challenges, ) .get_ext_field_vec() @@ -1151,6 +1193,11 @@ Hints: let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); + let pi_mles = cs + .instance_openings + .iter() + .map(|instance| pi_mles[instance.0].clone()) + .collect_vec(); let num_rows = num_instances.get(circuit_name).unwrap(); if *num_rows == 0 { @@ -1183,24 +1230,26 @@ Hints: let ram_type_mle = wit_infer_by_expr( ram_type_expr, cs.num_witin, - cs.num_structural_witin, cs.num_fixed as WitnessId, + cs.instance_openings.len(), fixed, witness, structural_witness, &pi_mles, + &pub_io_evals, &challenges, ); let ram_type_vec = ram_type_mle.get_ext_field_vec(); let write_rlc_records = wit_infer_by_expr( w_rlc_expr, cs.num_witin, - cs.num_structural_witin, cs.num_fixed as WitnessId, + cs.instance_openings.len(), fixed, witness, structural_witness, &pi_mles, + &pub_io_evals, &challenges, ); let w_selector_vec = w_selector.get_base_field_vec(); @@ -1220,14 +1269,14 @@ Hints: assert_eq!( writes_within_expr_dedup.insert(record_rlc), true, - "within expression write duplicated on RAMType {:?} annotation {:?}", + "circuit name {circuit_name} within expression write duplicated on RAMType {:?} annotation {:?} on row {row}", $ram_type, annotation ); assert_eq!( writes.insert(record_rlc), true, - "crossing-chip write duplicated on RAMType {:?} annotation {:?}", + "circuit name {circuit_name} crossing-chip write duplicated on RAMType {:?} annotation {:?} on row {row}", $ram_type, annotation ); @@ -1250,6 +1299,11 @@ Hints: let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); + let pi_mles = cs + .instance_openings + .iter() + .map(|instance| pi_mles[instance.0].clone()) + .collect_vec(); let num_rows = num_instances.get(circuit_name).unwrap(); if *num_rows == 0 { continue; @@ -1280,24 +1334,26 @@ Hints: let ram_type_mle = wit_infer_by_expr( ram_type_expr, cs.num_witin, - cs.num_structural_witin, cs.num_fixed as WitnessId, + cs.instance_openings.len(), fixed, witness, structural_witness, &pi_mles, + &pub_io_evals, &challenges, ); let ram_type_vec = ram_type_mle.get_ext_field_vec(); let read_records = wit_infer_by_expr( r_rlc_expr, cs.num_witin, - cs.num_structural_witin, cs.num_fixed as WitnessId, + cs.instance_openings.len(), fixed, witness, structural_witness, &pi_mles, + &pub_io_evals, &challenges, ); let r_selector_vec = r_selector.get_base_field_vec(); @@ -1319,12 +1375,13 @@ Hints: let v = wit_infer_by_expr( expr, cs.num_witin, - cs.num_structural_witin, cs.num_fixed as WitnessId, + cs.instance_openings.len(), fixed, witness, structural_witness, &pi_mles, + &pub_io_evals, &challenges, ); filter_mle_by_selector_mle(v, r_selector.clone()) @@ -1346,14 +1403,14 @@ Hints: assert_eq!( reads_within_expr_dedup.insert(record), true, - "within expression read duplicated on RAMType {:?} annotation {:?}", + "circuit name {circuit_name} within expression read duplicated on RAMType {:?} annotation {:?} on row {row}", $ram_type, annotation, ); assert_eq!( reads.insert(record), true, - "crossing-chip read duplicated on RAMType {:?} annotation {:?}", + "circuit name {circuit_name} crossing-chip read duplicated on RAMType {:?} annotation {:?} on row {row}", $ram_type, annotation, ); @@ -1459,14 +1516,34 @@ Hints: let (mut gs_rs, rs_grp_by_anno, mut gs_ws, ws_grp_by_anno, gs) = derive_ram_rws!(RAMType::GlobalState); gs_rs.insert( - eval_by_expr_with_instance(&[], &[], &[], &instance, &challenges, &gs_final) - .right() - .unwrap(), + eval_by_expr_with_instance( + &[], + &[], + &[], + &pub_io_evals + .iter() + .map(|v| v.right().unwrap()) + .collect_vec(), + &challenges, + &gs_final, + ) + .right() + .unwrap(), ); gs_ws.insert( - eval_by_expr_with_instance(&[], &[], &[], &instance, &challenges, &gs_init) - .right() - .unwrap(), + eval_by_expr_with_instance( + &[], + &[], + &[], + &pub_io_evals + .iter() + .map(|v| v.right().unwrap()) + .collect_vec(), + &challenges, + &gs_init, + ) + .right() + .unwrap(), ); // gs stores { (pc, timestamp) } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 3ae856d28..577641fc1 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -10,6 +10,7 @@ use std::{ }; use crate::scheme::{constants::SEPTIC_EXTENSION_DEGREE, hal::MainSumcheckEvals}; +use either::Either; use gkr_iop::hal::MultilinearPolynomial; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; @@ -28,6 +29,7 @@ use witness::RowMajorMatrix; use super::{PublicValues, ZKVMChipProof, ZKVMProof, hal::ProverDevice}; use crate::{ + e2e::ShardContext, error::ZKVMError, scheme::{hal::ProofInput, utils::build_main_witness}, structs::{ProvingKey, TowerProofs, ZKVMProvingKey, ZKVMWitnesses}, @@ -76,7 +78,8 @@ impl< level = "trace" )] pub fn create_proof( - &mut self, + &self, + shard_ctx: &ShardContext, witnesses: ZKVMWitnesses, pi: PublicValues, mut transcript: impl Transcript + 'static, @@ -97,7 +100,13 @@ impl< // commit to fixed commitment let span = entered_span!("commit_to_fixed_commit", profiling_1 = true); - if let Some(fixed_commit) = &self.pk.fixed_commit { + if let Some(fixed_commit) = &self.pk.fixed_commit + && shard_ctx.is_first_shard() + { + PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?; + } else if let Some(fixed_commit) = &self.pk.fixed_no_omc_init_commit + && !shard_ctx.is_first_shard() + { PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?; } exit_span!(span); @@ -116,6 +125,10 @@ impl< let mut circuit_name_num_instances_mapping = BTreeMap::new(); for (index, (circuit_name, ProvingKey { vk, .. })) in self.pk.circuit_pks.iter().enumerate() { + // skip omc init on >1 shard + if !shard_ctx.is_first_shard() && vk.get_cs().with_omc_init_only() { + continue; + } // num_instance from witness might include rotation if let Some(num_instance) = witnesses .num_instances @@ -191,7 +204,9 @@ impl< // transfer pk to device let transfer_pk_span = entered_span!("transfer pk to device", profiling_1 = true); - let device_pk = self.device.transport_proving_key(self.pk.clone()); + let device_pk = self + .device + .transport_proving_key(shard_ctx, self.pk.clone()); let mut fixed_mles = device_pk.fixed_mles; exit_span!(transfer_pk_span); @@ -215,6 +230,11 @@ impl< .cloned() .unwrap_or_default(); let cs = pk.get_cs(); + if !shard_ctx.is_first_shard() && cs.with_omc_init_only() { + assert!(num_instances.is_empty()); + // skip drain respective fixed because we use different set of fixed commitment + return Ok::<(Vec<_>, Vec>), ZKVMError>((points, evaluations)); + } if num_instances.is_empty() { // we need to drain respective fixed when num_instances is 0 if cs.num_fixed() > 0 { @@ -245,49 +265,32 @@ impl< fixed, structural_witness, public_input: public_input.clone(), + pub_io_evals: pi_evals.iter().map(|p| Either::Right(*p)).collect(), num_instances: num_instances.clone(), has_ecc_ops: cs.has_ecc_ops(), }; - if cs.is_opcode_circuit() { - let (opcode_proof, _, input_opening_point) = self.create_chip_proof( - circuit_name, - pk, - input, - &mut transcript, - &challenges, - )?; - tracing::trace!( - "generated proof for opcode {} with num_instances={:?}", - circuit_name, - num_instances - ); + let (opcode_proof, pi_in_evals, input_opening_point) = + self.create_chip_proof(circuit_name, pk, input, &mut transcript, &challenges)?; + tracing::trace!( + "generated proof for opcode {} with num_instances={:?}", + circuit_name, + num_instances + ); + if cs.num_witin() > 0 || cs.num_fixed() > 0 { points.push(input_opening_point); - evaluations.push(vec![opcode_proof.wits_in_evals.clone()]); - chip_proofs.insert(index, opcode_proof); + evaluations.push(vec![ + opcode_proof.wits_in_evals.clone(), + opcode_proof.fixed_in_evals.clone(), + ]); } else { - let (table_proof, pi_in_evals, input_opening_point) = self.create_chip_proof( - circuit_name, - pk, - input, - &mut transcript, - &challenges, - )?; - if cs.num_witin() > 0 || cs.num_fixed() > 0 { - points.push(input_opening_point); - evaluations.push(vec![ - table_proof.wits_in_evals.clone(), - table_proof.fixed_in_evals.clone(), - ]); - } else { - assert!(table_proof.wits_in_evals.is_empty()); - assert!(table_proof.fixed_in_evals.is_empty()); - } - chip_proofs.insert(index, table_proof); - for (idx, eval) in pi_in_evals { - pi_evals[idx] = eval; - } - }; + assert!(opcode_proof.wits_in_evals.is_empty()); + assert!(opcode_proof.fixed_in_evals.is_empty()); + } + chip_proofs.insert(index, opcode_proof); + for (idx, eval) in pi_in_evals { + pi_evals[idx] = eval; + } Ok((points, evaluations)) }, )?; @@ -336,6 +339,7 @@ impl< // run ecc quark prover let ecc_proof = if !cs.zkvm_v1_css.ec_final_sum.is_empty() { + let span = entered_span!("run_ecc_final_sum", profiling_2 = true); let ec_point_exprs = &cs.zkvm_v1_css.ec_point_exprs; assert_eq!(ec_point_exprs.len(), SEPTIC_EXTENSION_DEGREE * 2); let mut xs_ys = ec_point_exprs @@ -356,27 +360,28 @@ impl< _ => unreachable!("slope's expression must be WitIn"), }) .collect_vec(); - Some(self.device.prove_ec_sum_quark( + let ecc_proof = Some(self.device.prove_ec_sum_quark( input.num_instances(), xs, ys, slopes, transcript, - )?) + )?); + exit_span!(span); + ecc_proof } else { None }; // build main witness - let (records, is_padded) = - build_main_witness::(&self.device, cs, &input, challenges); + let records = build_main_witness::(cs, &input, challenges); let span = entered_span!("prove_tower_relation", profiling_2 = true); // prove the product and logup sum relation between layers in tower // (internally calls build_tower_witness) let (rt_tower, tower_proof, lk_out_evals, w_out_evals, r_out_evals) = self .device - .prove_tower_relation(cs, &input, &records, is_padded, challenges, transcript); + .prove_tower_relation(cs, &input, &records, challenges, transcript); exit_span!(span); assert_eq!( @@ -407,9 +412,9 @@ impl< // evaluate pi if there is instance query let mut pi_in_evals: HashMap = HashMap::new(); - if !cs.instance_name_map().is_empty() { + if !cs.instance_openings().is_empty() { let span = entered_span!("pi::evals"); - for &Instance(idx) in cs.instance_name_map().keys() { + for &Instance(idx) in cs.instance_openings() { let poly = &input.public_input[idx]; pi_in_evals.insert( idx, diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 73355017c..be85360f9 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -11,9 +11,7 @@ use crate::{ create_backend, create_prover, hal::{ProofInput, TowerProverSpec}, }, - structs::{ - PointAndEval, ProgramParams, RAMType, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses, - }, + structs::{ProgramParams, RAMType, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::ProgramTableCircuit, witness::{LkMultiplicity, set_val}, }; @@ -37,12 +35,15 @@ use ff_ext::{Instrumented, PoseidonField}; use super::{ PublicValues, - constants::{MAX_NUM_VARIABLES, NUM_FANIN}, + constants::MAX_NUM_VARIABLES, prover::ZKVMProver, utils::infer_tower_product_witness, verifier::{TowerVerify, ZKVMVerifier}, }; -use crate::{e2e::ShardContext, tables::DynamicRangeTableCircuit}; +use crate::{ + e2e::ShardContext, scheme::constants::NUM_FANIN, structs::PointAndEval, + tables::DynamicRangeTableCircuit, +}; use itertools::Itertools; use mpcs::{ PolynomialCommitmentScheme, SecurityLevel, SecurityLevel::Conjecture100bits, WhirDefault, @@ -132,6 +133,7 @@ fn test_rw_lk_expression_combination() { .key_gen::( device.backend.pp.clone(), device.backend.vp.clone(), + 0, zkvm_fixed_traces, ) .unwrap(); @@ -198,6 +200,7 @@ fn test_rw_lk_expression_combination() { witness: wits_in, structural_witness: structural_in, public_input: vec![], + pub_io_evals: vec![], num_instances: vec![num_instances], has_ecc_ops: false, }; @@ -227,11 +230,12 @@ fn test_rw_lk_expression_combination() { Instrumented::<<::BaseField as PoseidonField>::P>::clear_metrics(); } verifier - .verify_opcode_proof( + .verify_chip_proof( name.as_str(), verifier.vk.circuit_vks.get(&name).unwrap(), &proof, &[], + &[], &mut v_transcript, NUM_FANIN, &PointAndEval::default(), @@ -307,7 +311,7 @@ fn test_single_add_instance_e2e() { let pk = zkvm_cs .clone() - .key_gen::(pp, vp, zkvm_fixed_traces) + .key_gen::(pp, vp, program.entry, zkvm_fixed_traces) .expect("keygen failed"); let vk = pk.get_vk_slow(); @@ -340,7 +344,7 @@ fn test_single_add_instance_e2e() { let (max_num_variables, security_level) = default_backend_config(); let backend = create_backend::(max_num_variables, security_level); let device = create_prover(backend); - let mut prover = ZKVMProver::new(pk, device); + let prover = ZKVMProver::new(pk, device); let verifier = ZKVMVerifier::new(vk); let mut zkvm_witness = ZKVMWitnesses::default(); // assign opcode circuits @@ -375,7 +379,7 @@ fn test_single_add_instance_e2e() { let pi = PublicValues::new(0, 0, 0, 0, 0, 0, vec![0], vec![0; 14]); let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover - .create_proof(zkvm_witness, pi, transcript) + .create_proof(&shard_ctx, zkvm_witness, pi, transcript) .expect("create_proof failed"); println!("encoded zkvm proof {}", &zkvm_proof,); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index cfa88175f..2426c6f50 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -1,7 +1,7 @@ use crate::{ scheme::{ constants::MIN_PAR_SIZE, - hal::{MainSumcheckProver, ProofInput, ProverDevice}, + hal::{ProofInput, ProverDevice}, }, structs::ComposedConstrainSystem, }; @@ -29,50 +29,6 @@ use rayon::{ use std::{iter, sync::Arc}; use witness::next_pow2_instance_padding; -// first computes the masked mle'[j] = mle[j] if j < num_instance, else default -// then split it into `num_parts` smaller mles -pub(crate) fn masked_mle_split_to_chunks<'a, 'b, E: ExtensionField>( - mle: &'a MultilinearExtension<'a, E>, - num_instance: usize, - num_chunks: usize, - default: E, -) -> Vec> { - assert!(num_chunks.is_power_of_two()); - assert!( - num_instance <= mle.evaluations().len(), - "num_instance {num_instance} > {}", - mle.evaluations().len() - ); - - // TODO: when mle.len() is two's power, we should avoid the clone - (0..num_chunks) - .into_par_iter() - .map(|part_idx| { - let n = mle.evaluations().len() / num_chunks; - - match mle.evaluations() { - FieldType::Ext(evals) => (part_idx * n..(part_idx + 1) * n) - .into_par_iter() - .with_min_len(64) - .map(|i| if i < num_instance { evals[i] } else { default }) - .collect::>() - .into_mle(), - FieldType::Base(evals) => (part_idx * n..(part_idx + 1) * n) - .map(|i| { - if i < num_instance { - E::from(evals[i]) - } else { - default - } - }) - .collect::>() - .into_mle(), - _ => unreachable!(), - } - }) - .collect::>() -} - /// interleaving multiple mles into mles, and num_limbs indicate number of final limbs vector /// e.g input [[1,2],[3,4],[5,6],[7,8]], num_limbs=2,log2_per_instance_size=3 /// output [[1,3,5,7,0,0,0,0],[2,4,6,8,0,0,0,0]] @@ -349,76 +305,76 @@ pub fn build_main_witness< PB: ProverBackend + 'static, PD: ProverDevice, >( - device: &PD, composed_cs: &ComposedConstrainSystem, input: &ProofInput<'a, PB>, challenges: &[E; 2], -) -> (Vec>>, bool) { - let (mles, is_padded) = { - let ComposedConstrainSystem { - zkvm_v1_css: cs, - gkr_circuit, - } = composed_cs; - let log2_num_instances = input.log2_num_instances(); - let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0); - - // sanity check - assert_eq!(input.witness.len(), cs.num_witin as usize); - - // structural witness can be empty. In this case they are `eq`, and will be filled later - assert!( - input.structural_witness.len() == cs.num_structural_witin as usize - || input.structural_witness.is_empty(), - ); - assert_eq!(input.fixed.len(), cs.num_fixed); +) -> Vec>> { + let ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit, + } = composed_cs; + let log2_num_instances = input.log2_num_instances(); + let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0); + + // sanity check + assert_eq!(input.witness.len(), cs.num_witin as usize); + + // structural witness can be empty. In this case they are `eq`, and will be filled later + assert!( + input.structural_witness.len() == cs.num_structural_witin as usize + || input.structural_witness.is_empty(), + ); + assert_eq!(input.fixed.len(), cs.num_fixed); - // check all witness size are power of 2 + // check all witness size are power of 2 + assert!( + input + .witness + .iter() + .all(|v| { v.evaluations_len() == 1 << num_var_with_rotation }) + ); + + if !input.structural_witness.is_empty() { assert!( input - .witness + .structural_witness .iter() .all(|v| { v.evaluations_len() == 1 << num_var_with_rotation }) ); + } - if !input.structural_witness.is_empty() { - assert!( - input - .structural_witness - .iter() - .all(|v| { v.evaluations_len() == 1 << num_var_with_rotation }) - ); - } - - if let Some(gkr_circuit) = gkr_circuit { - // circuit must have at least one read/write/lookup - assert!( - cs.r_expressions.len() - + cs.w_expressions.len() - + cs.lk_expressions.len() - + cs.r_table_expressions.len() - + cs.w_table_expressions.len() - + cs.lk_table_expressions.len() - > 0, - "assert circuit" - ); - - let (_, gkr_circuit_out) = gkr_witness::( - gkr_circuit, - &input.witness, - &input.structural_witness, - &input.fixed, - &input.public_input, - challenges, - ); - (gkr_circuit_out.0.0, true) - } else { - ( - >::table_witness(device, input, cs, challenges), - input.num_instances() > 1 && input.num_instances().is_power_of_two(), - ) - } + let Some(gkr_circuit) = gkr_circuit else { + panic!("empty gkr-iop") }; - (mles, is_padded) + + // circuit must have at least one read/write/lookup + assert!( + cs.r_expressions.len() + + cs.w_expressions.len() + + cs.lk_expressions.len() + + cs.r_table_expressions.len() + + cs.w_table_expressions.len() + + cs.lk_table_expressions.len() + > 0, + "assert circuit" + ); + + let pub_io_mles = cs + .instance_openings + .iter() + .map(|instance| input.public_input[instance.0].clone()) + .collect_vec(); + + let (_, gkr_circuit_out) = gkr_witness::( + gkr_circuit, + &input.witness, + &input.structural_witness, + &input.fixed, + &pub_io_mles, + &input.pub_io_evals, + challenges, + ); + gkr_circuit_out.0.0 } pub fn gkr_witness< @@ -432,7 +388,8 @@ pub fn gkr_witness< phase1_witness_group: &[Arc>], structural_witness: &[Arc>], fixed: &[Arc>], - pub_io: &[Arc>], + pub_io_mles: &[Arc>], + pub_io_evals: &[Either], challenges: &[E], ) -> (GKRCircuitWitness<'b, PB>, GKRCircuitOutput<'b, PB>) { // layer order from output to input @@ -446,20 +403,35 @@ pub fn gkr_witness< first_layer .in_eval_expr .iter() - .take(phase1_witness_group.len()) - .enumerate() - .for_each(|(index, witin)| { - witness_mle_flatten[*witin] = Some(phase1_witness_group[index].clone()); + .take(first_layer.n_witin) + .zip_eq(phase1_witness_group.iter()) + .for_each(|(index, witin_mle)| { + witness_mle_flatten[*index] = Some(witin_mle.clone()); }); - // TODO process fixed (and probably short) mle - assert_eq!( - first_layer.in_eval_expr.len(), - phase1_witness_group.len(), - "TODO process fixed (and probably short) mle" - ); - // XXX currently fixed poly not support in layers > 1 + first_layer + .in_eval_expr + .iter() + .skip(first_layer.n_witin) + .take(first_layer.n_fixed) + .zip_eq(fixed.iter()) + .for_each(|(index, fixed_mle)| { + witness_mle_flatten[*index] = Some(fixed_mle.clone()); + }); + + first_layer + .in_eval_expr + .iter() + .skip(first_layer.n_witin + first_layer.n_fixed) + .take(first_layer.n_instance) + .zip_eq(pub_io_mles.iter()) + .for_each(|(index, pubio_mle)| { + witness_mle_flatten[*index] = Some(pubio_mle.clone()); + }); + // XXX currently fixed poly not support in layers > 1 + // TODO process fixed (and probably short) mle + // // first_layer // .in_eval_expr // .par_iter() @@ -505,15 +477,22 @@ pub fn gkr_witness< } else { Either::Right(iter::empty()) }) - .chain(fixed.iter().cloned()) .collect_vec(); + assert_eq!( + current_layer_wits.len(), + layer.n_witin + + layer.n_fixed + + layer.n_instance + + if i == 0 { layer.n_structural_witin } else { 0 } + ); + // infer current layer output let current_layer_output: Vec>> = >::layer_witness( layer, ¤t_layer_wits, - pub_io, + pub_io_evals, challenges, ); layer_wits.push(LayerWitness::new(current_layer_wits, vec![])); @@ -525,7 +504,6 @@ pub fn gkr_witness< .flat_map(|(_, out_eval)| out_eval) .zip_eq(¤t_layer_output) .for_each(|(out_eval, out_mle)| match out_eval { - // note: Linear (x - b)/a has been done and encode in expression EvalExpression::Single(out) | EvalExpression::Linear(out, _, _) => { witness_mle_flatten[*out] = Some(out_mle.clone()); } diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 4ed5a89e9..ad767f155 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,33 +1,32 @@ use either::Either; use ff_ext::ExtensionField; -use std::marker::PhantomData; +use std::{iter, marker::PhantomData}; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; +use super::{ZKVMChipProof, ZKVMProof}; use crate::{ error::ZKVMError, + instructions::riscv::constants::{END_PC_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, SHARD_ID_IDX}, scheme::{ - constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE}, - septic_curve::SepticExtension, + constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, + septic_curve::{SepticExtension, SepticPoint}, }, structs::{ ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey, }, - utils::{ - eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, - eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, - }, }; +use ceno_emul::Tracer; use gkr_iop::{ - gkr::GKRClaims, + self, selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, chain, interleave, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - Expression, Instance, StructuralWitIn, StructuralWitInType, + Expression, StructuralWitIn, StructuralWitInType::StackedConstantSequence, mle::IntoMLE, util::ceil_log2, @@ -42,8 +41,6 @@ use sumcheck::{ use transcript::{ForkableTranscript, Transcript}; use witness::next_pow2_instance_padding; -use super::{ZKVMChipProof, ZKVMProof}; - pub struct ZKVMVerifier> { pub vk: ZKVMVerifyingKey, } @@ -67,29 +64,81 @@ impl> ZKVMVerifier self.verify_proof_halt(vm_proof, transcript, true) } + #[tracing::instrument(skip_all, name = "verify_proofs")] + pub fn verify_proofs( + &self, + vm_proofs: Vec>, + transcripts: Vec>, + ) -> Result { + self.verify_proofs_halt(vm_proofs, transcripts, true) + } + /// Verify a trace from start to optional halt. pub fn verify_proof_halt( &self, vm_proof: ZKVMProof, transcript: impl ForkableTranscript, - _expect_halt: bool, + expect_halt: bool, + ) -> Result { + self.verify_proofs_halt(vec![vm_proof], vec![transcript], expect_halt) + } + + /// Verify a trace from start to optional halt. + pub fn verify_proofs_halt( + &self, + vm_proofs: Vec>, + transcripts: Vec>, + expect_halt: bool, ) -> Result { - // require ecall/halt proof to exist, depending whether we expect a halt. - // let has_halt = vm_proof.has_halt(&self.vk); - // if has_halt != expect_halt { - // return Err(ZKVMError::VerifyError( - // format!("ecall/halt mismatch: expected {expect_halt} != {has_halt}",).into(), - // )); - // } - - self.verify_proof_validity(vm_proof, transcript) + assert!(!vm_proofs.is_empty()); + let num_proofs = vm_proofs.len(); + let (_end_pc, shard_ec_sum) = vm_proofs + .into_iter() + .zip_eq(transcripts) + // optionally halt on last chunk + .zip_eq(iter::repeat_n(false, num_proofs - 1).chain(iter::once(expect_halt))) + .enumerate() + .try_fold((None, SepticPoint::::default()), |(prev_pc, mut shard_ec_sum), (shard_id, ((vm_proof, transcript), expect_halt))| { + // require ecall/halt proof to exist, depend on whether we expect a halt. + let has_halt = vm_proof.has_halt(&self.vk); + if has_halt != expect_halt { + return Err(ZKVMError::VerifyError( + format!( + "{shard_id}th proof ecall/halt mismatch: expected {expect_halt} != {has_halt}", + ) + .into(), + )); + } + // each shard set init cycle = Tracer::SUBCYCLES_PER_INSN + // to satisfy initial reads for all prev_cycle = 0 < init_cycle + assert_eq!(vm_proof.pi_evals[INIT_CYCLE_IDX], E::from_canonical_u64(Tracer::SUBCYCLES_PER_INSN)); + // check init_pc match prev end_pc + if let Some(prev_pc) = prev_pc { + assert_eq!(vm_proof.pi_evals[INIT_PC_IDX], prev_pc); + } else { + // first chunk, check program entry + assert_eq!(vm_proof.pi_evals[INIT_PC_IDX], E::from_canonical_u32(self.vk.entry_pc)); + } + let end_pc = vm_proof.pi_evals[END_PC_IDX]; + // add to global shard ec + shard_ec_sum = shard_ec_sum + self.verify_proof_validity(shard_id, vm_proof, transcript)?; + Ok((Some(end_pc), shard_ec_sum)) + })?; + // check shard ec_sum is_infinity + if !shard_ec_sum.is_infinity { + return Err(ZKVMError::VerifyError( + "shard_ec_sum is not infinity".into(), + )); + } + Ok(true) } fn verify_proof_validity( &self, + shard_id: usize, vm_proof: ZKVMProof, mut transcript: impl ForkableTranscript, - ) -> Result { + ) -> Result, ZKVMError> { // main invariant between opcode circuits and table circuits let mut prod_r = E::ONE; let mut prod_w = E::ONE; @@ -103,7 +152,7 @@ impl> ZKVMVerifier if *chip_idx >= self.vk.circuit_vks.len() { return Err(ZKVMError::VKNotFound( format!( - "chip index {chip_idx} not found in vk set [0..{})", + "{shard_id}th shard chip index {chip_idx} not found in vk set [0..{})", self.vk.circuit_vks.len() ) .into(), @@ -118,6 +167,12 @@ impl> ZKVMVerifier .iter() .for_each(|v| v.iter().for_each(|v| transcript.append_field_element(v))); + // check shard id + assert_eq!( + vm_proof.raw_pi[SHARD_ID_IDX], + vec![E::BaseField::from_canonical_usize(shard_id)] + ); + // verify constant poly(s) evaluation result match // we can evaluate at this moment because constant always evaluate to same value // non-constant poly(s) will be verified in respective (table) proof accordingly @@ -126,7 +181,7 @@ impl> ZKVMVerifier .try_for_each(|(i, (raw, eval))| { if raw.len() == 1 && E::from(raw[0]) != *eval { Err(ZKVMError::VerifyError( - format!("pub input on index {i} mismatch {raw:?} != {eval:?}").into(), + format!("{shard_id}th shard pub input on index {i} mismatch {raw:?} != {eval:?}").into(), )) } else { Ok(()) @@ -135,7 +190,13 @@ impl> ZKVMVerifier // write fixed commitment to transcript // TODO check soundness if there is no fixed_commit but got fixed proof? - if let Some(fixed_commit) = self.vk.fixed_commit.as_ref() { + if let Some(fixed_commit) = self.vk.fixed_commit.as_ref() + && shard_id == 0 + { + PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?; + } else if let Some(fixed_commit) = self.vk.fixed_no_omc_init_commit.as_ref() + && shard_id > 0 + { PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?; } @@ -163,32 +224,43 @@ impl> ZKVMVerifier transcript.read_challenge().elements, transcript.read_challenge().elements, ]; - tracing::trace!("challenges in verifier: {:?}", challenges); + tracing::trace!( + "{shard_id}th shard challenges in verifier: {:?}", + challenges + ); let dummy_table_item = challenges[0]; let mut dummy_table_item_multiplicity = 0; let point_eval = PointAndEval::default(); let mut witin_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); let mut fixed_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); + let mut shard_ec_sum = SepticPoint::::default(); for (index, proof) in &vm_proof.chip_proofs { let num_instance: usize = proof.num_instances.iter().sum(); assert!(num_instance > 0); let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; + if shard_id > 0 && circuit_vk.get_cs().with_omc_init_only() { + return Err(ZKVMError::InvalidProof( + format!("{shard_id}th shard non-first shard got omc dynamic table init",) + .into(), + )); + } + // check chip proof is well-formed if proof.wits_in_evals.len() != circuit_vk.get_cs().num_witin() || proof.fixed_in_evals.len() != circuit_vk.get_cs().num_fixed() { return Err(ZKVMError::InvalidProof( format!( - "witness/fixed evaluations length mismatch: ({}, {}) != ({}, {})", + "{shard_id}th shard witness/fixed evaluations length mismatch: ({}, {}) != ({}, {})", proof.wits_in_evals.len(), proof.fixed_in_evals.len(), circuit_vk.get_cs().num_witin(), circuit_vk.get_cs().num_fixed(), ) - .into(), + .into(), )); } if proof.r_out_evals.len() != circuit_vk.get_cs().num_reads() @@ -196,19 +268,19 @@ impl> ZKVMVerifier { return Err(ZKVMError::InvalidProof( format!( - "read/write evaluations length mismatch: ({}, {}) != ({}, {})", + "{shard_id}th shard read/write evaluations length mismatch: ({}, {}) != ({}, {})", proof.r_out_evals.len(), proof.w_out_evals.len(), circuit_vk.get_cs().num_reads(), circuit_vk.get_cs().num_writes(), ) - .into(), + .into(), )); } if proof.lk_out_evals.len() != circuit_vk.get_cs().num_lks() { return Err(ZKVMError::InvalidProof( format!( - "lookup evaluations length mismatch: {} != {}", + "{shard_id}th shard lookup evaluations length mismatch: {} != {}", proof.lk_out_evals.len(), circuit_vk.get_cs().num_lks(), ) @@ -226,7 +298,9 @@ impl> ZKVMVerifier .sum::(); transcript.append_field_element(&E::BaseField::from_canonical_u64(*index as u64)); - let input_opening_point = if circuit_vk.get_cs().is_opcode_circuit() { + if circuit_vk.get_cs().is_with_lk_table() { + logup_sum -= chip_logup_sum; + } else { // getting the number of dummy padding item that we used in this opcode circuit let num_lks = circuit_vk.get_cs().num_lks(); // each padding instance contribute to (2^rotation_vars) dummy lookup padding @@ -240,30 +314,18 @@ impl> ZKVMVerifier num_lks * (num_padded_instance + num_instance_non_selected); logup_sum += chip_logup_sum; - self.verify_opcode_proof( - circuit_name, - circuit_vk, - proof, - pi_evals, - &mut transcript, - NUM_FANIN, - &point_eval, - &challenges, - )? - } else { - logup_sum -= chip_logup_sum; - self.verify_table_proof( - circuit_name, - circuit_vk, - proof, - &vm_proof.raw_pi, - &vm_proof.pi_evals, - &mut transcript, - NUM_FANIN_LOGUP, - &point_eval, - &challenges, - )? }; + let (input_opening_point, chip_shard_ec_sum) = self.verify_chip_proof( + circuit_name, + circuit_vk, + proof, + pi_evals, + &vm_proof.raw_pi, + &mut transcript, + NUM_FANIN, + &point_eval, + &challenges, + )?; if circuit_vk.get_cs().num_witin() > 0 { witin_openings.push(( input_opening_point.len(), @@ -278,18 +340,17 @@ impl> ZKVMVerifier } prod_w *= proof.w_out_evals.iter().flatten().copied().product::(); prod_r *= proof.r_out_evals.iter().flatten().copied().product::(); - tracing::debug!("verified proof for circuit {}", circuit_name); + tracing::debug!( + "{shard_id}th shard verified proof for circuit {}", + circuit_name + ); + if let Some(chip_shard_ec_sum) = chip_shard_ec_sum { + shard_ec_sum = shard_ec_sum + chip_shard_ec_sum; + } } logup_sum -= E::from_canonical_u64(dummy_table_item_multiplicity as u64) * dummy_table_item.inverse(); - // check logup relation across all proofs - if logup_sum != E::ZERO { - return Err(ZKVMError::VerifyError( - format!("logup_sum({:?}) != 0", logup_sum).into(), - )); - } - #[cfg(debug_assertions)] { Instrumented::<<::BaseField as PoseidonField>::P>::log_label( @@ -299,7 +360,14 @@ impl> ZKVMVerifier // verify mpcs let mut rounds = vec![(vm_proof.witin_commit.clone(), witin_openings)]; - if let Some(fixed_commit) = self.vk.fixed_commit.as_ref() { + + if let Some(fixed_commit) = self.vk.fixed_commit.as_ref() + && shard_id == 0 + { + rounds.push((fixed_commit.clone(), fixed_openings)); + } else if let Some(fixed_commit) = self.vk.fixed_no_omc_init_commit.as_ref() + && shard_id > 0 + { rounds.push((fixed_commit.clone(), fixed_openings)); } PCS::batch_verify( @@ -332,28 +400,38 @@ impl> ZKVMVerifier .right() .unwrap(); prod_r *= finalize_global_state; - // check rw_set equality across all proofs + + // check rw_set equality of shard proof if prod_r != prod_w { - return Err(ZKVMError::VerifyError("prod_r != prod_w".into())); + return Err(ZKVMError::VerifyError( + format!("{shard_id}th prod_r != prod_w").into(), + )); } - Ok(true) + // check logup sum of shard proof + if logup_sum != E::ZERO { + return Err(ZKVMError::VerifyError( + format!("{shard_id}th logup_sum({:?}) != 0", logup_sum).into(), + )); + } + + Ok(shard_ec_sum) } - // TODO: unify `verify_opcode_proof` and `verify_table_proof` /// verify proof and return input opening point - #[allow(clippy::too_many_arguments)] - pub fn verify_opcode_proof( + #[allow(clippy::too_many_arguments, clippy::type_complexity)] + pub fn verify_chip_proof( &self, _name: &str, circuit_vk: &VerifyingKey, proof: &ZKVMChipProof, pi: &[E], + raw_pi: &[Vec], transcript: &mut impl Transcript, num_product_fanin: usize, _out_evals: &PointAndEval, challenges: &[E; 2], // derive challenge from PCS - ) -> Result, ZKVMError> { + ) -> Result<(Point, Option>), ZKVMError> { let composed_cs = circuit_vk.get_cs(); let ComposedConstrainSystem { zkvm_v1_css: cs, @@ -363,7 +441,7 @@ impl> ZKVMVerifier let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = ( cs.r_expressions.len() + cs.r_table_expressions.len(), cs.w_expressions.len() + cs.w_table_expressions.len(), - cs.lk_expressions.len() + cs.lk_table_expressions.len() * 2, + cs.lk_expressions.len() + cs.lk_table_expressions.len(), ); let num_batched = r_counts_per_instance + w_counts_per_instance + lk_counts_per_instance; @@ -376,38 +454,78 @@ impl> ZKVMVerifier } let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0); + // constrain log2_num_instances within max length + cs.r_table_expressions + .iter() + .chain(&cs.w_table_expressions) + .for_each(|set_table_expr| { + // iterate through structural witins and collect max round. + let num_vars = set_table_expr + .table_spec + .len + .map(ceil_log2) + .unwrap_or_else(|| { + set_table_expr + .table_spec + .structural_witins + .iter() + .map(|StructuralWitIn { witin_type, .. }| { + let hint_num_vars = log2_num_instances; + assert!((1 << hint_num_vars) <= witin_type.max_len()); + hint_num_vars + }) + .max() + .unwrap_or(log2_num_instances) + }); + assert_eq!(num_vars, log2_num_instances); + }); + cs.lk_table_expressions.iter().for_each(|l| { + // iterate through structural witins and collect max round. + let num_vars = l.table_spec.len.map(ceil_log2).unwrap_or_else(|| { + l.table_spec + .structural_witins + .iter() + .map(|StructuralWitIn { witin_type, .. }| { + let hint_num_vars = log2_num_instances; + assert!((1 << hint_num_vars) <= witin_type.max_len()); + hint_num_vars + }) + .max() + .unwrap_or(log2_num_instances) + }); + assert_eq!(num_vars, log2_num_instances); + }); + // verify ecc proof if exists - if composed_cs.has_ecc_ops() { + let shard_ec_sum: Option> = if composed_cs.has_ecc_ops() { tracing::debug!("verifying ecc proof..."); assert!(proof.ecc_proof.is_some()); let ecc_proof = proof.ecc_proof.as_ref().unwrap(); - // TODO: enable this - // let xy = cs - // .ec_final_sum - // .iter() - // .map(|expr| { - // eval_by_expr_with_instance(&[], &[], &[], pi, challenges, &expr) - // .right() - // .and_then(|v| v.as_base()) - // .unwrap() - // }) - // .collect_vec(); - // let x: SepticExtension = xy[0..SEPTIC_EXTENSION_DEGREE].into(); - // let y: SepticExtension = xy[SEPTIC_EXTENSION_DEGREE..].into(); - - // assert_eq!( - // SepticPoint { - // x, - // y, - // is_infinity: false, - // }, - // ecc_proof.sum - // ); - // assert ec sum in public input matches that in ecc proof + let expected_septic_xy = cs + .ec_final_sum + .iter() + .map(|expr| { + eval_by_expr_with_instance(&[], &[], &[], pi, challenges, expr) + .right() + .and_then(|v| v.as_base()) + .unwrap() + }) + .collect_vec(); + let expected_septic_x: SepticExtension = + expected_septic_xy[0..SEPTIC_EXTENSION_DEGREE].into(); + let expected_septic_y: SepticExtension = + expected_septic_xy[SEPTIC_EXTENSION_DEGREE..].into(); + + assert_eq!(&ecc_proof.sum.x, &expected_septic_x); + assert_eq!(&ecc_proof.sum.y, &expected_septic_y); + assert!(!ecc_proof.sum.is_infinity); EccVerifier::verify_ecc_proof(ecc_proof, transcript)?; tracing::debug!("ecc proof verified."); - } + Some(ecc_proof.sum.clone()) + } else { + None + }; // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; @@ -426,18 +544,20 @@ impl> ZKVMVerifier transcript, )?; - // verify LogUp witness nominator p(x) ?= constant vector 1 - logup_p_evals - .iter() - .try_for_each(|PointAndEval { eval, .. }| { - if *eval != E::ONE { - Err(ZKVMError::VerifyError( - "Lookup table witness p(x) != constant 1".into(), - )) - } else { - Ok(()) - } - })?; + if cs.lk_table_expressions.is_empty() { + // verify LogUp witness nominator p(x) ?= constant vector 1 + logup_p_evals + .iter() + .try_for_each(|PointAndEval { eval, .. }| { + if *eval != E::ONE { + Err(ZKVMError::VerifyError( + "Lookup table witness p(x) != constant 1".into(), + )) + } else { + Ok(()) + } + })?; + } debug_assert!( chain!(&record_evals, &logup_p_evals, &logup_q_evals) @@ -445,13 +565,24 @@ impl> ZKVMVerifier .all_equal() ); - // verify zero statement (degree > 1) + sel sumcheck let num_rw_records = r_counts_per_instance + w_counts_per_instance; debug_assert_eq!(record_evals.len(), num_rw_records); debug_assert_eq!(logup_p_evals.len(), lk_counts_per_instance); debug_assert_eq!(logup_q_evals.len(), lk_counts_per_instance); + let evals = record_evals + .iter() + // append p_evals if there got lk table expressions + .chain(if cs.lk_table_expressions.is_empty() { + Either::Left(iter::empty()) + } else { + Either::Right(logup_p_evals.iter()) + }) + .chain(&logup_q_evals) + .cloned() + .collect_vec(); + let gkr_circuit = gkr_circuit.as_ref().unwrap(); let selector_ctxs = if cs.ec_final_sum.is_empty() { assert_eq!(proof.num_instances.len(), 1); @@ -491,238 +622,17 @@ impl> ZKVMVerifier }, ] }; - let GKRClaims(opening_evaluations) = gkr_circuit.verify( + let (_, rt) = gkr_circuit.verify( num_var_with_rotation, proof.gkr_iop_proof.clone().unwrap(), - &chain!(record_evals, logup_q_evals).collect_vec(), + &evals, pi, + raw_pi, challenges, transcript, &selector_ctxs, )?; - Ok(opening_evaluations[0].point.clone()) - } - - #[allow(clippy::too_many_arguments)] - pub fn verify_table_proof( - &self, - name: &str, - circuit_vk: &VerifyingKey, - proof: &ZKVMChipProof, - raw_pi: &[Vec], - pi: &[E], - transcript: &mut impl Transcript, - num_logup_fanin: usize, - _out_evals: &PointAndEval, - challenges: &[E; 2], - ) -> Result, ZKVMError> { - let ComposedConstrainSystem { - zkvm_v1_css: cs, .. - } = circuit_vk.get_cs(); - let with_rw = !cs.r_table_expressions.is_empty() && !cs.w_table_expressions.is_empty(); - if with_rw { - debug_assert!( - cs.r_table_expressions - .iter() - .zip_eq(cs.w_table_expressions.iter()) - .all(|(r, w)| r.table_spec.len == w.table_spec.len) - ); - } - let num_instances = proof.num_instances.iter().sum(); - let log2_num_instances = next_pow2_instance_padding(num_instances).ilog2() as usize; - - // verify and reduce product tower sumcheck - let tower_proofs = &proof.tower_proof; - - // NOTE: for all structural witness within same constrain system should got same hints num variable via `log2_num_instances` - let expected_rounds = interleave(&cs.r_table_expressions, &cs.w_table_expressions) - .map(|set_table_expr| { - // iterate through structural witins and collect max round. - let num_vars = set_table_expr - .table_spec - .len - .map(ceil_log2) - .unwrap_or_else(|| { - set_table_expr - .table_spec - .structural_witins - .iter() - .map(|StructuralWitIn { witin_type, .. }| { - let hint_num_vars = log2_num_instances; - assert!((1 << hint_num_vars) <= witin_type.max_len()); - hint_num_vars - }) - .max() - .unwrap() - }); - assert_eq!(num_vars, log2_num_instances); - num_vars - }) - .chain(cs.lk_table_expressions.iter().map(|l| { - // iterate through structural witins and collect max round. - let num_vars = l.table_spec.len.map(ceil_log2).unwrap_or_else(|| { - l.table_spec - .structural_witins - .iter() - .map(|StructuralWitIn { witin_type, .. }| { - let hint_num_vars = log2_num_instances; - assert!((1 << hint_num_vars) <= witin_type.max_len()); - hint_num_vars - }) - .max() - .unwrap() - }); - assert_eq!(num_vars, log2_num_instances); - num_vars - })) - .collect_vec(); - - let (rt_tower, prod_point_and_eval, logup_p_point_and_eval, logup_q_point_and_eval) = - TowerVerify::verify( - interleave(&proof.r_out_evals, &proof.w_out_evals) - .map(|eval| eval.to_vec()) - .collect_vec(), - proof - .lk_out_evals - .iter() - .map(|eval| eval.to_vec()) - .collect_vec(), - tower_proofs, - expected_rounds, - num_logup_fanin, - transcript, - )?; - - // TODO: return error instead of panic - assert_eq!( - logup_q_point_and_eval.len(), - cs.lk_table_expressions.len(), - "[lk_q_record] mismatch length" - ); - assert_eq!( - logup_p_point_and_eval.len(), - cs.lk_table_expressions.len(), - "[lk_p_record] mismatch length" - ); - assert_eq!( - prod_point_and_eval.len(), - cs.r_table_expressions.len() + cs.w_table_expressions.len(), - "[prod_record] mismatch length" - ); - - // TODO differentiate `ram_bus` via cs - let is_shard_ram_bus_circuit = false; - - let input_opening_point = if !is_shard_ram_bus_circuit { - // evaluate the evaluation of structural mles at input_opening_point by verifier - let structural_evals = if with_rw { - // only iterate r set, as read/write set round should match - Either::Left(cs.r_table_expressions.iter()) - } else { - Either::Right(cs.r_table_expressions.iter().chain(&cs.w_table_expressions)) - } - .map(|set_table_expr| &set_table_expr.table_spec) - .chain(cs.lk_table_expressions.iter().map(|r| &r.table_spec)) - .flat_map(|table_spec| { - table_spec - .structural_witins - .iter() - .map(|structural_witin| match structural_witin.witin_type { - StructuralWitInType::EqualDistanceSequence { - offset, - multi_factor, - descending, - .. - } => eval_wellform_address_vec( - offset as u64, - multi_factor as u64, - &rt_tower, - descending, - ), - StructuralWitInType::StackedIncrementalSequence { .. } => { - eval_stacked_wellform_address_vec(&rt_tower) - } - StructuralWitInType::StackedConstantSequence { .. } => { - eval_stacked_constant_vec(&rt_tower) - } - StructuralWitInType::InnerRepeatingIncrementalSequence { k, .. } => { - eval_inner_repeated_incremental_vec(k as u64, &rt_tower) - } - StructuralWitInType::OuterRepeatingIncrementalSequence { k, .. } => { - eval_outer_repeated_incremental_vec(k as u64, &rt_tower) - } - }) - .collect_vec() - }) - .collect_vec(); - - // verify records (degree = 1) statement, thus no sumcheck - let expected_evals = interleave( - &cs.r_table_expressions, // r - &cs.w_table_expressions, // w - ) - .map(|rw| &rw.expr) - .chain( - cs.lk_table_expressions - .iter() - .flat_map(|lk| vec![&lk.multiplicity, &lk.values]), // p, q - ) - .map(|expr| { - eval_by_expr_with_instance( - &proof.fixed_in_evals, - &proof.wits_in_evals, - &structural_evals, - pi, - challenges, - expr, - ) - .right() - .unwrap() - }) - .collect_vec(); - for (expected_eval, eval) in expected_evals.iter().zip( - prod_point_and_eval - .into_iter() - .chain( - logup_p_point_and_eval - .into_iter() - .zip_eq(logup_q_point_and_eval) - .flat_map(|(p_point_and_eval, q_point_and_eval)| { - [p_point_and_eval, q_point_and_eval] - }), - ) - .map(|point_and_eval| point_and_eval.eval), - ) { - if expected_eval != &eval { - return Err(ZKVMError::VerifyError( - format!("table {name} evaluation mismatch {expected_eval:?} != {eval:?}") - .into(), - )); - } - } - rt_tower - } else { - unimplemented!("shard ram bus circuit go here"); - }; - - // assume public io is tiny vector, so we evaluate it directly without PCS - for &Instance(idx) in cs.instance_name_map.keys() { - let poly = raw_pi[idx].to_vec().into_mle(); - let expected_eval = poly.evaluate(&input_opening_point[..poly.num_vars()]); - let eval = pi[idx]; - if expected_eval != eval { - return Err(ZKVMError::VerifyError( - format!("pub input on index {idx} mismatch {expected_eval:?} != {eval:?}") - .into(), - )); - } - tracing::trace!( - "[table {name}] verified public inputs on index {idx} with point {:?}", - input_opening_point - ); - } - - Ok(input_opening_point) + Ok((rt, shard_ec_sum)) } } @@ -1042,9 +952,7 @@ impl EccVerifier { // this value doesn't matter, as we only need structural id StackedConstantSequence { max_value: 0 }, )); - let mut sel_evals = vec![E::ZERO]; - sel_add_expr.evaluate( - &mut sel_evals, + let Some((expected_sel_add, _)) = sel_add_expr.evaluate( &out_rt, &rt, &SelectorContext { @@ -1052,9 +960,9 @@ impl EccVerifier { num_instances: proof.num_instances, num_vars, }, - 0, - ); - let expected_sel_add = sel_evals[0]; + ) else { + unreachable!() + }; if proof.evals[0] != expected_sel_add { return Err(ZKVMError::VerifyError( diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 79661d728..c10d0d843 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -8,22 +8,23 @@ use crate::{ }, scheme::septic_curve::SepticPoint, state::StateCircuit, - tables::{RMMCollections, TableCircuit}, + tables::{MemFinalRecord, RMMCollections, TableCircuit}, }; -use ceno_emul::{CENO_PLATFORM, Platform, StepRecord}; +use ceno_emul::{CENO_PLATFORM, Platform, RegIdx, StepRecord, WordAddr}; use ff_ext::{ExtensionField, PoseidonField}; use gkr_iop::{gkr::GKRCircuit, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{Expression, Instance}; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rustc_hash::FxHashSet; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ collections::{BTreeMap, HashMap}, sync::Arc, }; use sumcheck::structs::{IOPProof, IOPProverMessage}; -use witness::RowMajorMatrix; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; /// proof that the sum of N=2^n EC points is equal to `sum` /// in one layer instead of GKR layered circuit approach @@ -152,17 +153,15 @@ impl ComposedConstrainSystem { self.zkvm_v1_css.w_expressions.len() + self.zkvm_v1_css.w_table_expressions.len() } + pub fn instance_openings(&self) -> &[Instance] { + &self.zkvm_v1_css.instance_openings + } pub fn has_ecc_ops(&self) -> bool { !self.zkvm_v1_css.ec_final_sum.is_empty() } - pub fn instance_name_map(&self) -> &HashMap { - &self.zkvm_v1_css.instance_name_map - } - - pub fn is_opcode_circuit(&self) -> bool { - // TODO: is global chip opcode circuit?? - self.gkr_circuit.is_some() || self.has_ecc_ops() + pub fn is_with_lk_table(&self) -> bool { + !self.zkvm_v1_css.lk_table_expressions.is_empty() } /// return number of lookup operation @@ -185,6 +184,10 @@ impl ComposedConstrainSystem { .as_ref() .map(|param| param.rotation_cyclic_subgroup_size) } + + pub fn with_omc_init_only(&self) -> bool { + self.zkvm_v1_css.with_omc_init_only + } } #[derive(Clone)] @@ -429,10 +432,68 @@ impl ZKVMWitnesses { pub fn assign_global_chip_circuit( &mut self, cs: &ZKVMConstraintSystem, - shard_ctx: &ShardContext, + // shard_ctx: &ShardContext, + (shard_ctx, final_mem): &( + &ShardContext, + &[(InstancePaddingStrategy, &[MemFinalRecord])], + ), config: & as TableCircuit>::TableConfig, ) -> Result<(), ZKVMError> { let perm = ::get_default_perm(); + let waddr_first_access = if shard_ctx.is_first_shard() { + shard_ctx.get_addr_accessed_first_shard() + } else { + FxHashSet::default() + }; + + let non_first_shard_records = if shard_ctx.is_first_shard() { + final_mem + .par_iter() + .flat_map_iter(|(_, final_mem)| { + final_mem.iter().filter_map(|mem_record| { + // prepare global writes record for those record which not accessed in first record + // but access in future shard + let (waddr, addr): (WordAddr, u32) = match mem_record.ram_type { + RAMType::Register => ( + Platform::register_vma(mem_record.addr as RegIdx).into(), + mem_record.addr, + ), + RAMType::Memory => (mem_record.addr.into(), mem_record.addr), + _ => unimplemented!(), + }; + if !waddr_first_access.contains(&waddr) + && shard_ctx.after_current_shard_cycle(mem_record.cycle) + { + let global_write = GlobalRecord { + addr: match mem_record.ram_type { + RAMType::Register => addr, + RAMType::Memory => waddr.into(), + _ => unimplemented!(), + }, + ram_type: mem_record.ram_type, + // fill initial value to cancel initial record + value: mem_record.init_value, + shard: 0, + local_clk: 0, + global_clk: 0, + is_to_write_set: true, + }; + let ec_point: GlobalPoint = global_write.to_ec_point(&perm); + Some(GlobalChipInput { + record: global_write, + ec_point, + }) + } else { + None + } + }) + }) + .collect() + } else { + vec![] + }; + let non_first_shard_records_len = non_first_shard_records.len(); + let global_input = shard_ctx .write_records() .par_iter() @@ -447,6 +508,7 @@ impl ZKVMWitnesses { } }) }) + .chain(non_first_shard_records.into_par_iter()) .chain( shard_ctx .read_records() @@ -464,6 +526,7 @@ impl ZKVMWitnesses { }), ) .collect::>(); + assert!(self.combined_lk_mlt.is_some()); let cs = cs.get_cs(&GlobalChip::::name()).unwrap(); let witness = GlobalChip::assign_instances( @@ -484,7 +547,8 @@ impl ZKVMWitnesses { .write_records() .iter() .map(|records| records.len()) - .sum(), + .sum::() + + non_first_shard_records_len, // global read -> local write shard_ctx .read_records() @@ -528,10 +592,26 @@ impl ZKVMWitnesses { pub struct ZKVMProvingKey> { pub pp: PCS::ProverParam, pub vp: PCS::VerifierParam, + // entry program counter + pub entry_pc: u32, // pk for opcode and table circuits pub circuit_pks: BTreeMap>, + + // Fixed commitments are separated into two groups: + // + // 1. `fixed_commit_*` + // - Used by the *main circuit* for offline memory check (OMC) table initialization. + // - This initialization occurs **only in the first shard** (`shard_id = 0`). + // + // 2. `fixed_no_omc_init_commit_*` + // - Used by subsequent shards (`shard_id > 0`), which **omit** OMC table initialization. + // - All circuit components related to OMC init are skipped in these shards. pub fixed_commit_wd: Option>::CommitmentWithWitness>>, pub fixed_commit: Option<>::Commitment>, + pub fixed_no_omc_init_commit_wd: + Option>::CommitmentWithWitness>>, + pub fixed_no_omc_init_commit: Option<>::Commitment>, + pub circuit_index_fixed_num_instances: BTreeMap, // expression for global state in/out @@ -544,18 +624,22 @@ impl> ZKVMProvingKey::BaseField>>, + fixed_traces_no_omc_init: BTreeMap::BaseField>>, ) -> Result<(), ZKVMError> { if !fixed_traces.is_empty() { let fixed_commit_wd = @@ -568,20 +652,40 @@ impl> ZKVMProvingKey> ZKVMProvingKey { pub fn get_vk_slow(&self) -> ZKVMVerifyingKey { ZKVMVerifyingKey { vp: self.vp.clone(), + entry_pc: self.entry_pc, circuit_vks: self .circuit_pks .iter() .map(|(name, pk)| (name.clone(), pk.vk.clone())) .collect(), fixed_commit: self.fixed_commit.clone(), + fixed_no_omc_init_commit: self.fixed_no_omc_init_commit.clone(), // expression for global state in/out initial_global_state_expr: self.initial_global_state_expr.clone(), finalize_global_state_expr: self.finalize_global_state_expr.clone(), @@ -602,9 +706,12 @@ impl> ZKVMProvingKey> { pub vp: PCS::VerifierParam, + // entry program counter + pub entry_pc: u32, // vk for opcode and table circuits pub circuit_vks: BTreeMap>, pub fixed_commit: Option<>::Commitment>, + pub fixed_no_omc_init_commit: Option<>::Commitment>, // expression for global state in/out pub initial_global_state_expr: Expression, pub finalize_global_state_expr: Expression, diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index d55a6a907..fd8fa2278 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -1,6 +1,12 @@ use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, structs::ProgramParams}; use ff_ext::ExtensionField; -use gkr_iop::gkr::GKRCircuit; +use gkr_iop::{ + chip::Chip, + gkr::{GKRCircuit, layer::Layer}, + selector::SelectorType, +}; +use itertools::Itertools; +use multilinear_extensions::ToExpr; use std::collections::HashMap; use witness::RowMajorMatrix; @@ -36,7 +42,43 @@ pub trait TableCircuit { param: &ProgramParams, ) -> Result<(Self::TableConfig, Option>), ZKVMError> { let config = Self::construct_circuit(cb, param)?; - Ok((config, None)) + let r_table_len = cb.cs.r_table_expressions.len(); + let w_table_len = cb.cs.w_table_expressions.len(); + let lk_table_len = cb.cs.lk_table_expressions.len() * 2; + + let selector = cb.create_placeholder_structural_witin(|| "selector"); + let selector_type = SelectorType::Whole(selector.expr()); + + // all shared the same selector + let (out_evals, mut chip) = ( + [ + // r_record + (0..r_table_len).collect_vec(), + // w_record + (r_table_len..r_table_len + w_table_len).collect_vec(), + // lk_record + (r_table_len + w_table_len..r_table_len + w_table_len + lk_table_len).collect_vec(), + // zero_record + vec![], + ], + Chip::new_from_cb(cb, 0), + ); + + // register selector to legacy constrain system + if r_table_len > 0 { + cb.cs.r_selector = Some(selector_type.clone()); + } + if w_table_len > 0 { + cb.cs.w_selector = Some(selector_type.clone()); + } + if lk_table_len > 0 { + cb.cs.lk_selector = Some(selector_type.clone()); + } + + let layer = Layer::from_circuit_builder(cb, Self::name(), 0, out_evals); + chip.add_layer(layer); + + Ok((config, Some(chip.gkr_circuit()))) } fn generate_fixed_traces( diff --git a/ceno_zkvm/src/tables/ops/ops_circuit.rs b/ceno_zkvm/src/tables/ops/ops_circuit.rs index d98e05360..7939a1d81 100644 --- a/ceno_zkvm/src/tables/ops/ops_circuit.rs +++ b/ceno_zkvm/src/tables/ops/ops_circuit.rs @@ -22,7 +22,7 @@ impl TableCircuit for OpsTableCircuit type WitnessInput = (); fn name() -> String { - format!("OPS_{:?}", OP::ROM_TYPE) + format!("{:?}_OPS_ROM_TABLE", OP::ROM_TYPE) } fn construct_circuit( diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index 2f365142e..72b80a548 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -73,19 +73,31 @@ impl OpTableConfig { multiplicity: &HashMap, length: usize, ) -> Result, CircuitBuilderError> { - assert_eq!(num_structural_witin, 0); + assert_eq!(num_structural_witin, 1); + let num_structural_witin = num_structural_witin.max(1); + let mut witness = RowMajorMatrix::::new(length, num_witin, InstancePaddingStrategy::Default); + let mut structural_witness = RowMajorMatrix::::new( + length, + num_structural_witin, + InstancePaddingStrategy::Default, + ); let mut mlts = vec![0; length]; for (idx, mlt) in multiplicity { mlts[*idx as usize] = *mlt; } - witness.par_rows_mut().zip(mlts).for_each(|(row, mlt)| { - set_val!(row, self.mlt, F::from_v(mlt as u64)); - }); + witness + .par_rows_mut() + .zip_eq(structural_witness.par_rows_mut()) + .zip(mlts) + .for_each(|((row, structural_row), mlt)| { + set_val!(row, self.mlt, F::from_v(mlt as u64)); + *structural_row.last_mut().unwrap() = F::ONE; + }); - Ok([witness, RowMajorMatrix::empty()]) + Ok([witness, structural_witness]) } } diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 833663e74..a71818040 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -16,7 +16,9 @@ use multilinear_extensions::{Expression, Fixed, ToExpr, WitIn}; use p3::field::FieldAlgebra; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; use std::{collections::HashMap, marker::PhantomData}; -use witness::{InstancePaddingStrategy, RowMajorMatrix, set_fixed_val, set_val}; +use witness::{ + InstancePaddingStrategy, RowMajorMatrix, next_pow2_instance_padding, set_fixed_val, set_val, +}; /// This structure establishes the order of the fields in instruction records, common to the program table and circuit fetches. #[cfg(not(feature = "u16limb_circuit"))] @@ -269,9 +271,11 @@ impl TableCircuit for ProgramTableCircuit { multiplicity: &[HashMap], program: &Program, ) -> Result, ZKVMError> { + assert!(!program.instructions.is_empty()); + assert!(num_structural_witin == 0 || num_structural_witin == 1); let multiplicity = &multiplicity[ROMType::Instruction as usize]; - let mut prog_mlt = vec![0_usize; program.instructions.len()]; + let mut prog_mlt = vec![0_usize; next_pow2_instance_padding(program.instructions.len())]; for (pc, mlt) in multiplicity { let i = (*pc as usize - program.base_address as usize) / WORD_SIZE; prog_mlt[i] = *mlt; @@ -279,18 +283,28 @@ impl TableCircuit for ProgramTableCircuit { let mut witness = RowMajorMatrix::::new( config.program_size, - num_witin + num_structural_witin, + num_witin, InstancePaddingStrategy::Default, ); - witness.par_rows_mut().zip(prog_mlt).for_each(|(row, mlt)| { - set_val!( - row, - config.mlt, - E::BaseField::from_canonical_u64(mlt as u64) - ); - }); + let mut structural_witness = RowMajorMatrix::::new( + config.program_size, + 1, + InstancePaddingStrategy::Default, + ); + witness + .par_rows_mut() + .zip_eq(structural_witness.par_rows_mut()) + .zip(prog_mlt) + .for_each(|((row, structural_row), mlt)| { + set_val!( + row, + config.mlt, + E::BaseField::from_canonical_u64(mlt as u64) + ); + *structural_row.last_mut().unwrap() = E::BaseField::ONE; + }); - Ok([witness, RowMajorMatrix::empty()]) + Ok([witness, structural_witness]) } } diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index 6075b0440..e2bdc4dec 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -1,5 +1,5 @@ use ceno_emul::{Addr, VMState, WORD_SIZE}; -use ram_circuit::{DynVolatileRamCircuit, NonVolatileRamCircuit, PubIORamCircuit}; +use ram_circuit::{DynVolatileRamCircuit, NonVolatileRamCircuit, PubIORamInitCircuit}; use crate::{ instructions::riscv::constants::UINT_LIMBS, @@ -10,9 +10,7 @@ mod ram_circuit; mod ram_impl; use crate::tables::ram::{ ram_circuit::{LocalFinalRamCircuit, RamBusCircuit}, - ram_impl::{ - DynVolatileRamTableConfig, DynVolatileRamTableInitConfig, NonVolatileInitTableConfig, - }, + ram_impl::{DynVolatileRamTableInitConfig, NonVolatileInitTableConfig}, }; pub use ram_circuit::{DynVolatileRamTable, MemFinalRecord, MemInitRecord, NonVolatileTable}; @@ -96,8 +94,8 @@ impl DynVolatileRamTable for HintsTable { "HintsTable" } } -pub type HintsCircuit = - DynVolatileRamCircuit>; +pub type HintsInitCircuit = + DynVolatileRamCircuit>; /// RegTable, fix size without offset #[derive(Clone)] @@ -157,6 +155,6 @@ impl NonVolatileTable for PubIOTable { } } -pub type PubIOCircuit = PubIORamCircuit; +pub type PubIOInitCircuit = PubIORamInitCircuit; pub type LocalFinalCircuit<'a, E> = LocalFinalRamCircuit<'a, UINT_LIMBS, E>; pub type RBCircuit<'a, E> = RamBusCircuit<'a, UINT_LIMBS, E>; diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 344a8d891..458a17301 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, marker::PhantomData}; use super::ram_impl::{ - LocalFinalRAMTableConfig, NonVolatileTableConfigTrait, PubIOTableConfig, RAMBusConfig, + LocalFinalRAMTableConfig, NonVolatileTableConfigTrait, PubIOTableInitConfig, RAMBusConfig, }; use crate::{ circuit_builder::CircuitBuilder, @@ -19,7 +19,7 @@ use gkr_iop::{ selector::SelectorType, }; use itertools::Itertools; -use multilinear_extensions::{StructuralWitInType, ToExpr}; +use multilinear_extensions::ToExpr; use witness::{InstancePaddingStrategy, RowMajorMatrix}; #[derive(Clone, Debug)] @@ -34,6 +34,10 @@ pub struct MemFinalRecord { pub addr: Addr, pub cycle: Cycle, pub value: Word, + // initial state value + // same as `value` for read-only table + // probably different for rw table + pub init_value: Word, } impl GetAddr for MemInitRecord { @@ -126,14 +130,14 @@ impl< /// This circuit does not and cannot decide whether the memory is mutable or not. /// It supports LOAD where the program reads the public input, /// or STORE where the memory content must equal the public input after execution. -pub struct PubIORamCircuit(PhantomData<(E, R)>); +pub struct PubIORamInitCircuit(PhantomData<(E, R)>); impl TableCircuit - for PubIORamCircuit + for PubIORamInitCircuit { - type TableConfig = PubIOTableConfig; + type TableConfig = PubIOTableInitConfig; type FixedInput = [Addr]; - type WitnessInput = [Cycle]; + type WitnessInput = [MemFinalRecord]; fn name() -> String { format!("RAM_{:?}_{}", NVRAM::RAM_TYPE, NVRAM::name()) @@ -143,6 +147,7 @@ impl TableCirc cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { + cb.set_omc_init_only(); Ok(cb.namespace( || Self::name(), |cb| Self::TableConfig::construct_circuit(cb, params), @@ -163,10 +168,10 @@ impl TableCirc num_witin: usize, num_structural_witin: usize, _multiplicity: &[HashMap], - final_cycles: &[Cycle], + final_mem: &[MemFinalRecord], ) -> Result, ZKVMError> { // assume returned table is well-formed including padding - Ok(config.assign_instances(num_witin, num_structural_witin, final_cycles)?) + Ok(config.assign_instances(num_witin, num_structural_witin, final_mem)?) } } @@ -239,7 +244,7 @@ impl< type WitnessInput = [MemFinalRecord]; fn name() -> String { - format!("RAM_{:?}_{}", DVRAM::RAM_TYPE, DVRAM::name()) + format!("{}_{:?}_RAM", DVRAM::name(), DVRAM::RAM_TYPE,) } fn construct_circuit( @@ -310,16 +315,7 @@ impl<'a, E: ExtensionField, const V_LIMBS: usize> TableCircuit let config = Self::construct_circuit(cb, param)?; let r_table_len = cb.cs.r_table_expressions.len(); - let selector = cb.create_structural_witin( - || "selector", - StructuralWitInType::EqualDistanceSequence { - // TODO determin proper size of max length - max_len: u32::MAX as usize, - offset: 0, - multi_factor: 0, - descending: false, - }, - ); + let selector = cb.create_placeholder_structural_witin(|| "selector"); let selector_type = SelectorType::Prefix(selector.expr()); // all shared the same selector @@ -340,7 +336,7 @@ impl<'a, E: ExtensionField, const V_LIMBS: usize> TableCircuit // register selector to legacy constrain system cb.cs.r_selector = Some(selector_type.clone()); - let layer = Layer::from_circuit_builder(cb, "Rounds".to_string(), 0, out_evals); + let layer = Layer::from_circuit_builder(cb, Self::name(), 0, out_evals); chip.add_layer(layer); Ok((config, Some(chip.gkr_circuit()))) diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 554c71235..a73e176de 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -1,10 +1,11 @@ -use ceno_emul::{Addr, Cycle, WORD_SIZE}; +use ceno_emul::{Addr, WORD_SIZE}; use either::Either; use ff_ext::{ExtensionField, SmallField}; use gkr_iop::error::CircuitBuilderError; use itertools::Itertools; use rayon::iter::{ - IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, + IntoParallelRefMutIterator, ParallelExtend, ParallelIterator, }; use std::marker::PhantomData; use witness::{ @@ -71,6 +72,7 @@ impl NonVolatileTableConfigTrait< cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { + cb.set_omc_init_only(); assert!(NVRAM::WRITABLE); let init_v = (0..NVRAM::V_LIMBS) .map(|i| cb.create_fixed(|| format!("init_v_limb_{i}"))) @@ -144,29 +146,37 @@ impl NonVolatileTableConfigTrait< /// TODO consider taking RowMajorMatrix as argument to save allocations. fn assign_instances( - _config: &Self::Config, + config: &Self::Config, _num_witin: usize, num_structural_witin: usize, - _final_mem: &[MemFinalRecord], + final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - assert_eq!(num_structural_witin, 0); - Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]) + if final_mem.is_empty() { + return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); + } + assert!(num_structural_witin == 0 || num_structural_witin == 1); + let mut value = Vec::with_capacity(NVRAM::len(&config.params)); + value.par_extend( + (0..NVRAM::len(&config.params)) + .into_par_iter() + .map(|_| F::ONE), + ); + let structural_witness = + RowMajorMatrix::::new_by_values(value, 1, InstancePaddingStrategy::Default); + Ok([RowMajorMatrix::empty(), structural_witness]) } } /// define public io /// init value set by instance #[derive(Clone, Debug)] -pub struct PubIOTableConfig { +pub struct PubIOTableInitConfig { addr: Fixed, - - final_cycle: WitIn, - phantom: PhantomData, params: ProgramParams, } -impl PubIOTableConfig { +impl PubIOTableInitConfig { pub fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, @@ -175,8 +185,6 @@ impl PubIOTableConfig { let init_v = cb.query_public_io()?; let addr = cb.create_fixed(|| "addr"); - let final_cycle = cb.create_witin(|| "final_cycle"); - let init_table = [ vec![(NVRAM::RAM_TYPE as usize).into()], vec![Expression::Fixed(addr)], @@ -185,15 +193,6 @@ impl PubIOTableConfig { ] .concat(); - let final_table = [ - // a v t - vec![(NVRAM::RAM_TYPE as usize).into()], - vec![Expression::Fixed(addr)], - init_v.iter().map(|v| v.expr_as_instance()).collect_vec(), - vec![final_cycle.expr()], - ] - .concat(); - cb.w_table_record( || "init_table", NVRAM::RAM_TYPE, @@ -203,19 +202,9 @@ impl PubIOTableConfig { }, init_table, )?; - cb.r_table_record( - || "final_table", - NVRAM::RAM_TYPE, - SetTableSpec { - len: Some(NVRAM::len(params)), - structural_witins: vec![], - }, - final_table, - )?; Ok(Self { addr, - final_cycle, phantom: PhantomData, params: params.clone(), }) @@ -248,181 +237,23 @@ impl PubIOTableConfig { /// TODO consider taking RowMajorMatrix as argument to save allocations. pub fn assign_instances( &self, - num_witin: usize, - num_structural_witin: usize, - final_cycles: &[Cycle], - ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - assert_eq!(num_structural_witin, 0); - let mut final_table = RowMajorMatrix::::new( - NVRAM::len(&self.params), - num_witin, - InstancePaddingStrategy::Default, - ); - - final_table - .par_rows_mut() - .zip_eq(final_cycles) - .for_each(|(row, &cycle)| { - set_val!(row, self.final_cycle, cycle); - }); - - Ok([final_table, RowMajorMatrix::empty()]) - } -} - -/// volatile with all init value as 0 -/// dynamic address as witin, relied on augment of knowledge to prove address form -#[derive(Clone, Debug)] -pub struct DynVolatileRamTableConfig { - addr: StructuralWitIn, - - final_v: Vec, - final_cycle: WitIn, - - phantom: PhantomData, - params: ProgramParams, -} - -impl DynVolatileRamTableConfigTrait - for DynVolatileRamTableConfig -{ - type Config = DynVolatileRamTableConfig; - fn construct_circuit( - cb: &mut CircuitBuilder, - params: &ProgramParams, - ) -> Result { - let max_len = DVRAM::max_len(params); - let addr = cb.create_structural_witin( - || "addr", - StructuralWitInType::EqualDistanceSequence { - max_len, - offset: DVRAM::offset_addr(params), - multi_factor: WORD_SIZE, - descending: DVRAM::DESCENDING, - }, - ); - - let final_v = (0..DVRAM::V_LIMBS) - .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) - .collect::>(); - let final_cycle = cb.create_witin(|| "final_cycle"); - - let final_expr = final_v.iter().map(|v| v.expr()).collect_vec(); - let init_expr = if DVRAM::ZERO_INIT { - vec![Expression::ZERO; DVRAM::V_LIMBS] - } else { - final_expr.clone() - }; - - let init_table = [ - vec![(DVRAM::RAM_TYPE as usize).into()], - vec![addr.expr()], - init_expr, - vec![Expression::ZERO], // Initial cycle. - ] - .concat(); - - let final_table = [ - // a v t - vec![(DVRAM::RAM_TYPE as usize).into()], - vec![addr.expr()], - final_expr, - vec![final_cycle.expr()], - ] - .concat(); - - cb.w_table_record( - || "init_table", - DVRAM::RAM_TYPE, - SetTableSpec { - len: None, - structural_witins: vec![addr], - }, - init_table, - )?; - cb.r_table_record( - || "final_table", - DVRAM::RAM_TYPE, - SetTableSpec { - len: None, - structural_witins: vec![addr], - }, - final_table, - )?; - - Ok(Self { - addr, - final_v, - final_cycle, - phantom: PhantomData, - params: params.clone(), - }) - } - - /// TODO consider taking RowMajorMatrix as argument to save allocations. - fn assign_instances( - config: &Self::Config, - num_witin: usize, + _num_witin: usize, num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { if final_mem.is_empty() { return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); } - - let num_instances_padded = next_pow2_instance_padding(final_mem.len()); - assert!(num_instances_padded <= DVRAM::max_len(&config.params)); - assert!(DVRAM::max_len(&config.params).is_power_of_two()); - - let mut witness = RowMajorMatrix::::new( - num_instances_padded, - num_witin, - InstancePaddingStrategy::Default, - ); - let mut structural_witness = RowMajorMatrix::::new( - num_instances_padded, - num_structural_witin, - InstancePaddingStrategy::Default, + assert!(num_structural_witin == 0 || num_structural_witin == 1); + let mut value = Vec::with_capacity(NVRAM::len(&self.params)); + value.par_extend( + (0..NVRAM::len(&self.params)) + .into_par_iter() + .map(|_| F::ONE), ); - - witness - .par_rows_mut() - .zip_eq(structural_witness.par_rows_mut()) - .enumerate() - .for_each(|(i, (row, structural_row))| { - if cfg!(debug_assertions) - && let Some(addr) = final_mem.get(i).map(|rec| rec.addr) - { - debug_assert_eq!( - addr, - DVRAM::addr(&config.params, i), - "rec.addr {:x} != expected {:x}", - addr, - DVRAM::addr(&config.params, i), - ); - } - - if let Some(rec) = final_mem.get(i) { - if config.final_v.len() == 1 { - // Assign value directly. - set_val!(row, config.final_v[0], rec.value as u64); - } else { - // Assign value limbs. - config.final_v.iter().enumerate().for_each(|(l, limb)| { - let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; - set_val!(row, limb, val as u64); - }); - } - set_val!(row, config.final_cycle, rec.cycle); - } - set_val!( - structural_row, - config.addr, - DVRAM::addr(&config.params, i) as u64 - ); - }); - - Ok([witness, structural_witness]) + let structural_witness = + RowMajorMatrix::::new_by_values(value, 1, InstancePaddingStrategy::Default); + Ok([RowMajorMatrix::empty(), structural_witness]) } } @@ -432,6 +263,8 @@ impl DynVolatileRamTableConfig pub struct DynVolatileRamTableInitConfig { addr: StructuralWitIn, + init_v: Option>, + phantom: PhantomData, params: ProgramParams, } @@ -445,6 +278,7 @@ impl DynVolatileRamTableConfig cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { + cb.set_omc_init_only(); let max_len = DVRAM::max_len(params); let addr = cb.create_structural_witin( || "addr", @@ -456,9 +290,14 @@ impl DynVolatileRamTableConfig }, ); - assert!(DVRAM::ZERO_INIT); - - let init_expr = vec![Expression::ZERO; DVRAM::V_LIMBS]; + let (init_expr, init_v) = if DVRAM::ZERO_INIT { + (vec![Expression::ZERO; DVRAM::V_LIMBS], None) + } else { + let init_v = (0..DVRAM::V_LIMBS) + .map(|i| cb.create_witin(|| format!("init_v_limb_{i}"))) + .collect::>(); + (init_v.iter().map(|v| v.expr()).collect_vec(), Some(init_v)) + }; let init_table = [ vec![(DVRAM::RAM_TYPE as usize).into()], @@ -467,7 +306,6 @@ impl DynVolatileRamTableConfig vec![Expression::ZERO], // Initial cycle. ] .concat(); - cb.w_table_record( || "init_table", DVRAM::RAM_TYPE, @@ -480,6 +318,7 @@ impl DynVolatileRamTableConfig Ok(Self { addr, + init_v, phantom: PhantomData, params: params.clone(), }) @@ -488,47 +327,100 @@ impl DynVolatileRamTableConfig /// TODO consider taking RowMajorMatrix as argument to save allocations. fn assign_instances( config: &Self::Config, - _num_witin: usize, + num_witin: usize, num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { if final_mem.is_empty() { return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); } + assert_eq!(num_structural_witin, 2); let num_instances_padded = next_pow2_instance_padding(final_mem.len()); assert!(num_instances_padded <= DVRAM::max_len(&config.params)); assert!(DVRAM::max_len(&config.params).is_power_of_two()); - let mut structural_witness = RowMajorMatrix::::new( - num_instances_padded, - num_structural_witin, - InstancePaddingStrategy::Default, - ); + // got some duplicated code segment to simplify parallel assignment flow + if let Some(init_v) = config.init_v.as_ref() { + let mut witness = RowMajorMatrix::::new( + num_instances_padded, + num_witin, + InstancePaddingStrategy::Default, + ); + let mut structural_witness = RowMajorMatrix::::new( + num_instances_padded, + num_structural_witin, + InstancePaddingStrategy::Default, + ); - structural_witness - .par_rows_mut() - .enumerate() - .for_each(|(i, structural_row)| { - if cfg!(debug_assertions) - && let Some(addr) = final_mem.get(i).map(|rec| rec.addr) - { - debug_assert_eq!( - addr, - DVRAM::addr(&config.params, i), - "rec.addr {:x} != expected {:x}", - addr, - DVRAM::addr(&config.params, i), + witness + .par_rows_mut() + .zip_eq(structural_witness.par_rows_mut()) + .enumerate() + .for_each(|(i, (row, structural_row))| { + if cfg!(debug_assertions) + && let Some(addr) = final_mem.get(i).map(|rec| rec.addr) + { + debug_assert_eq!( + addr, + DVRAM::addr(&config.params, i), + "rec.addr {:x} != expected {:x}", + addr, + DVRAM::addr(&config.params, i), + ); + } + if let Some(rec) = final_mem.get(i) { + if init_v.len() == 1 { + // Assign value directly. + set_val!(row, init_v[0], rec.init_value as u64); + } else { + // Assign value limbs. + init_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.init_value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }); + } + } + set_val!( + structural_row, + config.addr, + DVRAM::addr(&config.params, i) as u64 ); - } - set_val!( - structural_row, - config.addr, - DVRAM::addr(&config.params, i) as u64 - ); - }); + *structural_row.last_mut().unwrap() = F::ONE; + }); - Ok([RowMajorMatrix::empty(), structural_witness]) + Ok([witness, structural_witness]) + } else { + let mut structural_witness = RowMajorMatrix::::new( + num_instances_padded, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + + structural_witness + .par_rows_mut() + .enumerate() + .for_each(|(i, structural_row)| { + if cfg!(debug_assertions) + && let Some(addr) = final_mem.get(i).map(|rec| rec.addr) + { + debug_assert_eq!( + addr, + DVRAM::addr(&config.params, i), + "rec.addr {:x} != expected {:x}", + addr, + DVRAM::addr(&config.params, i), + ); + } + set_val!( + structural_row, + config.addr, + DVRAM::addr(&config.params, i) as u64 + ); + *structural_row.last_mut().unwrap() = F::ONE; + }); + Ok([RowMajorMatrix::empty(), structural_witness]) + } } } @@ -595,7 +487,6 @@ impl LocalFinalRAMTableConfig { ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { assert!(num_structural_witin == 0 || num_structural_witin == 1); let num_structural_witin = num_structural_witin.max(1); - let selector_witin = WitIn { id: 0 }; let is_current_shard_mem_record = |record: &&MemFinalRecord| -> bool { (shard_ctx.is_first_shard() && record.cycle == 0) @@ -667,6 +558,8 @@ impl LocalFinalRAMTableConfig { structural_witness_value_rest = structural_witness_r; } + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + witness_mut_slices .par_iter_mut() .zip_eq(structural_witness_mut_slices.par_iter_mut()) @@ -692,11 +585,12 @@ impl LocalFinalRAMTableConfig { set_val!(row, limb, val as u64); }); } - set_val!(row, self.final_cycle, rec.cycle); + let shard_cycle = rec.cycle - current_shard_offset_cycle; + set_val!(row, self.final_cycle, shard_cycle); set_val!(row, self.ram_type, rec.ram_type as u64); set_val!(row, self.addr_subset, rec.addr as u64); - set_val!(structural_row, selector_witin, 1u64); + *structural_row.last_mut().unwrap() = F::ONE; }) .count(); @@ -723,7 +617,7 @@ impl LocalFinalRAMTableConfig { pad_func(pad_index as u64, self.addr_subset.id as u64) ); set_val!(row, self.ram_type, *ram_type as u64); - set_val!(structural_row, selector_witin, 1u64); + *structural_row.last_mut().unwrap() = F::ONE; }); } _ => unimplemented!(), @@ -1030,7 +924,7 @@ mod tests { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, structs::ProgramParams, - tables::{DynVolatileRamTable, HintsCircuit, HintsTable, MemFinalRecord, TableCircuit}, + tables::{DynVolatileRamTable, HintsInitCircuit, HintsTable, MemFinalRecord, TableCircuit}, witness::LkMultiplicity, }; @@ -1046,7 +940,8 @@ mod tests { fn test_well_formed_address_padding() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let config = HintsCircuit::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let (config, _) = + HintsInitCircuit::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()).unwrap(); let def_params = ProgramParams::default(); let lkm = LkMultiplicity::default().into_finalize_result(); @@ -1059,9 +954,10 @@ mod tests { addr: HintsTable::addr(&def_params, i), cycle: 0, value: 0, + init_value: 0, }) .collect_vec(); - let [_, mut structural_witness] = HintsCircuit::::assign_instances( + let [_, mut structural_witness] = HintsInitCircuit::::assign_instances( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -1074,7 +970,7 @@ mod tests { .cs .structural_witin_namespace_map .iter() - .position(|name| name == "riscv/RAM_Memory_HintsTable/addr") + .position(|name| name == "riscv/HintsTable_Memory_RAM/addr") .unwrap(); structural_witness.padding_by_strategy(); diff --git a/ceno_zkvm/src/tables/range.rs b/ceno_zkvm/src/tables/range.rs index 06b56774b..3105b897a 100644 --- a/ceno_zkvm/src/tables/range.rs +++ b/ceno_zkvm/src/tables/range.rs @@ -3,9 +3,7 @@ mod range_impl; mod range_circuit; -pub use range_circuit::{ - DoubleRangeTableCircuit, DynamicRangeTableCircuit, RangeTable, RangeTableCircuit, -}; +pub use range_circuit::{DoubleRangeTableCircuit, DynamicRangeTableCircuit, RangeTable}; use crate::ROMType; diff --git a/ceno_zkvm/src/tables/range/range_circuit.rs b/ceno_zkvm/src/tables/range/range_circuit.rs index 95814b97d..67ea2da87 100644 --- a/ceno_zkvm/src/tables/range/range_circuit.rs +++ b/ceno_zkvm/src/tables/range/range_circuit.rs @@ -1,7 +1,5 @@ //! Range tables as circuits with trait TableCircuit. -use super::range_impl::RangeTableConfig; - use std::{collections::HashMap, marker::PhantomData}; use crate::{ @@ -14,7 +12,14 @@ use crate::{ }, }; use ff_ext::ExtensionField; -use gkr_iop::tables::LookupTable; +use gkr_iop::{ + chip::Chip, + gkr::{GKRCircuit, layer::Layer}, + selector::SelectorType, + tables::LookupTable, +}; +use itertools::Itertools; +use multilinear_extensions::ToExpr; use witness::{InstancePaddingStrategy, RowMajorMatrix}; /// Use this trait as parameter to RangeTableCircuit. @@ -28,54 +33,6 @@ pub trait RangeTable { } } -pub struct RangeTableCircuit(PhantomData<(E, R)>); - -impl TableCircuit for RangeTableCircuit { - type TableConfig = RangeTableConfig; - type FixedInput = (); - type WitnessInput = (); - - fn name() -> String { - format!("RANGE_{:?}", RANGE::ROM_TYPE) - } - - fn construct_circuit( - cb: &mut CircuitBuilder, - _params: &ProgramParams, - ) -> Result { - Ok(cb.namespace( - || Self::name(), - |cb| RangeTableConfig::construct_circuit(cb, RANGE::ROM_TYPE, RANGE::len()), - )?) - } - - fn generate_fixed_traces( - _config: &RangeTableConfig, - _num_fixed: usize, - _input: &(), - ) -> RowMajorMatrix { - RowMajorMatrix::::new(0, 0, InstancePaddingStrategy::Default) - } - - fn assign_instances( - config: &Self::TableConfig, - num_witin: usize, - num_structural_witin: usize, - multiplicity: &[HashMap], - _input: &(), - ) -> Result, ZKVMError> { - let multiplicity = &multiplicity[RANGE::ROM_TYPE as usize]; - - Ok(config.assign_instances( - num_witin, - num_structural_witin, - multiplicity, - RANGE::content(), - RANGE::len(), - )?) - } -} - pub struct DynamicRangeTableCircuit(PhantomData); impl TableCircuit @@ -146,6 +103,40 @@ impl, + param: &ProgramParams, + ) -> Result<(Self::TableConfig, Option>), ZKVMError> { + let config = Self::construct_circuit(cb, param)?; + let lk_table_len = cb.cs.lk_table_expressions.len() * 2; + + let selector = cb.create_placeholder_structural_witin(|| "selector"); + let selector_type = SelectorType::Whole(selector.expr()); + + // all shared the same selector + let (out_evals, mut chip) = ( + [ + // r_record + vec![], + // w_record + vec![], + // lk_record + (0..lk_table_len).collect_vec(), + // zero_record + vec![], + ], + Chip::new_from_cb(cb, 0), + ); + + // register selector to legacy constrain system + cb.cs.lk_selector = Some(selector_type.clone()); + + let layer = Layer::from_circuit_builder(cb, Self::name(), 0, out_evals); + chip.add_layer(layer); + + Ok((config, Some(chip.gkr_circuit()))) + } + fn generate_fixed_traces( _config: &DoubleRangeTableConfig, _num_fixed: usize, diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index 535f50000..a95664085 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -12,80 +12,6 @@ use crate::{ }; use multilinear_extensions::{StructuralWitIn, StructuralWitInType, ToExpr, WitIn}; -#[derive(Clone, Debug)] -pub struct RangeTableConfig { - range: StructuralWitIn, - mlt: WitIn, -} - -impl RangeTableConfig { - pub fn construct_circuit( - cb: &mut CircuitBuilder, - rom_type: ROMType, - table_len: usize, - ) -> Result { - let range = cb.create_structural_witin( - || "structural range witin", - StructuralWitInType::EqualDistanceSequence { - max_len: table_len, - offset: 0, - multi_factor: 1, - descending: false, - }, - ); - let mlt = cb.create_witin(|| "mlt"); - - let record_exprs = vec![range.expr()]; - - cb.lk_table_record( - || "record", - SetTableSpec { - len: Some(table_len), - structural_witins: vec![range], - }, - rom_type, - record_exprs, - mlt.expr(), - )?; - - Ok(Self { range, mlt }) - } - - pub fn assign_instances( - &self, - num_witin: usize, - num_structural_witin: usize, - multiplicity: &HashMap, - content: Vec, - length: usize, - ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - let mut witness: RowMajorMatrix = - RowMajorMatrix::::new(length, num_witin, InstancePaddingStrategy::Default); - let mut structural_witness = RowMajorMatrix::::new( - length, - num_structural_witin, - InstancePaddingStrategy::Default, - ); - - let mut mlts = vec![0; length]; - for (idx, mlt) in multiplicity { - mlts[*idx as usize] = *mlt; - } - - witness - .par_rows_mut() - .zip(structural_witness.par_rows_mut()) - .zip(mlts) - .zip(content) - .for_each(|(((row, structural_row), mlt), i)| { - set_val!(row, self.mlt, F::from_canonical_u64(mlt as u64)); - set_val!(structural_row, self.range, F::from_canonical_u64(i)); - }); - - Ok([witness, structural_witness]) - } -} - #[derive(Clone, Debug)] pub struct DynamicRangeTableConfig { range: StructuralWitIn, @@ -167,6 +93,7 @@ impl DynamicRangeTableConfig { set_val!(row, self.mlt, F::from_canonical_u64(*mlt as u64)); set_val!(structural_row, self.range, i); set_val!(structural_row, self.bits, b); + *structural_row.last_mut().unwrap() = F::ONE; }); Ok([witness, structural_witness]) @@ -257,6 +184,7 @@ impl DoubleRangeTableConfig { set_val!(row, self.mlt, F::from_canonical_u64(*mlt as u64)); set_val!(structural_row, self.range_a, F::from_canonical_usize(a)); set_val!(structural_row, self.range_b, F::from_canonical_usize(b)); + *structural_row.last_mut().unwrap() = F::ONE; }); Ok([witness, structural_witness]) diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 96b4eb969..276622cf7 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -81,107 +81,6 @@ pub fn u64vec(x: u64) -> [u64; W] { ret } -/// evaluate MLE M(x0, x1, x2, ..., xn) address vector with it evaluation format -/// on r = [r0, r1, r2, ...rn] succinctly -/// where `M = descending * scaled * M' + offset` -/// offset, scaled, is constant, descending = +1/-1 -/// and M' := [0, 1, 2, 3, ....2^n-1] -/// succinctly format of M'(r) = r0 + r1 * 2 + r2 * 2^2 + .... rn * 2^n -pub fn eval_wellform_address_vec( - offset: u64, - scaled: u64, - r: &[E], - descending: bool, -) -> E { - let (offset, scaled) = (E::from_canonical_u64(offset), E::from_canonical_u64(scaled)); - let tmp = scaled - * r.iter() - .scan(E::ONE, |state, x| { - let result = *x * *state; - *state *= E::from_canonical_u64(2); // Update the state for the next power of 2 - Some(result) - }) - .sum::(); - let tmp = if descending { tmp.neg() } else { tmp }; - offset + tmp -} - -/// Evaluate MLE with the following evaluation over the hypercube: -/// [0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, ..., 0, 1, 2, ..., 2^n-1] -/// which is the concatenation of -/// [0] -/// [0, 1] -/// [0, 1, 2, 3] -/// ... -/// [0, 1, 2, ..., 2^n-1] -/// which is then prefixed by a single zero to make all the subvectors aligned to powers of two. -/// This function is used to support dynamic range check. -/// Note that this MLE has n+1 variables, so r should have length n+1. -/// -/// conceptually, we traverse evaluations in the sequence: -/// [0, 0], [0, 1], [0, 1, 2, 3], ... -/// for every `next` element is already in a well-formed incremental structure, -/// so we can reuse `eval_wellform_address_vec` to obtain its value. -/// -/// at each step `i`, we combine: -/// - the accumulated result so far, weighted by `(1 - r[i])` -/// - the evaluation of the current prefix `r[..i]`, weighted by `r[i]`. -/// -/// this iterative version avoids recursion for efficiency and clarity. -pub fn eval_stacked_wellform_address_vec(r: &[E]) -> E { - if r.len() < 2 { - return E::ZERO; - } - - let mut res = E::ZERO; - for i in 1..r.len() { - res = res * (E::ONE - r[i]) + eval_wellform_address_vec(0, 1, &r[..i], false) * r[i]; - } - res -} - -/// Evaluate MLE with the following evaluation over the hypercube: -/// [0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, ..., n, n, n, ..., n] -/// which is the concatenation of -/// [0] -/// [1, 1] -/// [2, 2, 2, 2] -/// ... -/// [n, n, n, ..., n] -/// which is then prefixed by a single zero to make all the subvectors aligned to powers of two. -/// This function is used to support dynamic range check. -/// Note that this MLE has n+1 variables, so r should have length n+1. -pub fn eval_stacked_constant_vec(r: &[E]) -> E { - if r.len() < 2 { - return E::ZERO; - } - - let mut res = E::ZERO; - for (i, r) in r.iter().enumerate().skip(1) { - res = res * (E::ONE - *r) + E::from_canonical_usize(i) * *r; - } - res -} - -/// evaluate MLE M(x0, x1, x2, ..., xn) address vector with it evaluation format -/// on r = [r0, r1, r2, ...rn] succinctly -/// where `M = [0 ... 0 1 ... 1 ... 2^(n-k)-1 ... 2^(n-k)-1]` -/// where each element is repeated 2^k times -/// The value is the same as M(xk, xk+1, ..., xn), i.e., just abandoning -/// the first k elements from r -pub fn eval_inner_repeated_incremental_vec(k: u64, r: &[E]) -> E { - eval_wellform_address_vec(0, 1, &r[k as usize..], false) -} - -/// evaluate MLE M(x0, x1, x2, ..., xn) address vector with it evaluation format -/// on r = [r0, r1, r2, ...rn] succinctly -/// where `M = [0 1 ... 2^k-1] * 2^(n-k)` -/// The value is the same as M(x0, ..., xk), i.e., just taking -/// the first k elements from r -pub fn eval_outer_repeated_incremental_vec(k: u64, r: &[E]) -> E { - eval_wellform_address_vec(0, 1, &r[..k as usize], false) -} - pub fn display_hashmap(map: &HashMap) -> String { format!( "[{}]", @@ -265,104 +164,3 @@ pub fn print_allocated_bytes() { let allocated = stats::allocated::read().unwrap(); tracing::info!("jemalloc total allocated bytes: {}", allocated); } - -#[cfg(test)] -mod tests { - use std::iter; - - use ff_ext::GoldilocksExt2; - use p3::field::FieldAlgebra; - - use super::*; - - type E = GoldilocksExt2; - use multilinear_extensions::mle::MultilinearExtension; - - #[test] - fn test_eval_stacked_wellform_address_vec() { - let r = [ - E::from_canonical_usize(123), - E::from_canonical_usize(456), - E::from_canonical_usize(789), - E::from_canonical_usize(3210), - E::from_canonical_usize(9876), - ]; - for n in 0..r.len() { - let v = iter::once(E::ZERO) - .chain((0..=n).flat_map(|i| (0..(1 << i)).map(E::from_canonical_usize))) - .collect::>(); - let poly = MultilinearExtension::from_evaluations_ext_vec(n + 1, v); - assert_eq!( - eval_stacked_wellform_address_vec(&r[0..=n]), - poly.evaluate(&r[0..=n]) - ) - } - } - - #[test] - fn test_eval_stacked_constant_vec() { - let r = [ - E::from_canonical_usize(123), - E::from_canonical_usize(456), - E::from_canonical_usize(789), - E::from_canonical_usize(3210), - E::from_canonical_usize(9876), - ]; - for n in 0..r.len() { - let v = iter::once(E::ZERO) - .chain((0..=n).flat_map(|i| iter::repeat_n(i, 1 << i).map(E::from_canonical_usize))) - .collect::>(); - let poly = MultilinearExtension::from_evaluations_ext_vec(n + 1, v); - assert_eq!( - eval_stacked_constant_vec(&r[0..=n]), - poly.evaluate(&r[0..=n]) - ) - } - } - - #[test] - fn test_eval_inner_repeating_incremental_vec() { - let r = [ - E::from_canonical_usize(123), - E::from_canonical_usize(456), - E::from_canonical_usize(789), - E::from_canonical_usize(3210), - E::from_canonical_usize(9876), - ]; - for n in 1..=r.len() { - for k in 0..=n { - let v = (0..(1 << (n - k))) - .flat_map(|i| iter::repeat_n(E::from_canonical_usize(i), 1 << k)) - .collect::>(); - let poly = MultilinearExtension::from_evaluations_ext_vec(n, v); - assert_eq!( - eval_inner_repeated_incremental_vec(k as u64, &r[0..n]), - poly.evaluate(&r[0..n]) - ) - } - } - } - - #[test] - fn test_eval_outer_repeating_incremental_vec() { - let r = [ - E::from_canonical_usize(123), - E::from_canonical_usize(456), - E::from_canonical_usize(789), - E::from_canonical_usize(3210), - E::from_canonical_usize(9876), - ]; - for n in 1..=r.len() { - for k in 0..=n { - let v = iter::repeat_n(0, 1 << (n - k)) - .flat_map(|_| (0..(1 << k)).map(E::from_canonical_usize)) - .collect::>(); - let poly = MultilinearExtension::from_evaluations_ext_vec(n, v); - assert_eq!( - eval_outer_repeated_incremental_vec(k as u64, &r[0..n]), - poly.evaluate(&r[0..n]) - ) - } - } - } -} diff --git a/gkr_iop/src/chip.rs b/gkr_iop/src/chip.rs index 10048418e..2b72f251e 100644 --- a/gkr_iop/src/chip.rs +++ b/gkr_iop/src/chip.rs @@ -44,7 +44,8 @@ impl Chip { + cb.cs.r_table_expressions.len() + cb.cs.lk_table_expressions.len() * 2 + cb.cs.num_fixed - + cb.cs.num_witin as usize, + + cb.cs.num_witin as usize + + cb.cs.instance_openings.len(), final_out_evals: (0..cb.cs.w_expressions.len() + cb.cs.r_expressions.len() + cb.cs.lk_expressions.len() diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 70de7f171..088f1f42c 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -96,12 +96,13 @@ pub struct ConstraintSystem { pub witin_namespace_map: Vec, pub num_structural_witin: WitnessId, + pub structural_witins: Vec, pub structural_witin_namespace_map: Vec, pub num_fixed: usize, pub fixed_namespace_map: Vec, - pub instance_name_map: HashMap, + pub instance_openings: Vec, pub ec_point_exprs: Vec>, pub ec_slope_exprs: Vec>, @@ -126,6 +127,9 @@ pub struct ConstraintSystem { pub r_table_expressions_namespace_map: Vec, pub w_table_expressions: Vec>, pub w_table_expressions_namespace_map: Vec, + // specify whether constrains system cover only init_w + // as it imply w/r set and final_w might happen ACROSS shards + pub with_omc_init_only: bool, pub lk_selector: Option>, /// lookup expression @@ -166,11 +170,12 @@ impl ConstraintSystem { // platform, witin_namespace_map: vec![], num_structural_witin: 0, + structural_witins: vec![], structural_witin_namespace_map: vec![], num_fixed: 0, fixed_namespace_map: vec![], ns: NameSpace::new(root_name_fn), - instance_name_map: HashMap::new(), + instance_openings: vec![], ec_final_sum: vec![], ec_slope_exprs: vec![], ec_point_exprs: vec![], @@ -186,6 +191,7 @@ impl ConstraintSystem { r_table_expressions_namespace_map: vec![], w_table_expressions: vec![], w_table_expressions_namespace_map: vec![], + with_omc_init_only: false, lk_selector: None, lk_expressions: vec![], lk_table_expressions: vec![], @@ -227,6 +233,7 @@ impl ConstraintSystem { id: self.num_structural_witin, witin_type, }; + self.structural_witins.push(wit_in); self.num_structural_witin = self.num_structural_witin.strict_add(1); let path = self.ns.compute_path(n().into()); @@ -235,6 +242,13 @@ impl ConstraintSystem { wit_in } + pub fn create_placeholder_structural_witin, N: FnOnce() -> NR>( + &mut self, + n: N, + ) -> StructuralWitIn { + self.create_structural_witin(n, StructuralWitInType::Empty) + } + pub fn create_fixed, N: FnOnce() -> NR>(&mut self, n: N) -> Fixed { let f = Fixed(self.num_fixed); self.num_fixed += 1; @@ -245,17 +259,25 @@ impl ConstraintSystem { f } - pub fn query_instance, N: FnOnce() -> NR>( + pub fn query_instance(&self, idx: usize) -> Result { + let i = Instance(idx); + Ok(i) + } + + pub fn query_instance_for_openings( &mut self, - n: N, idx: usize, ) -> Result { let i = Instance(idx); - let name = n().into(); - self.instance_name_map.insert(i, name); + assert!( + !self.instance_openings.contains(&i), + "query same pubio idx {idx} mle more than once", + ); + self.instance_openings.push(i); - Ok(i) + // return instance only count + Ok(Instance(self.instance_openings.len() - 1)) } pub fn rlc_chip_record(&self, items: Vec>) -> Expression { @@ -524,6 +546,10 @@ impl ConstraintSystem { self.ns.pop_namespace(); t } + + pub fn set_omc_init_only(&mut self) { + self.with_omc_init_only = true; + } } impl ConstraintSystem { @@ -606,6 +632,14 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.create_structural_witin(name_fn, witin_type) } + pub fn create_placeholder_structural_witin(&mut self, name_fn: N) -> StructuralWitIn + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs.create_placeholder_structural_witin(name_fn) + } + pub fn create_fixed(&mut self, name_fn: N) -> Fixed where NR: Into, @@ -1279,6 +1313,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn rotate_and_assert_eq(&mut self, a: Expression, b: Expression) { self.cs.rotations.push((a, b)); } + + pub fn set_omc_init_only(&mut self) { + self.cs.set_omc_init_only(); + } } /// take items from an iterator until the accumulated "weight" (measured by `f`) diff --git a/gkr_iop/src/cpu/mod.rs b/gkr_iop/src/cpu/mod.rs index 9b2864150..65d736020 100644 --- a/gkr_iop/src/cpu/mod.rs +++ b/gkr_iop/src/cpu/mod.rs @@ -3,12 +3,13 @@ use crate::{ gkr::layer::Layer, hal::{MultilinearPolynomial, ProtocolWitnessGeneratorProver, ProverBackend, ProverDevice}, }; +use either::Either; use ff_ext::ExtensionField; use itertools::izip; use mpcs::{PolynomialCommitmentScheme, SecurityLevel, SecurityLevel::Conjecture100bits}; use multilinear_extensions::{ macros::{entered_span, exit_span}, - mle::{ArcMultilinearExtension, MultilinearExtension, Point}, + mle::{MultilinearExtension, Point}, wit_infer_by_monomial_expr, }; use p3::field::TwoAdicField; @@ -109,7 +110,7 @@ impl> fn layer_witness<'a>( layer: &Layer, layer_wits: &[Arc< as ProverBackend>::MultilinearPoly<'a>>], - pub_io_evals: &[Arc< as ProverBackend>::MultilinearPoly<'a>>], + pub_io_evals: &[Either], challenges: &[E], ) -> Vec as ProverBackend>::MultilinearPoly<'a>>> { let span = entered_span!("witness_infer", profiling_2 = true); @@ -148,49 +149,3 @@ impl> res } } - -#[tracing::instrument(skip_all, name = "layer_witness", fields(profiling_2), level = "trace")] -pub fn layer_witness<'a, E>( - layer: &Layer, - layer_wits: &[ArcMultilinearExtension<'a, E>], - pub_io_evals: &[ArcMultilinearExtension<'a, E>], - challenges: &[E], -) -> Vec> -where - E: ExtensionField, -{ - let span = entered_span!("witness_infer", profiling_2 = true); - let out_evals: Vec<_> = layer - .out_sel_and_eval_exprs - .iter() - .flat_map(|(sel_type, out_eval)| izip!(iter::repeat(sel_type), out_eval.iter())) - .collect(); - let res = layer - .exprs_with_selector_out_eval_monomial_form - .par_iter() - .zip_eq(layer.expr_names.par_iter()) - .zip_eq(out_evals.par_iter()) - .map(|((expr, expr_name), (_, out_eval))| { - if cfg!(debug_assertions) - && let EvalExpression::Zero = out_eval - { - assert!( - wit_infer_by_monomial_expr(expr, layer_wits, pub_io_evals, challenges) - .evaluations() - .is_zero(), - "layer name: {}, expr name: \"{expr_name}\" got non_zero mle", - layer.name - ); - }; - match out_eval { - EvalExpression::Linear(_, _, _) | EvalExpression::Single(_) => { - wit_infer_by_monomial_expr(expr, layer_wits, pub_io_evals, challenges) - } - EvalExpression::Zero => MultilinearExtension::default().into(), - EvalExpression::Partition(_, _) => unimplemented!(), - } - }) - .collect::>(); - exit_span!(span); - res -} diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index b06e8fe71..b025aa1e4 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -45,6 +45,7 @@ pub struct GKRCircuitOutput<'a, PB: ProverBackend>(pub LayerWitness<'a, PB>); pub struct GKRProverOutput { pub gkr_proof: GKRProof, pub opening_evaluations: Vec, + pub rt: Vec>, } #[derive(Clone, Serialize, Deserialize)] @@ -85,7 +86,7 @@ impl GKRCircuit { running_evals.resize(self.n_evaluations, PointAndEval::default()); let mut challenges = challenges.to_vec(); let span = entered_span!("layer_proof", profiling_2 = true); - let sumcheck_proofs = izip!(&self.layers, circuit_wit.layers) + let (sumcheck_proofs, rt): (Vec<_>, Vec<_>) = izip!(&self.layers, circuit_wit.layers) .enumerate() .map(|(i, (layer, layer_wit))| { tracing::debug!("prove layer {i} layer with layer name {}", layer.name); @@ -103,7 +104,7 @@ impl GKRCircuit { exit_span!(span); res }) - .collect_vec(); + .unzip(); exit_span!(span); let opening_evaluations = self.opening_evaluations(&running_evals); @@ -111,6 +112,7 @@ impl GKRCircuit { Ok(GKRProverOutput { gkr_proof: GKRProof(sumcheck_proofs), opening_evaluations, + rt, }) } @@ -121,10 +123,11 @@ impl GKRCircuit { gkr_proof: GKRProof, out_evals: &[PointAndEval], pub_io_evals: &[E], + raw_pi: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], - ) -> Result>, BackendError> + ) -> Result<(GKRClaims>, Point), BackendError> where E: ExtensionField, { @@ -133,20 +136,24 @@ impl GKRCircuit { let mut challenges = challenges.to_vec(); let mut evaluations = out_evals.to_vec(); evaluations.resize(self.n_evaluations, PointAndEval::default()); - for (i, (layer, layer_proof)) in izip!(&self.layers, sumcheck_proofs).enumerate() { - tracing::debug!("verifier layer {i} layer with layer name {}", layer.name); - layer.verify( - max_num_variables, - layer_proof, - &mut evaluations, - pub_io_evals, - &mut challenges, - transcript, - selector_ctxs, - )?; - } - - Ok(GKRClaims(self.opening_evaluations(&evaluations))) + let rt = izip!(&self.layers, sumcheck_proofs).enumerate().try_fold( + vec![], + |_, (i, (layer, layer_proof))| { + tracing::debug!("verifier layer {i} layer with layer name {}", layer.name); + let rt = layer.verify( + max_num_variables, + layer_proof, + &mut evaluations, + pub_io_evals, + raw_pi, + &mut challenges, + transcript, + selector_ctxs, + )?; + Ok(rt) + }, + )?; + Ok((GKRClaims(self.opening_evaluations(&evaluations)), rt)) } /// Output opening evaluations. First witin and then fixed. diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 22312497d..64eb747be 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -3,7 +3,7 @@ use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; use linear_layer::{LayerClaims, LinearLayer}; use multilinear_extensions::{ - Expression, ToExpr, + Expression, Instance, StructuralWitIn, ToExpr, mle::{Point, PointAndEval}, monomial::Term, }; @@ -59,7 +59,12 @@ pub struct Layer { pub n_witin: usize, pub n_structural_witin: usize, pub n_fixed: usize, + pub n_instance: usize, pub max_expr_degree: usize, + /// keep all structural witin which could be evaluated succinctly without PCS + pub structural_witins: Vec, + /// instance openings + pub instance_openings: Vec, /// num challenges dedicated to this layer. pub n_challenges: usize, /// Expressions to prove in this layer. For zerocheck and linear layers, @@ -124,6 +129,7 @@ impl Layer { n_witin: usize, n_structural_witin: usize, n_fixed: usize, + n_instance: usize, // exprs concat zero/non-zero expression. exprs: Vec>, n_challenges: usize, @@ -136,6 +142,8 @@ impl Layer { usize, ), expr_names: Vec, + structural_witins: Vec, + instance_openings: Vec, ) -> Self { assert_eq!(expr_names.len(), exprs.len(), "there are expr without name"); let max_expr_degree = exprs @@ -152,7 +160,10 @@ impl Layer { n_witin, n_structural_witin, n_fixed, + n_instance, max_expr_degree, + structural_witins, + instance_openings, n_challenges, exprs, exprs_with_selector_out_eval_monomial_form: vec![], @@ -185,7 +196,7 @@ impl Layer { challenges: &mut Vec, transcript: &mut T, selector_ctxs: &[SelectorContext], - ) -> LayerProof { + ) -> (LayerProof, Point) { self.update_challenges(challenges, transcript); let mut eval_and_dedup_points = self.extract_claim_and_point(claims, challenges); @@ -220,7 +231,7 @@ impl Layer { self.update_claims(claims, &sumcheck_layer_proof.main.evals, &point); - sumcheck_layer_proof + (sumcheck_layer_proof, point) } #[allow(clippy::too_many_arguments)] @@ -230,10 +241,11 @@ impl Layer { proof: LayerProof, claims: &mut [PointAndEval], pub_io_evals: &[E], + raw_pi: &[Vec], challenges: &mut Vec, transcript: &mut Trans, selector_ctxs: &[SelectorContext], - ) -> Result<(), BackendError> { + ) -> Result, BackendError> { self.update_challenges(challenges, transcript); let mut eval_and_dedup_points = self.extract_claim_and_point(claims, challenges); @@ -244,6 +256,7 @@ impl Layer { proof, eval_and_dedup_points, pub_io_evals, + raw_pi, challenges, transcript, selector_ctxs, @@ -264,7 +277,7 @@ impl Layer { self.update_claims(claims, &evals, &in_point); - Ok(()) + Ok(in_point) } // extract claim and dudup point @@ -465,7 +478,7 @@ impl Layer { } = &cb.cs; let in_eval_expr = (non_zero_expr_len..) - .take(cb.cs.num_witin as usize + cb.cs.num_fixed) + .take(cb.cs.num_witin as usize + cb.cs.num_fixed + cb.cs.instance_openings.len()) .collect_vec(); if rotations.is_empty() { Layer::new( @@ -474,12 +487,15 @@ impl Layer { cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, cb.cs.num_fixed, + cb.cs.instance_openings.len(), expressions, n_challenges, in_eval_expr, expr_evals, ((None, vec![]), 0, 0), expr_names, + cb.cs.structural_witins.clone(), + cb.cs.instance_openings.clone(), ) } else { let Some(RotationParams { @@ -496,6 +512,7 @@ impl Layer { cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, cb.cs.num_fixed, + cb.cs.instance_openings.len(), expressions, n_challenges, in_eval_expr, @@ -506,6 +523,8 @@ impl Layer { *rotation_cyclic_subgroup_size, ), expr_names, + cb.cs.structural_witins.clone(), + cb.cs.instance_openings.clone(), ) } } diff --git a/gkr_iop/src/gkr/layer/cpu/mod.rs b/gkr_iop/src/gkr/layer/cpu/mod.rs index 882e224ee..00e472bf0 100644 --- a/gkr_iop/src/gkr/layer/cpu/mod.rs +++ b/gkr_iop/src/gkr/layer/cpu/mod.rs @@ -209,28 +209,40 @@ impl> ZerocheckLayerProver .collect::>(); exit_span!(span); - // `wit` := witin ++ fixed + // `wit` := witin ++ fixed ++ pubio // we concat eq in between `wit` := witin ++ eqs ++ fixed let all_witins = wit .iter() - .take(layer.n_witin) + .take(layer.n_witin + layer.n_fixed + layer.n_instance) .map(|mle| Either::Left(mle.as_ref())) - .chain(eqs.iter_mut().map(Either::Right)) .chain( - // fixed, start after `n_witin` + // some non-selector structural witin wit.iter() - .skip(layer.n_witin + layer.n_structural_witin) + .skip(layer.n_witin + layer.n_fixed + layer.n_instance) + .take( + layer.n_structural_witin + - layer.out_sel_and_eval_exprs.len() + - layer + .rotation_exprs + .0 + .as_ref() + .map(|_| ROTATION_OPENING_COUNT) + .unwrap_or(0), + ) .map(|mle| Either::Left(mle.as_ref())), ) + .chain(eqs.iter_mut().map(Either::Right)) .collect_vec(); + assert_eq!( all_witins.len(), - layer.n_witin + layer.n_structural_witin + layer.n_fixed, - "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {}", + layer.n_witin + layer.n_structural_witin + layer.n_fixed + layer.n_instance, + "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {} + layer.n_instance {}", all_witins.len(), layer.n_witin, layer.n_structural_witin, layer.n_fixed, + layer.n_instance, ); let builder = diff --git a/gkr_iop/src/gkr/layer/gpu/mod.rs b/gkr_iop/src/gkr/layer/gpu/mod.rs index d9380c511..67ad38d71 100644 --- a/gkr_iop/src/gkr/layer/gpu/mod.rs +++ b/gkr_iop/src/gkr/layer/gpu/mod.rs @@ -197,18 +197,39 @@ impl> ZerocheckLayerProver .map(|rotation_point| build_eq_x_r_gpu(&cuda_hal, rotation_point)), ) .collect::>(); + // `wit` := witin ++ fixed ++ pubio let all_witins_gpu = wit .iter() - .take(layer.n_witin) + .take(layer.n_witin + layer.n_fixed + layer.n_instance) .map(|mle| mle.as_ref()) - .chain(eqs_gpu.iter()) .chain( - // fixed, start after `n_witin` + // some non-selector structural witin wit.iter() - .skip(layer.n_witin + layer.n_structural_witin) + .skip(layer.n_witin + layer.n_fixed + layer.n_instance) + .take( + layer.n_structural_witin + - layer.out_sel_and_eval_exprs.len() + - layer + .rotation_exprs + .0 + .as_ref() + .map(|_| ROTATION_OPENING_COUNT) + .unwrap_or(0), + ) .map(|mle| mle.as_ref()), ) + .chain(eqs_gpu.iter()) .collect_vec(); + assert_eq!( + all_witins_gpu.len(), + layer.n_witin + layer.n_structural_witin + layer.n_fixed + layer.n_instance, + "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {} + layer.n_instance {}", + all_witins_gpu.len(), + layer.n_witin, + layer.n_structural_witin, + layer.n_fixed, + layer.n_instance, + ); // Calculate max_num_var and max_degree from the extracted relationships let (term_coefficients, mle_indices_per_term, mle_size_info) = extract_mle_relationships_from_monomial_terms( diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index d9f13a2a9..95069b5e9 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -1,11 +1,11 @@ use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; use multilinear_extensions::{ - ChallengeId, Expression, ToExpr, WitnessId, + ChallengeId, Expression, StructuralWitIn, StructuralWitInType, ToExpr, WitnessId, macros::{entered_span, exit_span}, - mle::Point, + mle::{IntoMLE, Point}, monomialize_expr_to_wit_terms, - utils::{eval_by_expr, eval_by_expr_with_instance}, + utils::{eval_by_expr, eval_by_expr_with_instance, expr_convert_to_witins}, virtual_poly::VPAuxInfo, }; use p3::field::{FieldAlgebra, dot_product}; @@ -28,7 +28,11 @@ use crate::{ }, hal::{ProverBackend, ProverDevice}, selector::{SelectorContext, SelectorType}, - utils::rotation_selector_eval, + utils::{ + eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, + eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, + rotation_selector_eval, + }, }; pub(crate) struct RotationPoints { @@ -68,6 +72,7 @@ pub trait ZerocheckLayer { proof: LayerProof, eval_and_dedup_points: Vec<(Vec, Option>)>, pub_io_evals: &[E], + raw_pi: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -128,8 +133,8 @@ impl ZerocheckLayer for Layer { monomialize_expr_to_wit_terms( &expr, self.n_witin as WitnessId, - self.n_structural_witin as WitnessId, self.n_fixed as WitnessId, + self.n_instance, ) }) .collect::>(); @@ -139,10 +144,9 @@ impl ZerocheckLayer for Layer { .take(self.exprs.len() + num_rotations * ROTATION_OPENING_COUNT) .map(|id| Expression::Challenge(id as ChallengeId, 1, E::ONE, E::ZERO)) .collect_vec(); - let zero_expr = - extend_exprs_with_rotation(self, &alpha_pows_expr, self.n_witin as WitnessId) - .into_iter() - .sum::>(); + let mut zero_expr = extend_exprs_with_rotation(self, &alpha_pows_expr) + .into_iter() + .sum::>(); self.rotation_sumcheck_expression = rotation_expr.clone(); self.rotation_sumcheck_expression_monomial_terms = @@ -150,21 +154,22 @@ impl ZerocheckLayer for Layer { monomialize_expr_to_wit_terms( expr, self.n_witin as WitnessId, - self.n_structural_witin as WitnessId, self.n_fixed as WitnessId, + self.n_instance, ) }); + expr_convert_to_witins( + &mut zero_expr, + self.n_witin as WitnessId, + self.n_fixed as WitnessId, + self.n_instance, + ); self.main_sumcheck_expression = Some(zero_expr); - self.main_sumcheck_expression_monomial_terms = - self.main_sumcheck_expression.as_ref().map(|expr| { - monomialize_expr_to_wit_terms( - expr, - self.n_witin as WitnessId, - self.n_structural_witin as WitnessId, - self.n_fixed as WitnessId, - ) - }); + self.main_sumcheck_expression_monomial_terms = self + .main_sumcheck_expression + .as_ref() + .map(|expr| expr.get_monomial_terms()); exit_span!(span); } @@ -198,6 +203,7 @@ impl ZerocheckLayer for Layer { proof: LayerProof, mut eval_and_dedup_points: Vec<(Vec, Option>)>, pub_io_evals: &[E], + raw_pi: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -213,11 +219,17 @@ impl ZerocheckLayer for Layer { main: SumcheckLayerProof { proof: IOPProof { proofs }, - evals: mut main_evals, + evals: main_evals, }, rotation: rotation_proof, } = proof; + assert_eq!( + main_evals.len(), + self.n_witin + self.n_fixed + self.n_instance + self.n_structural_witin, + "invalid main_evals length", + ); + if let Some(rotation_proof) = rotation_proof { // verify rotation proof let rt = eval_and_dedup_points @@ -283,22 +295,81 @@ impl ZerocheckLayer for Layer { ); let in_point = in_point.into_iter().map(|c| c.elements).collect_vec(); - // eval eq and set to respective witin + let structural_witin_offset = self.n_witin + self.n_fixed + self.n_instance; + // eval selector and set to respective witin izip!( &self.out_sel_and_eval_exprs, &eval_and_dedup_points, selector_ctxs.iter() ) .for_each(|((sel_type, _), (_, out_point), selector_ctx)| { - sel_type.evaluate( - &mut main_evals, - out_point.as_ref().unwrap(), - &in_point, - selector_ctx, - self.n_witin, - ); + if let Some((expected_eval, wit_id)) = + sel_type.evaluate(out_point.as_ref().unwrap(), &in_point, selector_ctx) + { + let wit_id = wit_id as usize + structural_witin_offset; + assert_eq!(main_evals[wit_id], expected_eval); + } }); + // check structural witin + for StructuralWitIn { id, witin_type } in &self.structural_witins { + let wit_id = *id as usize + structural_witin_offset; + let expected_eval = match witin_type { + StructuralWitInType::EqualDistanceSequence { + offset, + multi_factor, + descending, + .. + } => eval_wellform_address_vec( + *offset as u64, + *multi_factor as u64, + &in_point, + *descending, + ), + StructuralWitInType::StackedIncrementalSequence { .. } => { + eval_stacked_wellform_address_vec(&in_point) + } + + StructuralWitInType::StackedConstantSequence { .. } => { + eval_stacked_constant_vec(&in_point) + } + StructuralWitInType::InnerRepeatingIncrementalSequence { k, .. } => { + eval_inner_repeated_incremental_vec(*k as u64, &in_point) + } + StructuralWitInType::OuterRepeatingIncrementalSequence { k, .. } => { + eval_outer_repeated_incremental_vec(*k as u64, &in_point) + } + StructuralWitInType::Empty => continue, + }; + if expected_eval != main_evals[wit_id] { + return Err(BackendError::LayerVerificationFailed( + format!("layer {} structural witin mismatch", self.name.clone()).into(), + VerifierError::ClaimNotMatch( + format!("{}", expected_eval).into(), + format!("{}", main_evals[wit_id]).into(), + ), + )); + } + } + + // check pub-io + // assume public io is tiny vector, so we evaluate it directly without PCS + let pubio_offset = self.n_witin + self.n_fixed; + for (index, instance) in self.instance_openings.iter().enumerate() { + let index = pubio_offset + index; + let poly = raw_pi[instance.0].to_vec().into_mle(); + let expected_eval = poly.evaluate(&in_point[..poly.num_vars()]); + if expected_eval != main_evals[index] { + return Err(BackendError::LayerVerificationFailed( + format!("layer {} pi mismatch", self.name.clone()).into(), + VerifierError::ClaimNotMatch( + format!("{}", expected_eval).into(), + format!("{}", main_evals[index]).into(), + ), + )); + } + } + let got_claim = eval_by_expr_with_instance( &[], &main_evals, @@ -430,14 +501,14 @@ fn verify_rotation( pub fn extend_exprs_with_rotation( layer: &Layer, alpha_pows: &[Expression], - offset_eq_id: WitnessId, ) -> Vec> { + let offset_structural_witid = (layer.n_witin + layer.n_fixed + layer.n_instance) as WitnessId; let mut alpha_pows_iter = alpha_pows.iter(); let mut expr_iter = layer.exprs.iter(); let mut zero_check_exprs = Vec::with_capacity(layer.out_sel_and_eval_exprs.len()); let match_expr = |sel_expr: &Expression| match sel_expr { - Expression::StructuralWitIn(id, ..) => Expression::WitIn(offset_eq_id + *id), + Expression::StructuralWitIn(id, ..) => Expression::WitIn(offset_structural_witid + *id), invalid => panic!("invalid eq format {:?}", invalid), }; @@ -515,9 +586,9 @@ pub fn extend_exprs_with_rotation( Expression::StructuralWitIn(right_eq_id, ..), Expression::StructuralWitIn(eq_id, ..), ) => ( - Expression::WitIn(offset_eq_id + *left_eq_id), - Expression::WitIn(offset_eq_id + *right_eq_id), - Expression::WitIn(offset_eq_id + *eq_id), + Expression::WitIn(offset_structural_witid + *left_eq_id), + Expression::WitIn(offset_structural_witid + *right_eq_id), + Expression::WitIn(offset_structural_witid + *eq_id), ), invalid => panic!("invalid eq format {:?}", invalid), }; diff --git a/gkr_iop/src/gkr/layer_constraint_system.rs b/gkr_iop/src/gkr/layer_constraint_system.rs index 750f626c4..6bbb1cdc7 100644 --- a/gkr_iop/src/gkr/layer_constraint_system.rs +++ b/gkr_iop/src/gkr/layer_constraint_system.rs @@ -417,12 +417,15 @@ impl LayerConstraintSystem { self.num_witin, 0, self.num_fixed, + 0, expressions, n_challenges, in_eval_expr, expr_evals, ((None, vec![]), 0, 0), expr_names, + vec![], + vec![], ) } else { let Some(RotationParams { @@ -439,6 +442,7 @@ impl LayerConstraintSystem { self.num_witin, 0, self.num_fixed, + 0, expressions, n_challenges, in_eval_expr, @@ -449,6 +453,8 @@ impl LayerConstraintSystem { rotation_cyclic_subgroup_size, ), expr_names, + vec![], + vec![], ) } } diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs index 4dfdbe2f6..11dc3ce4c 100644 --- a/gkr_iop/src/gkr/mock.rs +++ b/gkr_iop/src/gkr/mock.rs @@ -73,12 +73,13 @@ impl MockProver { wit_infer_by_expr( &(sel.selector_expr() * expr), layer.n_witin as WitnessId, - layer.n_structural_witin as WitnessId, layer.n_fixed as WitnessId, + layer.n_instance, &[], &wits, &structural_wits, &[], + &[], &challenges, ) }) @@ -91,8 +92,8 @@ impl MockProver { out.iter().map(|out| { out.mock_evaluate( layer.n_witin as WitnessId, - layer.n_structural_witin as WitnessId, layer.n_fixed as WitnessId, + layer.n_instance, &evaluations, &challenges, num_vars, @@ -146,8 +147,8 @@ impl EvalExpression { pub fn mock_evaluate<'a>( &self, n_witin: WitnessId, - n_structural_witin: WitnessId, n_fixed: WitnessId, + n_instance: usize, evals: &[ArcMultilinearExtension<'a, E>], challenges: &[E], num_vars: usize, @@ -160,12 +161,13 @@ impl EvalExpression { EvalExpression::Linear(i, c0, c1) => wit_infer_by_expr( &(Expression::WitIn(*i as WitnessId) * *c0.clone() + *c1.clone()), n_witin, - n_structural_witin, n_fixed, + n_instance, &[], evals, &[], &[], + &[], challenges, ), EvalExpression::Partition(parts, indices) => { @@ -174,12 +176,7 @@ impl EvalExpression { .iter() .map(|part| { part.mock_evaluate( - n_witin, - n_structural_witin, - n_fixed, - evals, - challenges, - num_vars, + n_witin, n_fixed, n_instance, evals, challenges, num_vars, ) }) .collect::, _>>()?; diff --git a/gkr_iop/src/gpu/mod.rs b/gkr_iop/src/gpu/mod.rs index 14a371a7b..8c00e6a20 100644 --- a/gkr_iop/src/gpu/mod.rs +++ b/gkr_iop/src/gpu/mod.rs @@ -14,6 +14,7 @@ use witness::RowMajorMatrix; use crate::cpu::default_backend_config; +use either::Either; use itertools::{Itertools, izip}; use std::marker::PhantomData; @@ -338,7 +339,7 @@ impl> fn layer_witness<'a>( layer: &Layer, layer_wits: &[Arc< as ProverBackend>::MultilinearPoly<'a>>], - pub_io_evals: &[Arc< as ProverBackend>::MultilinearPoly<'a>>], + pub_io_evals: &[Either], challenges: &[E], ) -> Vec as ProverBackend>::MultilinearPoly<'a>>> { let span = entered_span!("preprocess", profiling_2 = true); @@ -352,17 +353,6 @@ impl> .flat_map(|(sel_type, out_eval)| izip!(std::iter::repeat(sel_type), out_eval.iter())) .collect(); - // take public input from gpu to cpu for scalar evaluation - // assume public io is quite small, thus the cost is negligible - // evaluate all scalar terms first - // when instance was access in scalar, we only take its first item - // this operation is sound - let pub_io_evals = pub_io_evals - .iter() - .map(|mle| mle.inner_to_mle()) - .map(|instance| instance.evaluations.index(0)) - .collect_vec(); - // pre-process and flatten indices into friendly GPU format let (num_non_zero_expr, term_coefficients, mle_indices_per_term, mle_size_info) = layer .exprs_with_selector_out_eval_monomial_form diff --git a/gkr_iop/src/hal.rs b/gkr_iop/src/hal.rs index 8864f8c95..efb72e811 100644 --- a/gkr_iop/src/hal.rs +++ b/gkr_iop/src/hal.rs @@ -2,6 +2,7 @@ use crate::gkr::layer::{ Layer, hal::{LinearLayerProver, SumcheckLayerProver, ZerocheckLayerProver}, }; +use either::Either; use ff_ext::ExtensionField; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::mle::Point; @@ -50,7 +51,7 @@ pub trait ProtocolWitnessGeneratorProver { fn layer_witness<'a>( layer: &Layer, layer_wits: &[Arc>], - pub_io_evals: &[Arc>], + pub_io_evals: &[Either<::BaseField, PB::E>], challenges: &[PB::E], ) -> Vec>>; } diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index 9f10d2249..ebeb7f526 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -4,7 +4,7 @@ use rayon::iter::IndexedParallelIterator; use ff_ext::ExtensionField; use multilinear_extensions::{ - Expression, + Expression, WitnessId, mle::{IntoMLE, MultilinearExtension, Point}, util::ceil_log2, virtual_poly::{build_eq_x_r_vec, eq_eval}, @@ -239,17 +239,15 @@ impl SelectorType { pub fn evaluate( &self, - evals: &mut Vec, out_point: &Point, in_point: &Point, ctx: &SelectorContext, - offset_eq_id: usize, - ) { + ) -> Option<(E, WitnessId)> { assert_eq!(in_point.len(), ctx.num_vars); assert_eq!(out_point.len(), ctx.num_vars); let (expr, eval) = match self { - SelectorType::None => return, + SelectorType::None => return None, SelectorType::Whole(expr) => { debug_assert_eq!(out_point.len(), in_point.len()); (expr, eq_eval(out_point, in_point)) @@ -353,11 +351,7 @@ impl SelectorType { let Expression::StructuralWitIn(wit_id, _) = expr else { panic!("Wrong selector expression format"); }; - let wit_id = *wit_id as usize + offset_eq_id; - if wit_id >= evals.len() { - evals.resize(wit_id + 1, E::ZERO); - } - evals[wit_id] = eval; + Some((eval, *wit_id)) } /// return ordered indices of OrderedSparse32 @@ -426,9 +420,9 @@ mod tests { assert_eq!(vec[7], E::ZERO); let in_rt = E::random_vec(n_vars, &mut rng); - let mut evals = vec![]; - // TODO: avoid the param evals when we evaluate a selector - selector.evaluate(&mut evals, &out_rt, &in_rt, &ctx, 0); - assert_eq!(sel_mle.evaluate(&in_rt), evals[0]); + let Some((eval, _)) = selector.evaluate(&out_rt, &in_rt, &ctx) else { + unreachable!() + }; + assert_eq!(sel_mle.evaluate(&in_rt), eval); } } diff --git a/gkr_iop/src/utils.rs b/gkr_iop/src/utils.rs index 0d8c919e9..5d1970297 100644 --- a/gkr_iop/src/utils.rs +++ b/gkr_iop/src/utils.rs @@ -207,12 +207,116 @@ pub fn eq_eval_less_or_equal_than(max_idx: usize, a: &[E], b: ans } +/// evaluate MLE M(x0, x1, x2, ..., xn) address vector with it evaluation format +/// on r = [r0, r1, r2, ...rn] succinctly +/// where `M = descending * scaled * M' + offset` +/// offset, scaled, is constant, descending = +1/-1 +/// and M' := [0, 1, 2, 3, ....2^n-1] +/// succinctly format of M'(r) = r0 + r1 * 2 + r2 * 2^2 + .... rn * 2^n +pub fn eval_wellform_address_vec( + offset: u64, + scaled: u64, + r: &[E], + descending: bool, +) -> E { + let (offset, scaled) = (E::from_canonical_u64(offset), E::from_canonical_u64(scaled)); + let tmp = scaled + * r.iter() + .scan(E::ONE, |state, x| { + let result = *x * *state; + *state *= E::from_canonical_u64(2); // Update the state for the next power of 2 + Some(result) + }) + .sum::(); + let tmp = if descending { tmp.neg() } else { tmp }; + offset + tmp +} + +/// Evaluate MLE with the following evaluation over the hypercube: +/// [0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, ..., 0, 1, 2, ..., 2^n-1] +/// which is the concatenation of +/// [0] +/// [0, 1] +/// [0, 1, 2, 3] +/// ... +/// [0, 1, 2, ..., 2^n-1] +/// which is then prefixed by a single zero to make all the subvectors aligned to powers of two. +/// This function is used to support dynamic range check. +/// Note that this MLE has n+1 variables, so r should have length n+1. +/// +/// conceptually, we traverse evaluations in the sequence: +/// [0, 0], [0, 1], [0, 1, 2, 3], ... +/// for every `next` element is already in a well-formed incremental structure, +/// so we can reuse `eval_wellform_address_vec` to obtain its value. +/// +/// at each step `i`, we combine: +/// - the accumulated result so far, weighted by `(1 - r[i])` +/// - the evaluation of the current prefix `r[..i]`, weighted by `r[i]`. +/// +/// this iterative version avoids recursion for efficiency and clarity. +pub fn eval_stacked_wellform_address_vec(r: &[E]) -> E { + if r.len() < 2 { + return E::ZERO; + } + + let mut res = E::ZERO; + for i in 1..r.len() { + res = res * (E::ONE - r[i]) + eval_wellform_address_vec(0, 1, &r[..i], false) * r[i]; + } + res +} + +/// Evaluate MLE with the following evaluation over the hypercube: +/// [0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, ..., n, n, n, ..., n] +/// which is the concatenation of +/// [0] +/// [1, 1] +/// [2, 2, 2, 2] +/// ... +/// [n, n, n, ..., n] +/// which is then prefixed by a single zero to make all the subvectors aligned to powers of two. +/// This function is used to support dynamic range check. +/// Note that this MLE has n+1 variables, so r should have length n+1. +pub fn eval_stacked_constant_vec(r: &[E]) -> E { + if r.len() < 2 { + return E::ZERO; + } + + let mut res = E::ZERO; + for (i, r) in r.iter().enumerate().skip(1) { + res = res * (E::ONE - *r) + E::from_canonical_usize(i) * *r; + } + res +} + +/// evaluate MLE M(x0, x1, x2, ..., xn) address vector with it evaluation format +/// on r = [r0, r1, r2, ...rn] succinctly +/// where `M = [0 ... 0 1 ... 1 ... 2^(n-k)-1 ... 2^(n-k)-1]` +/// where each element is repeated 2^k times +/// The value is the same as M(xk, xk+1, ..., xn), i.e., just abandoning +/// the first k elements from r +pub fn eval_inner_repeated_incremental_vec(k: u64, r: &[E]) -> E { + eval_wellform_address_vec(0, 1, &r[k as usize..], false) +} + +/// evaluate MLE M(x0, x1, x2, ..., xn) address vector with it evaluation format +/// on r = [r0, r1, r2, ...rn] succinctly +/// where `M = [0 1 ... 2^k-1] * 2^(n-k)` +/// The value is the same as M(x0, ..., xk), i.e., just taking +/// the first k elements from r +pub fn eval_outer_repeated_incremental_vec(k: u64, r: &[E]) -> E { + eval_wellform_address_vec(0, 1, &r[..k as usize], false) +} + #[cfg(test)] mod tests { - use std::sync::Arc; - use ff_ext::{FromUniformBytes, GoldilocksExt2}; - use p3::goldilocks::Goldilocks; + use p3::{field::FieldAlgebra, goldilocks::Goldilocks}; + use std::{iter, sync::Arc}; + + type E = GoldilocksExt2; + + use multilinear_extensions::mle::MultilinearExtension; use super::*; @@ -251,4 +355,92 @@ mod tests { bh.get_rotation_right_eval_from_left(rotated_eval, left_eval, &point) ); } + + #[test] + fn test_eval_stacked_wellform_address_vec() { + let r = [ + E::from_canonical_usize(123), + E::from_canonical_usize(456), + E::from_canonical_usize(789), + E::from_canonical_usize(3210), + E::from_canonical_usize(9876), + ]; + for n in 0..r.len() { + let v = iter::once(E::ZERO) + .chain((0..=n).flat_map(|i| (0..(1 << i)).map(E::from_canonical_usize))) + .collect::>(); + let poly = MultilinearExtension::from_evaluations_ext_vec(n + 1, v); + assert_eq!( + eval_stacked_wellform_address_vec(&r[0..=n]), + poly.evaluate(&r[0..=n]) + ) + } + } + + #[test] + fn test_eval_stacked_constant_vec() { + let r = [ + E::from_canonical_usize(123), + E::from_canonical_usize(456), + E::from_canonical_usize(789), + E::from_canonical_usize(3210), + E::from_canonical_usize(9876), + ]; + for n in 0..r.len() { + let v = iter::once(E::ZERO) + .chain((0..=n).flat_map(|i| iter::repeat_n(i, 1 << i).map(E::from_canonical_usize))) + .collect::>(); + let poly = MultilinearExtension::from_evaluations_ext_vec(n + 1, v); + assert_eq!( + eval_stacked_constant_vec(&r[0..=n]), + poly.evaluate(&r[0..=n]) + ) + } + } + + #[test] + fn test_eval_inner_repeating_incremental_vec() { + let r = [ + E::from_canonical_usize(123), + E::from_canonical_usize(456), + E::from_canonical_usize(789), + E::from_canonical_usize(3210), + E::from_canonical_usize(9876), + ]; + for n in 1..=r.len() { + for k in 0..=n { + let v = (0..(1 << (n - k))) + .flat_map(|i| iter::repeat_n(E::from_canonical_usize(i), 1 << k)) + .collect::>(); + let poly = MultilinearExtension::from_evaluations_ext_vec(n, v); + assert_eq!( + eval_inner_repeated_incremental_vec(k as u64, &r[0..n]), + poly.evaluate(&r[0..n]) + ) + } + } + } + + #[test] + fn test_eval_outer_repeating_incremental_vec() { + let r = [ + E::from_canonical_usize(123), + E::from_canonical_usize(456), + E::from_canonical_usize(789), + E::from_canonical_usize(3210), + E::from_canonical_usize(9876), + ]; + for n in 1..=r.len() { + for k in 0..=n { + let v = iter::repeat_n(0, 1 << (n - k)) + .flat_map(|_| (0..(1 << k)).map(E::from_canonical_usize)) + .collect::>(); + let poly = MultilinearExtension::from_evaluations_ext_vec(n, v); + assert_eq!( + eval_outer_repeated_incremental_vec(k as u64, &r[0..n]), + poly.evaluate(&r[0..n]) + ) + } + } + } }