@@ -27,7 +27,10 @@ use p3::{
2727 symmetric:: Permutation ,
2828} ;
2929use rayon:: {
30- iter:: { IndexedParallelIterator , IntoParallelIterator , ParallelExtend , ParallelIterator } ,
30+ iter:: {
31+ IndexedParallelIterator , IntoParallelIterator , IntoParallelRefMutIterator , ParallelExtend ,
32+ ParallelIterator ,
33+ } ,
3134 prelude:: ParallelSliceMut ,
3235 slice:: ParallelSlice ,
3336} ;
@@ -538,50 +541,46 @@ impl<E: ExtensionField, P: Permutation<[E::BaseField; POSEIDON2_BABYBEAR_WIDTH]>
538541 )
539542 . collect :: < Result < ( ) , ZKVMError > > ( ) ?;
540543
541- // assign internal nodes in the binary tree for ec point summation
542- let mut cur_layer_points = steps
543- . iter ( )
544- . map ( |step| step. ec_point . as_ref ( ) . map ( |p| p. point . clone ( ) ) . unwrap ( ) )
545- . enumerate ( )
546- . collect_vec ( ) ;
544+ // allocate num_rows_padded size, fill points on first half
545+ let mut cur_layer_points_buffer: Vec < _ > = ( 0 ..num_rows_padded)
546+ . into_par_iter ( )
547+ . map ( |i| {
548+ steps
549+ . get ( i)
550+ . and_then ( |step| step. ec_point . as_ref ( ) )
551+ . map ( |point| point. point . clone ( ) )
552+ . unwrap_or_else ( SepticPoint :: default)
553+ } )
554+ . collect ( ) ;
555+ // raw_witin offset start from n.
556+ // left node is at b, right node is at b + 1
557+ // op(left node, right node) = offset + b / 2
558+ let mut offset = num_rows_padded / 2 ;
559+ let mut current_layer_len = cur_layer_points_buffer. len ( ) / 2 ;
547560
548561 // slope[1,b] = (input[b,0].y - input[b,1].y) / (input[b,0].x - input[b,1].x)
549562 loop {
550- if cur_layer_points . len ( ) <= 1 {
563+ if current_layer_len <= 1 {
551564 break ;
552565 }
553- // 2b -> b + 2^log_n
554- let next_layer_offset = cur_layer_points . first ( ) . map ( | ( i , _ ) | * i / 2 + n ) . unwrap ( ) ;
555- cur_layer_points = cur_layer_points
566+ let ( current_layer , next_layer ) =
567+ cur_layer_points_buffer . split_at_mut ( current_layer_len ) ;
568+ current_layer
556569 . par_chunks ( 2 )
557- . zip ( raw_witin. values [ next_layer_offset * num_witin..] . par_chunks_mut ( num_witin) )
558- . with_min_len ( 64 )
559- . map ( |( pair, instance) | {
560- // input[1,b] = affine_add(input[b,0], input[b,1])
561- // the left node is at index 2b, right node is at index 2b+1
562- // the parent node is at index b + 2^n
563- let ( o, slope, q) = match pair. len ( ) {
564- 2 => {
565- // l = 2b, r = 2b+1
566- let ( l, p1) = & pair[ 0 ] ;
567- let ( r, p2) = & pair[ 1 ] ;
568- assert_eq ! ( * r - * l, 1 ) ;
569-
570- // parent node idx = b + 2^log2_n
571- let o = n + l / 2 ;
572- let slope = ( & p1. y - & p2. y ) * ( & p1. x - & p2. x ) . inverse ( ) . unwrap ( ) ;
573- let q = p1. clone ( ) + p2. clone ( ) ;
574-
575- ( o, slope, q)
576- }
577- 1 => {
578- let ( l, p) = & pair[ 0 ] ;
579- let o = n + l / 2 ;
580- ( o, SepticExtension :: zero ( ) , p. clone ( ) )
581- }
582- _ => unreachable ! ( ) ,
570+ . zip_eq ( next_layer[ ..current_layer_len / 2 ] . par_iter_mut ( ) )
571+ . zip ( raw_witin. values [ offset * num_witin..] . par_chunks_mut ( num_witin) )
572+ . for_each ( |( ( pair, parent) , instance) | {
573+ let p1 = & pair[ 0 ] ;
574+ let p2 = & pair[ 1 ] ;
575+ let ( slope, q) = if p2. is_infinity {
576+ // input[1,b] = bypass_left(input[b,0], input[b,1])
577+ ( SepticExtension :: zero ( ) , p1. clone ( ) )
578+ } else {
579+ // input[1,b] = affine_add(input[b,0], input[b,1])
580+ let slope = ( & p1. y - & p2. y ) * ( & p1. x - & p2. x ) . inverse ( ) . unwrap ( ) ;
581+ let q = p1. clone ( ) + p2. clone ( ) ;
582+ ( slope, q)
583583 } ;
584-
585584 config
586585 . x
587586 . iter ( )
@@ -595,10 +594,11 @@ impl<E: ExtensionField, P: Permutation<[E::BaseField; POSEIDON2_BABYBEAR_WIDTH]>
595594 . for_each ( |( witin, fe) | {
596595 set_val ! ( instance, * witin, * fe) ;
597596 } ) ;
598-
599- ( o, q)
600- } )
601- . collect :: < Vec < _ > > ( ) ;
597+ * parent = q. clone ( ) ;
598+ } ) ;
599+ cur_layer_points_buffer = cur_layer_points_buffer. split_off ( current_layer_len) ;
600+ current_layer_len /= 2 ;
601+ offset += current_layer_len;
602602 }
603603
604604 let raw_witin = witness:: RowMajorMatrix :: new_by_inner_matrix (
0 commit comments