diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 8cef4d869..30be012fd 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -31,7 +31,10 @@ use p3::{ symmetric::Permutation, }; use rayon::{ - iter::{IndexedParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator}, + iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelExtend, + ParallelIterator, + }, prelude::ParallelSliceMut, slice::ParallelSlice, }; @@ -554,50 +557,45 @@ impl TableCircuit for GlobalChip { }) .collect::>()?; - // assign internal nodes in the binary tree for ec point summation - let mut cur_layer_points = steps - .iter() - .map(|step| step.ec_point.point.clone()) - .enumerate() - .collect_vec(); + // allocate num_rows_padded size, fill points on first half + let mut cur_layer_points_buffer: Vec<_> = (0..num_rows_padded) + .into_par_iter() + .map(|i| { + steps + .get(i) + .map(|step| step.ec_point.point.clone()) + .unwrap_or_else(SepticPoint::default) + }) + .collect(); + // raw_witin offset start from n. + // left node is at b, right node is at b + 1 + // op(left node, right node) = offset + b / 2 + let mut offset = num_rows_padded / 2; + let mut current_layer_len = cur_layer_points_buffer.len() / 2; // slope[1,b] = (input[b,0].y - input[b,1].y) / (input[b,0].x - input[b,1].x) loop { - if cur_layer_points.len() <= 1 { + if current_layer_len <= 1 { break; } - // 2b -> b + 2^log_n - let next_layer_offset = cur_layer_points.first().map(|(i, _)| *i / 2 + n).unwrap(); - cur_layer_points = cur_layer_points + let (current_layer, next_layer) = + cur_layer_points_buffer.split_at_mut(current_layer_len); + current_layer .par_chunks(2) - .zip(raw_witin.values[next_layer_offset * num_witin..].par_chunks_mut(num_witin)) - .with_min_len(64) - .map(|(pair, instance)| { - // input[1,b] = affine_add(input[b,0], input[b,1]) - // the left node is at index 2b, right node is at index 2b+1 - // the parent node is at index b + 2^n - let (o, slope, q) = match pair.len() { - 2 => { - // l = 2b, r = 2b+1 - let (l, p1) = &pair[0]; - let (r, p2) = &pair[1]; - assert_eq!(*r - *l, 1); - - // parent node idx = b + 2^log2_n - let o = n + l / 2; - let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap(); - let q = p1.clone() + p2.clone(); - - (o, slope, q) - } - 1 => { - let (l, p) = &pair[0]; - let o = n + l / 2; - (o, SepticExtension::zero(), p.clone()) - } - _ => unreachable!(), + .zip_eq(next_layer[..current_layer_len / 2].par_iter_mut()) + .zip(raw_witin.values[offset * num_witin..].par_chunks_mut(num_witin)) + .for_each(|((pair, parent), instance)| { + let p1 = &pair[0]; + let p2 = &pair[1]; + let (slope, q) = if p2.is_infinity { + // input[1,b] = bypass_left(input[b,0], input[b,1]) + (SepticExtension::zero(), p1.clone()) + } else { + // input[1,b] = affine_add(input[b,0], input[b,1]) + let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap(); + let q = p1.clone() + p2.clone(); + (slope, q) }; - config .x .iter() @@ -611,10 +609,11 @@ impl TableCircuit for GlobalChip { .for_each(|(witin, fe)| { set_val!(instance, *witin, *fe); }); - - (o, q) - }) - .collect::>(); + *parent = q.clone(); + }); + cur_layer_points_buffer = cur_layer_points_buffer.split_off(current_layer_len); + current_layer_len /= 2; + offset += current_layer_len; } let raw_witin = witness::RowMajorMatrix::new_by_inner_matrix(