Skip to content

Commit 5af8392

Browse files
authored
refactor ec points witness assignments (#1100)
- simplify index computation of affine_add(left, right) - allocate buffer once and parallel initialization
1 parent 3860d14 commit 5af8392

File tree

1 file changed

+41
-42
lines changed

1 file changed

+41
-42
lines changed

ceno_zkvm/src/instructions/global.rs

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ use p3::{
3131
symmetric::Permutation,
3232
};
3333
use rayon::{
34-
iter::{IndexedParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator},
34+
iter::{
35+
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelExtend,
36+
ParallelIterator,
37+
},
3538
prelude::ParallelSliceMut,
3639
slice::ParallelSlice,
3740
};
@@ -554,50 +557,45 @@ impl<E: ExtensionField> TableCircuit<E> for GlobalChip<E> {
554557
})
555558
.collect::<Result<(), ZKVMError>>()?;
556559

557-
// assign internal nodes in the binary tree for ec point summation
558-
let mut cur_layer_points = steps
559-
.iter()
560-
.map(|step| step.ec_point.point.clone())
561-
.enumerate()
562-
.collect_vec();
560+
// allocate num_rows_padded size, fill points on first half
561+
let mut cur_layer_points_buffer: Vec<_> = (0..num_rows_padded)
562+
.into_par_iter()
563+
.map(|i| {
564+
steps
565+
.get(i)
566+
.map(|step| step.ec_point.point.clone())
567+
.unwrap_or_else(SepticPoint::default)
568+
})
569+
.collect();
570+
// raw_witin offset start from n.
571+
// left node is at b, right node is at b + 1
572+
// op(left node, right node) = offset + b / 2
573+
let mut offset = num_rows_padded / 2;
574+
let mut current_layer_len = cur_layer_points_buffer.len() / 2;
563575

564576
// slope[1,b] = (input[b,0].y - input[b,1].y) / (input[b,0].x - input[b,1].x)
565577
loop {
566-
if cur_layer_points.len() <= 1 {
578+
if current_layer_len <= 1 {
567579
break;
568580
}
569-
// 2b -> b + 2^log_n
570-
let next_layer_offset = cur_layer_points.first().map(|(i, _)| *i / 2 + n).unwrap();
571-
cur_layer_points = cur_layer_points
581+
let (current_layer, next_layer) =
582+
cur_layer_points_buffer.split_at_mut(current_layer_len);
583+
current_layer
572584
.par_chunks(2)
573-
.zip(raw_witin.values[next_layer_offset * num_witin..].par_chunks_mut(num_witin))
574-
.with_min_len(64)
575-
.map(|(pair, instance)| {
576-
// input[1,b] = affine_add(input[b,0], input[b,1])
577-
// the left node is at index 2b, right node is at index 2b+1
578-
// the parent node is at index b + 2^n
579-
let (o, slope, q) = match pair.len() {
580-
2 => {
581-
// l = 2b, r = 2b+1
582-
let (l, p1) = &pair[0];
583-
let (r, p2) = &pair[1];
584-
assert_eq!(*r - *l, 1);
585-
586-
// parent node idx = b + 2^log2_n
587-
let o = n + l / 2;
588-
let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap();
589-
let q = p1.clone() + p2.clone();
590-
591-
(o, slope, q)
592-
}
593-
1 => {
594-
let (l, p) = &pair[0];
595-
let o = n + l / 2;
596-
(o, SepticExtension::zero(), p.clone())
597-
}
598-
_ => unreachable!(),
585+
.zip_eq(next_layer[..current_layer_len / 2].par_iter_mut())
586+
.zip(raw_witin.values[offset * num_witin..].par_chunks_mut(num_witin))
587+
.for_each(|((pair, parent), instance)| {
588+
let p1 = &pair[0];
589+
let p2 = &pair[1];
590+
let (slope, q) = if p2.is_infinity {
591+
// input[1,b] = bypass_left(input[b,0], input[b,1])
592+
(SepticExtension::zero(), p1.clone())
593+
} else {
594+
// input[1,b] = affine_add(input[b,0], input[b,1])
595+
let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap();
596+
let q = p1.clone() + p2.clone();
597+
(slope, q)
599598
};
600-
601599
config
602600
.x
603601
.iter()
@@ -611,10 +609,11 @@ impl<E: ExtensionField> TableCircuit<E> for GlobalChip<E> {
611609
.for_each(|(witin, fe)| {
612610
set_val!(instance, *witin, *fe);
613611
});
614-
615-
(o, q)
616-
})
617-
.collect::<Vec<_>>();
612+
*parent = q.clone();
613+
});
614+
cur_layer_points_buffer = cur_layer_points_buffer.split_off(current_layer_len);
615+
current_layer_len /= 2;
616+
offset += current_layer_len;
618617
}
619618

620619
let raw_witin = witness::RowMajorMatrix::new_by_inner_matrix(

0 commit comments

Comments
 (0)