Skip to content

Commit f94004c

Browse files
committed
refactor and use new selector type
1 parent 0371dd6 commit f94004c

File tree

4 files changed

+56
-45
lines changed

4 files changed

+56
-45
lines changed

ceno_zkvm/src/scheme/cpu/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ impl CpuEccProver {
311311
// TODO: prove the validity of s[0,rt], x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt]
312312
EccQuarkProof {
313313
zerocheck_proof,
314-
num_vars: n,
314+
num_instances,
315315
evals,
316316
sum,
317317
}

ceno_zkvm/src/scheme/verifier.rs

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ use gkr_iop::{gkr::GKRClaims, selector::SelectorType, utils::eq_eval_less_or_equ
2424
use itertools::{Itertools, chain, interleave, izip};
2525
use mpcs::{Point, PolynomialCommitmentScheme};
2626
use multilinear_extensions::{
27-
Instance, StructuralWitIn, StructuralWitInType,
27+
Expression, Instance, StructuralWitIn, StructuralWitInType,
28+
StructuralWitInType::StackedConstantSequence,
2829
mle::IntoMLE,
2930
util::ceil_log2,
3031
utils::eval_by_expr_with_instance,
@@ -811,7 +812,7 @@ impl TowerVerify {
811812

812813
let max_num_variables = num_variables.iter().max().unwrap();
813814

814-
let (next_rt, _) = (0..(max_num_variables-1)).try_fold(
815+
let (next_rt, _) = (0..(max_num_variables - 1)).try_fold(
815816
(
816817
PointAndEval {
817818
point: initial_rt,
@@ -844,31 +845,31 @@ impl TowerVerify {
844845
// prod'[b] = prod[0,b] * prod[1,b]
845846
// prod'[out_rt] = \sum_b eq(out_rt,b) * prod'[b] = \sum_b eq(out_rt,b) * prod[0,b] * prod[1,b]
846847
eq * *alpha
847-
* if round < *max_round-1 {tower_proofs.prod_specs_eval[spec_index][round].iter().copied().product()} else {
848-
E::ZERO
849-
}
848+
* if round < *max_round - 1 { tower_proofs.prod_specs_eval[spec_index][round].iter().copied().product() } else {
849+
E::ZERO
850+
}
850851
})
851852
.sum::<E>()
852853
+ (0..num_logup_spec)
853-
.zip_eq(alpha_pows[num_prod_spec..].chunks(2))
854-
.zip_eq(num_variables[num_prod_spec..].iter())
855-
.map(|((spec_index, alpha), max_round)| {
856-
// logup_q'[b] = logup_q[0,b] * logup_q[1,b]
857-
// logup_p'[b] = logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b]
858-
// logup_p'[out_rt] = \sum_b eq(out_rt,b) * (logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b])
859-
// logup_q'[out_rt] = \sum_b eq(out_rt,b) * logup_q[0,b] * logup_q[1,b]
860-
let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]);
861-
eq * if round < *max_round-1 {
862-
let evals = &tower_proofs.logup_specs_eval[spec_index][round];
863-
let (p1, p2, q1, q2) =
864-
(evals[0], evals[1], evals[2], evals[3]);
865-
*alpha_numerator * (p1 * q2 + p2 * q1)
866-
+ *alpha_denominator * (q1 * q2)
867-
} else {
868-
E::ZERO
869-
}
870-
})
871-
.sum::<E>();
854+
.zip_eq(alpha_pows[num_prod_spec..].chunks(2))
855+
.zip_eq(num_variables[num_prod_spec..].iter())
856+
.map(|((spec_index, alpha), max_round)| {
857+
// logup_q'[b] = logup_q[0,b] * logup_q[1,b]
858+
// logup_p'[b] = logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b]
859+
// logup_p'[out_rt] = \sum_b eq(out_rt,b) * (logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b])
860+
// logup_q'[out_rt] = \sum_b eq(out_rt,b) * logup_q[0,b] * logup_q[1,b]
861+
let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]);
862+
eq * if round < *max_round - 1 {
863+
let evals = &tower_proofs.logup_specs_eval[spec_index][round];
864+
let (p1, p2, q1, q2) =
865+
(evals[0], evals[1], evals[2], evals[3]);
866+
*alpha_numerator * (p1 * q2 + p2 * q1)
867+
+ *alpha_denominator * (q1 * q2)
868+
} else {
869+
E::ZERO
870+
}
871+
})
872+
.sum::<E>();
872873

873874
if expected_evaluation != sumcheck_claim.expected_evaluation {
874875
return Err(ZKVMError::VerifyError("mismatch tower evaluation".into()));
@@ -877,7 +878,7 @@ impl TowerVerify {
877878
// derive single eval
878879
// rt' = r_merge || rt
879880
// r_merge.len() == ceil_log2(num_product_fanin)
880-
let r_merge =transcript.sample_and_append_vec(b"merge", log2_num_fanin);
881+
let r_merge = transcript.sample_and_append_vec(b"merge", log2_num_fanin);
881882
let coeffs = build_eq_x_r_vec_sequential(&r_merge);
882883
assert_eq!(coeffs.len(), num_fanin);
883884
let rt_prime = [rt, r_merge].concat();
@@ -893,17 +894,17 @@ impl TowerVerify {
893894
.zip(num_variables.iter())
894895
.map(|((spec_index, alpha), max_round)| {
895896
// prod'[rt,r_merge] = \sum_b eq(r_merge, b) * prod'[b,rt]
896-
if round < max_round -1 {
897+
if round < max_round - 1 {
897898
// merged evaluation
898899
let evals = izip!(
899900
tower_proofs.prod_specs_eval[spec_index][round].iter(),
900901
coeffs.iter()
901902
)
902-
.map(|(a, b)| *a * *b)
903-
.sum::<E>();
903+
.map(|(a, b)| *a * *b)
904+
.sum::<E>();
904905
// this will keep update until round > evaluation
905906
prod_spec_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), evals);
906-
if next_round < max_round -1 {
907+
if next_round < max_round - 1 {
907908
*alpha * evals
908909
} else {
909910
E::ZERO
@@ -917,28 +918,28 @@ impl TowerVerify {
917918
.zip_eq(next_alpha_pows[num_prod_spec..].chunks(2))
918919
.zip_eq(num_variables[num_prod_spec..].iter())
919920
.map(|((spec_index, alpha), max_round)| {
920-
if round < max_round -1 {
921+
if round < max_round - 1 {
921922
let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]);
922923
// merged evaluation
923924
let p_evals = izip!(
924925
tower_proofs.logup_specs_eval[spec_index][round][0..2].iter(),
925926
coeffs.iter()
926927
)
927-
.map(|(a, b)| *a * *b)
928-
.sum::<E>();
928+
.map(|(a, b)| *a * *b)
929+
.sum::<E>();
929930

930931
let q_evals = izip!(
931932
tower_proofs.logup_specs_eval[spec_index][round][2..4].iter(),
932933
coeffs.iter()
933934
)
934-
.map(|(a, b)| *a * *b)
935-
.sum::<E>();
935+
.map(|(a, b)| *a * *b)
936+
.sum::<E>();
936937

937938
// this will keep update until round > evaluation
938939
logup_spec_p_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), p_evals);
939940
logup_spec_q_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), q_evals);
940941

941-
if next_round < max_round -1 {
942+
if next_round < max_round - 1 {
942943
*alpha_numerator * p_evals + *alpha_denominator * q_evals
943944
} else {
944945
E::ZERO
@@ -980,7 +981,8 @@ impl EccVerifier {
980981
proof: &EccQuarkProof<E>,
981982
transcript: &mut impl Transcript<E>,
982983
) -> Result<(), ZKVMError> {
983-
let out_rt = transcript.sample_and_append_vec(b"ecc", proof.num_vars);
984+
let num_vars = next_pow2_instance_padding(proof.num_instances).ilog2() as usize;
985+
let out_rt = transcript.sample_and_append_vec(b"ecc", num_vars);
984986
let alpha_pows =
985987
transcript.sample_and_append_challenge_pows(SEPTIC_EXTENSION_DEGREE * 3, b"ecc_alpha");
986988

@@ -989,7 +991,7 @@ impl EccVerifier {
989991
&proof.zerocheck_proof,
990992
&VPAuxInfo {
991993
max_degree: 3,
992-
max_num_variables: proof.num_vars,
994+
max_num_variables: num_vars,
993995
phantom: PhantomData,
994996
},
995997
transcript,
@@ -1023,7 +1025,6 @@ impl EccVerifier {
10231025
.try_into()
10241026
.unwrap();
10251027

1026-
let num_instances = (1 << proof.num_vars) - 1;
10271028
let rt = sumcheck_claim
10281029
.point
10291030
.iter()
@@ -1052,15 +1053,20 @@ impl EccVerifier {
10521053
})
10531054
.sum();
10541055

1055-
let sel = eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt);
1056-
// let SelectorType::QuarkBinaryTreeLessThan(1.into());
1057-
// let sel = eq_quark_form(num_instances - 1, &out_rt, &rt);
1058-
if sumcheck_claim.expected_evaluation != v * sel {
1056+
let sel_add_expr = SelectorType::<E>::QuarkBinaryTreeLessThan(Expression::StructuralWitIn(
1057+
0,
1058+
// this value doesn't matter, as we only need structural id
1059+
StackedConstantSequence { max_value: 0 },
1060+
));
1061+
let mut sel_evals = vec![E::ZERO];
1062+
sel_add_expr.evaluate(&mut sel_evals, &out_rt, &rt, proof.num_instances, 0);
1063+
let sel_add = sel_evals[0];
1064+
if sumcheck_claim.expected_evaluation != v * sel_add {
10591065
return Err(ZKVMError::VerifyError(
10601066
(format!(
10611067
"ecc zerocheck failed: mismatched evaluation, expected {}, got {}",
10621068
sumcheck_claim.expected_evaluation,
1063-
v * sel
1069+
v * sel_add
10641070
))
10651071
.into(),
10661072
));

ceno_zkvm/src/structs.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use witness::RowMajorMatrix;
3131
))]
3232
pub struct EccQuarkProof<E: ExtensionField> {
3333
pub zerocheck_proof: IOPProof<E>,
34-
pub num_vars: usize,
34+
pub num_instances: usize,
3535
pub evals: Vec<E>, // x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[0,rt], y[0,rt], s[0,rt]
3636
pub sum: SepticPoint<E::BaseField>,
3737
}

gkr_iop/src/selector.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,11 @@ impl<E: ExtensionField> SelectorType<E> {
176176
// where nodes size is 2^(N) / 2
177177
// out_point.len() is also log(2^(N)) - 1
178178
// so num_instances and 1 << out_point.len() are on same scaling
179+
println!(
180+
"num_instances {} out_point.len() {}",
181+
num_instances,
182+
out_point.len()
183+
);
179184
assert!(num_instances > 0);
180185
assert!(num_instances <= (1 << out_point.len()));
181186
if out_point.is_empty() {

0 commit comments

Comments
 (0)