Skip to content

Commit 4af044d

Browse files
committed
refactor ec points witness assignments
1 parent 9676da8 commit 4af044d

File tree

1 file changed

+42
-42
lines changed

1 file changed

+42
-42
lines changed

ceno_zkvm/src/instructions/global.rs

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ use p3::{
2727
symmetric::Permutation,
2828
};
2929
use rayon::{
30-
iter::{IndexedParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator},
30+
iter::{
31+
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelExtend,
32+
ParallelIterator,
33+
},
3134
prelude::ParallelSliceMut,
3235
slice::ParallelSlice,
3336
};
@@ -538,50 +541,46 @@ impl<E: ExtensionField, P: Permutation<[E::BaseField; POSEIDON2_BABYBEAR_WIDTH]>
538541
)
539542
.collect::<Result<(), ZKVMError>>()?;
540543

541-
// assign internal nodes in the binary tree for ec point summation
542-
let mut cur_layer_points = steps
543-
.iter()
544-
.map(|step| step.ec_point.as_ref().map(|p| p.point.clone()).unwrap())
545-
.enumerate()
546-
.collect_vec();
544+
// allocate num_rows_padded size, fill points on first half
545+
let mut cur_layer_points_buffer: Vec<_> = (0..num_rows_padded)
546+
.into_par_iter()
547+
.map(|i| {
548+
steps
549+
.get(i)
550+
.and_then(|step| step.ec_point.as_ref())
551+
.map(|point| point.point.clone())
552+
.unwrap_or_else(SepticPoint::default)
553+
})
554+
.collect();
555+
// raw_witin offset start from n.
556+
// left node is at b, right node is at b + 1
557+
// op(left node, right node) = offset + b / 2
558+
let mut offset = num_rows_padded / 2;
559+
let mut current_layer_len = cur_layer_points_buffer.len() / 2;
547560

548561
// slope[1,b] = (input[b,0].y - input[b,1].y) / (input[b,0].x - input[b,1].x)
549562
loop {
550-
if cur_layer_points.len() <= 1 {
563+
if current_layer_len <= 1 {
551564
break;
552565
}
553-
// 2b -> b + 2^log_n
554-
let next_layer_offset = cur_layer_points.first().map(|(i, _)| *i / 2 + n).unwrap();
555-
cur_layer_points = cur_layer_points
566+
let (current_layer, next_layer) =
567+
cur_layer_points_buffer.split_at_mut(current_layer_len);
568+
current_layer
556569
.par_chunks(2)
557-
.zip(raw_witin.values[next_layer_offset * num_witin..].par_chunks_mut(num_witin))
558-
.with_min_len(64)
559-
.map(|(pair, instance)| {
560-
// input[1,b] = affine_add(input[b,0], input[b,1])
561-
// the left node is at index 2b, right node is at index 2b+1
562-
// the parent node is at index b + 2^n
563-
let (o, slope, q) = match pair.len() {
564-
2 => {
565-
// l = 2b, r = 2b+1
566-
let (l, p1) = &pair[0];
567-
let (r, p2) = &pair[1];
568-
assert_eq!(*r - *l, 1);
569-
570-
// parent node idx = b + 2^log2_n
571-
let o = n + l / 2;
572-
let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap();
573-
let q = p1.clone() + p2.clone();
574-
575-
(o, slope, q)
576-
}
577-
1 => {
578-
let (l, p) = &pair[0];
579-
let o = n + l / 2;
580-
(o, SepticExtension::zero(), p.clone())
581-
}
582-
_ => unreachable!(),
570+
.zip_eq(next_layer[..current_layer_len / 2].par_iter_mut())
571+
.zip(raw_witin.values[offset * num_witin..].par_chunks_mut(num_witin))
572+
.for_each(|((pair, parent), instance)| {
573+
let p1 = &pair[0];
574+
let p2 = &pair[1];
575+
let (slope, q) = if p2.is_infinity {
576+
// input[1,b] = bypass_left(input[b,0], input[b,1])
577+
(SepticExtension::zero(), p1.clone())
578+
} else {
579+
// input[1,b] = affine_add(input[b,0], input[b,1])
580+
let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap();
581+
let q = p1.clone() + p2.clone();
582+
(slope, q)
583583
};
584-
585584
config
586585
.x
587586
.iter()
@@ -595,10 +594,11 @@ impl<E: ExtensionField, P: Permutation<[E::BaseField; POSEIDON2_BABYBEAR_WIDTH]>
595594
.for_each(|(witin, fe)| {
596595
set_val!(instance, *witin, *fe);
597596
});
598-
599-
(o, q)
600-
})
601-
.collect::<Vec<_>>();
597+
*parent = q.clone();
598+
});
599+
cur_layer_points_buffer = cur_layer_points_buffer.split_off(current_layer_len);
600+
current_layer_len /= 2;
601+
offset += current_layer_len;
602602
}
603603

604604
let raw_witin = witness::RowMajorMatrix::new_by_inner_matrix(

0 commit comments

Comments
 (0)