11#![ deny( clippy:: cargo) ]  
22use  ff_ext:: ExtensionField ; 
3- use  itertools:: { interleave ,   Either ,  Itertools } ; 
3+ use  itertools:: { Either ,  Itertools } ; 
44use  multilinear_extensions:: { mle:: { DenseMultilinearExtension ,  FieldType ,  MultilinearExtension } ,  virtual_poly:: { build_eq_x_r,  eq_eval,  VPAuxInfo } } ; 
55use  serde:: { Serialize ,  de:: DeserializeOwned } ; 
66use  std:: fmt:: Debug ; 
@@ -10,6 +10,8 @@ use p3_field::PrimeCharacteristicRing;
1010use  multilinear_extensions:: virtual_poly:: VirtualPolynomial ; 
1111use  sumcheck:: structs:: { IOPProof ,  IOPProverState ,  IOPVerifierState } ; 
1212use  witness:: RowMajorMatrix ; 
13+ #[ cfg( feature = "parallel" ) ]  
14+ use  rayon:: prelude:: * ; 
1315
1416pub  mod  sum_check; 
1517pub  mod  util; 
@@ -172,7 +174,8 @@ fn interleave_polys<E: ExtensionField>(
172174// Interleave the polys give their position on the binary tree 
173175// Assume the polys are sorted by decreasing size 
174176// Denote: N - size of the interleaved poly; M - num of polys 
175- // This function performs interleave in O(M) + O(N) time and is *potentially* parallelizable (maybe? idk) 
177+ // This function performs interleave in O(M) + O(N) time 
178+ #[ cfg( not( feature = "parallel" ) ) ]  
176179fn  interleave_polys < E :  ExtensionField > ( 
177180    polys :  Vec < & DenseMultilinearExtension < E > > , 
178181    comps :  & Vec < Vec < bool > > , 
@@ -223,6 +226,84 @@ fn interleave_polys<E: ExtensionField>(
223226    DenseMultilinearExtension  {  num_vars :  interleaved_num_vars,  evaluations :  interleaved_evaluations } 
224227} 
225228
229+ // Parallel version: divide interleaved_evaluation into chunks 
230+ #[ cfg( feature = "parallel" ) ]  
231+ fn  interleave_polys < E :  ExtensionField > ( 
232+     polys :  Vec < & DenseMultilinearExtension < E > > , 
233+     comps :  & Vec < Vec < bool > > , 
234+ )  -> DenseMultilinearExtension < E >  { 
235+     use  std:: cmp:: min; 
236+ 
237+     assert ! ( polys. len( )  > 0 ) ; 
238+     let  sizes:  Vec < usize >  = polys. iter ( ) . map ( |p| p. evaluations . len ( ) ) . collect ( ) ; 
239+     let  interleaved_size = sizes. iter ( ) . sum :: < usize > ( ) . next_power_of_two ( ) ; 
240+     let  interleaved_num_vars = interleaved_size. ilog2 ( )  as  usize ; 
241+ 
242+     // Compute Start and Gap for each poly 
243+     // * Start: where's its first entry in the interleaved poly? 
244+     // * Gap: how many entires are between its consecutive entries in the interleaved poly? 
245+     let  start_list:  Vec < usize >  = comps. iter ( ) . map ( |comp| { 
246+         let  mut  start = 0 ; 
247+         let  mut  pow_2 = 1 ; 
248+         for  b in  comp { 
249+             start += if  * b {  pow_2 }  else  {  0  } ; 
250+             pow_2 *= 2 ; 
251+         } 
252+         start
253+     } ) . collect ( ) ; 
254+     let  gap_list:  Vec < usize >  = polys. iter ( ) . map ( |poly|
255+         1  << ( interleaved_num_vars - poly. num_vars ) 
256+     ) . collect ( ) ; 
257+     // Minimally each chunk needs one entry from the smallest poly 
258+     let  num_chunks = min ( rayon:: current_num_threads ( ) . next_power_of_two ( ) ,  sizes[ sizes. len ( )  - 1 ] ) ; 
259+     let  interleaved_chunk_size = interleaved_size / num_chunks; 
260+     // Length of the poly each thread processes 
261+     let  poly_chunk_size:  Vec < usize >  = sizes. iter ( ) . map ( |s| s / num_chunks) . collect ( ) ; 
262+ 
263+     // Initialize the interleaved poly 
264+     // Is there a better way to deal with field types? 
265+     let  interleaved_evaluations = match  polys[ 0 ] . evaluations  { 
266+         FieldType :: Base ( _)  => { 
267+             let  mut  interleaved_eval = vec ! [ E :: BaseField :: ZERO ;  interleaved_size] ; 
268+             interleaved_eval. par_chunks_exact_mut ( interleaved_chunk_size) . enumerate ( ) . for_each ( |( i,  chunk) | { 
269+                 for  ( p,  poly)  in  polys. iter ( ) . enumerate ( )  { 
270+                     match  & poly. evaluations  { 
271+                         FieldType :: Base ( pe)  => { 
272+                             // Each thread processes a chunk of pe 
273+                             for  ( j,  e)  in  pe[ i *  poly_chunk_size[ p] ..( i+1 )  *  poly_chunk_size[ p] ] . iter ( ) . enumerate ( )  { 
274+                                 chunk[ start_list[ p]  + gap_list[ p]  *  j]  = * e; 
275+                             } 
276+                         } 
277+                         b => panic ! ( "do not support merge BASE field type with b: {:?}" ,  b) 
278+                     } 
279+                 } 
280+             } ) ; 
281+             FieldType :: Base ( interleaved_eval) 
282+         } 
283+         FieldType :: Ext ( _)  => { 
284+             let  mut  interleaved_eval = vec ! [ E :: ZERO ;  interleaved_size] ; 
285+             interleaved_eval. par_chunks_exact_mut ( num_chunks) . enumerate ( ) . for_each ( |( i,  chunk) | { 
286+                 for  ( p,  poly)  in  polys. iter ( ) . enumerate ( )  { 
287+                     match  & poly. evaluations  { 
288+                         FieldType :: Ext ( pe)  => { 
289+                             // Each thread processes a chunk of pe 
290+                             for  ( j,  e)  in  pe[ i *  poly_chunk_size[ p] ..( i+1 )  *  poly_chunk_size[ p] ] . iter ( ) . enumerate ( )  { 
291+                                 chunk[ start_list[ p]  + gap_list[ p]  *  j]  = * e; 
292+                             } 
293+                         } 
294+                         b => panic ! ( "do not support merge EXT field type with b: {:?}" ,  b) 
295+                     } 
296+                 } 
297+             } ) ; 
298+             FieldType :: Ext ( interleaved_eval) 
299+         } 
300+         _ => unreachable ! ( ) 
301+     } ; 
302+ 
303+     DenseMultilinearExtension  {  num_vars :  interleaved_num_vars,  evaluations :  interleaved_evaluations } 
304+ } 
305+ 
306+ 
226307// Pack polynomials of different sizes into the same, returns 
227308// 0: A list of packed polys 
228309// 1: The final packed poly, if of different size 
0 commit comments