Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 106 additions & 50 deletions ceno_zkvm/src/scheme/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ use multilinear_extensions::{
virtual_poly::build_eq_x_r_vec,
virtual_polys::VirtualPolynomialsBuilder,
};
use p3::field::FieldAlgebra;
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
};
use std::{collections::BTreeMap, sync::Arc};
use sumcheck::{
macros::{entered_span, exit_span},
Expand Down Expand Up @@ -65,6 +67,7 @@ impl CpuEccProver {

pub fn create_ecc_proof<'a, E: ExtensionField>(
&self,
num_instances: usize,
mut xs: Vec<MultilinearExtension<'a, E>>,
mut ys: Vec<MultilinearExtension<'a, E>>,
invs: Vec<MultilinearExtension<'a, E>>,
Expand All @@ -78,22 +81,44 @@ impl CpuEccProver {
let out_rt = transcript.sample_and_append_vec(b"ecc", n);
let num_threads = optimal_sumcheck_threads(out_rt.len());

let alpha_pows =
transcript.sample_and_append_challenge_pows(SEPTIC_EXTENSION_DEGREE * 3, b"ecc_alpha");
// expression with add (3 zero constrains) and bypass (2 zero constrains)
let alpha_pows = transcript.sample_and_append_challenge_pows(
SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2,
b"ecc_alpha",
);
let mut alpha_pows_iter = alpha_pows.iter();

let mut expr_builder = VirtualPolynomialsBuilder::new(num_threads, out_rt.len());

let sel = SelectorType::Prefix(0.into());
let num_instances = (1 << n) - 1;
let sel_ctx = SelectorContext {
let sel_add = SelectorType::QuarkBinaryTreeLessThan(0.into());
let sel_add_ctx = SelectorContext {
offset: 0,
num_instances,
num_vars: n,
};
let mut sel_mle: MultilinearExtension<'_, E> = sel.compute(&out_rt, &sel_ctx).unwrap();
let sel_expr = expr_builder.lift(sel_mle.to_either());
let mut sel_add_mle: MultilinearExtension<'_, E> =
sel_add.compute(&out_rt, &sel_add_ctx).unwrap();
// we construct sel_bypass witness here
// verifier can derive it via `sel_bypass = eq - sel_add - sel_last_onehot`
let mut sel_bypass_mle: Vec<E> = build_eq_x_r_vec(&out_rt);
match sel_add_mle.evaluations() {
FieldType::Ext(sel_add_mle) => sel_add_mle
.par_iter()
.zip_eq(sel_bypass_mle.par_iter_mut())
.for_each(|(sel_add, sel_bypass)| {
if *sel_add != E::ZERO {
*sel_bypass = E::ZERO;
}
}),
_ => unreachable!(),
}
*sel_bypass_mle.last_mut().unwrap() = E::ZERO;
let mut sel_bypass_mle = sel_bypass_mle.into_mle();
let sel_add_expr = expr_builder.lift(sel_add_mle.to_either());
let sel_bypass_expr = expr_builder.lift(sel_bypass_mle.to_either());

let mut exprs = vec![];
let mut exprs_add = vec![];
let mut exprs_bypass = vec![];

let filter_bj = |v: &[MultilinearExtension<'_, E>], j: usize| {
v.iter()
Expand Down Expand Up @@ -162,43 +187,58 @@ impl CpuEccProver {
);
// affine addition
// zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) with b != (1,...,1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some notes
[p1, p2, p3, p4, p5, o, o, o | p12, p34, p5, o]

sel_add: [1, 1, 1, ... ,0 ,0, 1, 1, 1, 0, 0 ,0 ,.. ..... 0, 0]
sel_bypass: [0, 0, 0, 0, 1, 1, ,... 0] => \sum_{s} eq(r ,s) - product_{i}(ri*si) - sel_add_eval(r, s) -
sel_last_onehot: [0, 0, 0, 0, ..., 1]

out = [r0, r1, r2, r3, ....]

round 0: (1-r0) * eq_less_than + r0 * [1, 0]
...
round i: (1-ri) * eq_less_than + ri * eval_so_far
...
round last: (1-r_last) * eq_less_than + r_last * eval_so_far

exprs.extend(
exprs_add.extend(
(s.clone() * (&x0 - &x1) - (&y0 - &y1))
.to_exprs()
.into_iter()
.zip(alpha_pows.iter().take(SEPTIC_EXTENSION_DEGREE))
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))),
);

// zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] with b != (1,...,1)
exprs.extend(
exprs_add.extend(
((&s * &s) - &x0 - &x1 - &x3)
.to_exprs()
.into_iter()
.zip(
alpha_pows[SEPTIC_EXTENSION_DEGREE..]
.iter()
.take(SEPTIC_EXTENSION_DEGREE),
)
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))),
);

// zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) with b != (1,...,1)
exprs.extend(
exprs_add.extend(
(s.clone() * (&x0 - &x3) - (&y0 + &y3))
.to_exprs()
.into_iter()
.zip(
alpha_pows[SEPTIC_EXTENSION_DEGREE * 2..]
.iter()
.take(SEPTIC_EXTENSION_DEGREE),
)
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))),
);

let exprs_add = exprs_add.into_iter().sum::<Expression<E>>() * sel_add_expr;

// deal with bypass
// 0 = (x[1,b] - x[b,0])
exprs_bypass.extend(
(&x3 - &x0)
.to_exprs()
.into_iter()
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))),
);

// 0 = (y[1,b] - y[b,0])
exprs_bypass.extend(
(&y3 - &y0)
.to_exprs()
.into_iter()
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))),
);
assert!(alpha_pows_iter.next().is_none());

let exprs_bypass = exprs_bypass.into_iter().sum::<Expression<E>>() * sel_bypass_expr;

let (zerocheck_proof, state) = IOPProverState::prove(
expr_builder
.to_virtual_polys(&[exprs.into_iter().sum::<Expression<E>>() * sel_expr], &[]),
expr_builder.to_virtual_polys(&[exprs_add + exprs_bypass], &[]),
transcript,
);

Expand All @@ -207,10 +247,11 @@ impl CpuEccProver {

assert_eq!(zerocheck_proof.extract_sum(), E::ZERO);
// 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[0,rt]
assert_eq!(evals.len(), 1 + SEPTIC_EXTENSION_DEGREE * 7);
assert_eq!(evals.len(), 2 + SEPTIC_EXTENSION_DEGREE * 7);

#[cfg(feature = "sanity-check")]
{
let last_evaluation_index = (1 << n) - 1;
let s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec();
let x0 = filter_bj(&xs, 0);
let y0 = filter_bj(&ys, 0);
Expand All @@ -219,19 +260,19 @@ impl CpuEccProver {
let x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec();
let y3 = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec();
let final_sum_x: SepticExtension<E::BaseField> = (x3.iter())
.map(|x| x.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0]
.map(|x| x.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0]
.collect_vec()
.into();
let final_sum_y: SepticExtension<E::BaseField> = (y3.iter())
.map(|y| y.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0]
.map(|y| y.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0]
.collect_vec()
.into();
let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y);

assert_eq!(final_sum, sum);
// check evaluations
assert_eq!(
eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt),
eq_eval_less_or_equal_than(last_evaluation_index - 1, &out_rt, &rt),
evals[0]
);
for i in 0..SEPTIC_EXTENSION_DEGREE {
Expand Down Expand Up @@ -263,7 +304,7 @@ impl CpuEccProver {
// TODO: prove the validity of s[0,rt], x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt]
EccQuarkProof {
zerocheck_proof,
num_vars: n,
num_instances,
evals,
sum,
}
Expand Down Expand Up @@ -1090,31 +1131,35 @@ where

#[cfg(test)]
mod tests {
use std::iter::repeat;

use crate::scheme::{
constants::SEPTIC_EXTENSION_DEGREE,
cpu::CpuEccProver,
septic_curve::{SepticExtension, SepticPoint},
verifier::EccVerifier,
};
use ff_ext::BabyBearExt4;
use itertools::Itertools;
use multilinear_extensions::{
mle::{IntoMLE, MultilinearExtension},
util::transpose,
};
use p3::babybear::BabyBear;
use std::iter::repeat_n;
use transcript::BasicTranscript;

use crate::scheme::{
constants::SEPTIC_EXTENSION_DEGREE,
cpu::CpuEccProver,
septic_curve::{SepticExtension, SepticPoint},
verifier::EccVerifier,
};
use witness::next_pow2_instance_padding;

#[test]
fn test_ecc_quark_prover() {
for n_points in 1..2 ^ 10 {
test_ecc_quark_prover_inner(n_points)
}
}

fn test_ecc_quark_prover_inner(n_points: usize) {
type E = BabyBearExt4;
type F = BabyBear;

let log2_n = 6;
let n_points = 1 << log2_n;
let log2_n = next_pow2_instance_padding(n_points).ilog2();
let mut rng = rand::thread_rng();

let final_sum;
Expand All @@ -1124,7 +1169,11 @@ mod tests {
let mut points = (0..n_points)
.map(|_| SepticPoint::<F>::random(&mut rng))
.collect_vec();
let mut s = Vec::with_capacity(n_points);
points.extend(repeat_n(
SepticPoint::point_at_infinity(),
(1 << log2_n) - points.len(),
));
let mut s = Vec::with_capacity(1 << (log2_n + 1));

for layer in (1..=log2_n).rev() {
let num_inputs = 1 << layer;
Expand All @@ -1133,17 +1182,19 @@ mod tests {
s.extend(inputs.chunks_exact(2).map(|chunk| {
let p = &chunk[0];
let q = &chunk[1];

(&p.y - &q.y) * (&p.x - &q.x).inverse().unwrap()
if q.is_infinity {
SepticExtension::zero()
} else {
(&p.y - &q.y) * (&p.x - &q.x).inverse().unwrap()
}
}));

points.extend(
points[points.len() - num_inputs..]
inputs
.chunks_exact(2)
.map(|chunk| {
let p = chunk[0].clone();
let q = chunk[1].clone();

p + q
})
.collect_vec(),
Expand All @@ -1152,11 +1203,14 @@ mod tests {
final_sum = points.last().cloned().unwrap();

// padding to 2*N
s.extend(repeat(SepticExtension::zero()).take(n_points + 1));
s.extend(repeat_n(
SepticExtension::zero(),
(1 << (log2_n + 1)) - s.len(),
));
points.push(SepticPoint::point_at_infinity());

assert_eq!(s.len(), 2 * n_points);
assert_eq!(points.len(), 2 * n_points);
assert_eq!(s.len(), 1 << (log2_n + 1));
assert_eq!(points.len(), 1 << (log2_n + 1));

// transform points to row major matrix
let trace = points
Expand All @@ -1183,6 +1237,7 @@ mod tests {
let mut transcript = BasicTranscript::new(b"test");
let prover = CpuEccProver::new();
let quark_proof = prover.create_ecc_proof(
n_points,
xs.to_vec(),
ys.to_vec(),
s.to_vec(),
Expand All @@ -1195,6 +1250,7 @@ mod tests {
assert!(
verifier
.verify_ecc_proof(&quark_proof, &mut transcript)
.inspect_err(|err| println!("err {:?}", err))
.is_ok()
);
}
Expand Down
22 changes: 22 additions & 0 deletions ceno_zkvm/src/scheme/septic_curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,28 @@ impl<F: Field> MulAssign<Self> for SepticExtension<F> {
#[derive(Clone, Debug)]
pub struct SymbolicSepticExtension<E: ExtensionField>(pub Vec<Expression<E>>);

impl<E: ExtensionField> SymbolicSepticExtension<E> {
pub fn mul_scalar(&self, scalar: Either<E::BaseField, E>) -> Self {
let res = self
.0
.iter()
.map(|a| a.clone() * Expression::Constant(scalar))
.collect();

SymbolicSepticExtension(res)
}

pub fn add_scalar(&self, scalar: Either<E::BaseField, E>) -> Self {
let res = self
.0
.iter()
.map(|a| a.clone() + Expression::Constant(scalar))
.collect();

SymbolicSepticExtension(res)
}
}

impl<E: ExtensionField> Add<Self> for &SymbolicSepticExtension<E> {
type Output = SymbolicSepticExtension<E>;

Expand Down
Loading
Loading