Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 41 additions & 42 deletions ceno_zkvm/src/instructions/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ use p3::{
symmetric::Permutation,
};
use rayon::{
iter::{IndexedParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator},
iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelExtend,
ParallelIterator,
},
prelude::ParallelSliceMut,
slice::ParallelSlice,
};
Expand Down Expand Up @@ -554,50 +557,45 @@ impl<E: ExtensionField> TableCircuit<E> for GlobalChip<E> {
})
.collect::<Result<(), ZKVMError>>()?;

// assign internal nodes in the binary tree for ec point summation
let mut cur_layer_points = steps
.iter()
.map(|step| step.ec_point.point.clone())
.enumerate()
.collect_vec();
// allocate num_rows_padded size, fill points on first half
let mut cur_layer_points_buffer: Vec<_> = (0..num_rows_padded)
.into_par_iter()
.map(|i| {
steps
.get(i)
.map(|step| step.ec_point.point.clone())
.unwrap_or_else(SepticPoint::default)
})
.collect();
// raw_witin offset start from n.
// left node is at b, right node is at b + 1
// op(left node, right node) = offset + b / 2
let mut offset = num_rows_padded / 2;
let mut current_layer_len = cur_layer_points_buffer.len() / 2;

// slope[1,b] = (input[b,0].y - input[b,1].y) / (input[b,0].x - input[b,1].x)
loop {
if cur_layer_points.len() <= 1 {
if current_layer_len <= 1 {
break;
}
// 2b -> b + 2^log_n
let next_layer_offset = cur_layer_points.first().map(|(i, _)| *i / 2 + n).unwrap();
cur_layer_points = cur_layer_points
let (current_layer, next_layer) =
cur_layer_points_buffer.split_at_mut(current_layer_len);
current_layer
.par_chunks(2)
.zip(raw_witin.values[next_layer_offset * num_witin..].par_chunks_mut(num_witin))
.with_min_len(64)
.map(|(pair, instance)| {
// input[1,b] = affine_add(input[b,0], input[b,1])
// the left node is at index 2b, right node is at index 2b+1
// the parent node is at index b + 2^n
let (o, slope, q) = match pair.len() {
2 => {
// l = 2b, r = 2b+1
let (l, p1) = &pair[0];
let (r, p2) = &pair[1];
assert_eq!(*r - *l, 1);

// parent node idx = b + 2^log2_n
let o = n + l / 2;
let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap();
let q = p1.clone() + p2.clone();

(o, slope, q)
}
1 => {
let (l, p) = &pair[0];
let o = n + l / 2;
(o, SepticExtension::zero(), p.clone())
}
_ => unreachable!(),
.zip_eq(next_layer[..current_layer_len / 2].par_iter_mut())
.zip(raw_witin.values[offset * num_witin..].par_chunks_mut(num_witin))
.for_each(|((pair, parent), instance)| {
let p1 = &pair[0];
let p2 = &pair[1];
let (slope, q) = if p2.is_infinity {
// input[1,b] = bypass_left(input[b,0], input[b,1])
(SepticExtension::zero(), p1.clone())
} else {
// input[1,b] = affine_add(input[b,0], input[b,1])
let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap();
let q = p1.clone() + p2.clone();
(slope, q)
};

config
.x
.iter()
Expand All @@ -611,10 +609,11 @@ impl<E: ExtensionField> TableCircuit<E> for GlobalChip<E> {
.for_each(|(witin, fe)| {
set_val!(instance, *witin, *fe);
});

(o, q)
})
.collect::<Vec<_>>();
*parent = q.clone();
});
cur_layer_points_buffer = cur_layer_points_buffer.split_off(current_layer_len);
current_layer_len /= 2;
offset += current_layer_len;
}

let raw_witin = witness::RowMajorMatrix::new_by_inner_matrix(
Expand Down
Loading