Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ opt-level = 3
[profile.release]
lto = "thin"

#[patch."ssh://[email protected]/scroll-tech/ceno-gpu.git"]
#ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal" }
# [patch."ssh://[email protected]/scroll-tech/ceno-gpu.git"]
# ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal" }

#[patch."https://github.com/scroll-tech/gkr-backend"]
#ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" }
Expand Down
2 changes: 1 addition & 1 deletion build-scripts/conditional-patch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ WORKSPACE_CARGO="Cargo.toml"

# Workspace dependency declarations
LOCAL_DEP='ceno_gpu = { path = "utils/cuda_hal", package = "cuda_hal" }'
REMOTE_DEP='ceno_gpu = { git = "ssh://[email protected]/scroll-tech/ceno-gpu.git", package = "cuda_hal", branch = "dev/integrate-into-ceno-as-dep" }'
REMOTE_DEP='ceno_gpu = { git = "ssh://[email protected]/scroll-tech/ceno-gpu.git", package = "cuda_hal", branch = "main", default-features = false, features = \["bb31"\] }'

if [ "$1" = "enable-gpu" ]; then
echo "Switching to GPU mode (using remote implementation)..."
Expand Down
10 changes: 5 additions & 5 deletions ceno_zkvm/src/scheme/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,15 +543,15 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> MainSumcheckProver<C
for CpuProver<CpuBackend<E, PCS>>
{
#[allow(clippy::type_complexity)]
#[tracing::instrument(skip_all, name = "table_witness", fields(profiling_3), level = "trace")]
#[tracing::instrument(skip_all, name = "table_witness", fields(profiling_2), level = "trace")]
fn table_witness<'a>(
&self,
input: &ProofInput<'a, CpuBackend<<CpuBackend<E, PCS> as ProverBackend>::E, PCS>>,
cs: &ConstraintSystem<<CpuBackend<E, PCS> as ProverBackend>::E>,
challenges: &[<CpuBackend<E, PCS> as ProverBackend>::E],
) -> Vec<Arc<<CpuBackend<E, PCS> as ProverBackend>::MultilinearPoly<'a>>> {
// main constraint: lookup denominator and numerator record witness inference
let record_span = entered_span!("record");
let span = entered_span!("witness_infer", profiling_2 = true);
let records: Vec<ArcMultilinearExtension<'_, E>> = cs
.r_table_expressions
.par_iter()
Expand Down Expand Up @@ -581,7 +581,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> MainSumcheckProver<C
)
})
.collect();
exit_span!(record_span);
exit_span!(span);
records
}

Expand Down Expand Up @@ -774,9 +774,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> DeviceTransporter<Cp

fn transport_mles<'a>(
&self,
mles: Vec<MultilinearExtension<'a, E>>,
mles: &[MultilinearExtension<'a, E>],
) -> Vec<ArcMultilinearExtension<'a, E>> {
mles.into_iter().map(|mle| mle.into()).collect_vec()
mles.iter().map(|mle| mle.clone().into()).collect_vec()
}
}

Expand Down
115 changes: 68 additions & 47 deletions ceno_zkvm/src/scheme/gpu/mod.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ceno_zkvm/src/scheme/hal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ pub trait DeviceTransporter<PB: ProverBackend> {

fn transport_mles<'a>(
&self,
mles: Vec<MultilinearExtension<'a, PB::E>>,
mles: &[MultilinearExtension<'a, PB::E>],
) -> Vec<Arc<PB::MultilinearPoly<'a>>>;
}

Expand Down
29 changes: 21 additions & 8 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ impl<
];
tracing::debug!("global challenges in prover: {:?}", challenges);

let public_input_span = entered_span!("public_input", profiling_1 = true);
let public_input = self.device.transport_mles(&pi);
exit_span!(public_input_span);

let main_proofs_span = entered_span!("main_proofs", profiling_1 = true);
let (points, evaluations) = self.pk.circuit_pks.iter().enumerate().try_fold(
(vec![], vec![]),
Expand All @@ -216,24 +220,29 @@ impl<
return Ok::<(Vec<_>, Vec<Vec<_>>), ZKVMError>((points, evaluations));
}
transcript.append_field_element(&E::BaseField::from_canonical_u64(index as u64));

// TODO: add an enum for circuit type either in constraint_system or vk
let witness_mle = witness_mles
.drain(..cs.num_witin())
.map(|mle| mle.into())
.collect_vec();
let structural_witness = self.device.transport_mles(
structural_wits
.remove(circuit_name)
.map(|(sw, _)| sw)
.unwrap_or(vec![]),
);

let structural_witness_span =
entered_span!("structural_witness", profiling_2 = true);
let structural_mles = structural_wits
.remove(circuit_name)
.map(|(sw, _)| sw)
.unwrap_or(vec![]);
let structural_witness = self.device.transport_mles(&structural_mles);
exit_span!(structural_witness_span);

let fixed = fixed_mles.drain(..cs.num_fixed()).collect_vec();
let public_input = self.device.transport_mles(pi.clone());

let mut input = ProofInput {
witness: witness_mle,
fixed,
structural_witness,
public_input,
public_input: public_input.clone(),
num_instances,
};

Expand Down Expand Up @@ -327,6 +336,8 @@ impl<
let log2_num_instances = input.log2_num_instances();
let num_var_with_rotation = log2_num_instances + cs.rotation_vars().unwrap_or(0);

// println!("create_chip_proof: {}", name);

// build main witness
let (records, is_padded) =
build_main_witness::<E, PCS, PB, PD>(&self.device, cs, &input, challenges);
Expand All @@ -346,13 +357,15 @@ impl<

// 1. prove the main constraints among witness polynomials
// 2. prove the relation between last layer in the tower and read/write/logup records
let span = entered_span!("prove_main_constraints", profiling_2 = true);
let (input_opening_point, evals, main_sumcheck_proofs, gkr_iop_proof) = self
.device
.prove_main_constraints(rt_tower, &input, cs, challenges, transcript)?;
let MainSumcheckEvals {
wits_in_evals,
fixed_in_evals,
} = evals;
exit_span!(span);

// evaluate pi if there is instance query
let mut pi_in_evals: HashMap<usize, E> = HashMap::new();
Expand Down
9 changes: 6 additions & 3 deletions ceno_zkvm/src/scheme/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use itertools::Itertools;
use mpcs::PolynomialCommitmentScheme;
pub use multilinear_extensions::wit_infer_by_expr;
use multilinear_extensions::{
macros::{entered_span, exit_span},
mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension},
util::ceil_log2,
};
Expand Down Expand Up @@ -297,6 +296,12 @@ pub(crate) fn infer_tower_product_witness<E: ExtensionField>(
wit_layers
}

#[tracing::instrument(
skip_all,
name = "build_main_witness",
fields(profiling_2),
level = "trace"
)]
pub fn build_main_witness<
'a,
E: ExtensionField,
Expand Down Expand Up @@ -439,7 +444,6 @@ pub fn gkr_witness<
// generate all layer witness from input to output
for (i, layer) in circuit.layers.iter().rev().enumerate() {
tracing::debug!("generating input {i} layer with layer name {}", layer.name);
let span = entered_span!("per_layer_gen_witness", profiling_2 = true);
// process in_evals to prepare layer witness
// This should assume the input of the first layer is the phase1 witness of the circuit.
let current_layer_wits = layer
Expand Down Expand Up @@ -486,7 +490,6 @@ pub fn gkr_witness<
}
other => unimplemented!("{:?}", other),
});
exit_span!(span);
}
layer_wits.reverse();

Expand Down
16 changes: 12 additions & 4 deletions gkr_iop/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use ff_ext::ExtensionField;
use itertools::izip;
use mpcs::{PolynomialCommitmentScheme, SecurityLevel, SecurityLevel::Conjecture100bits};
use multilinear_extensions::{
macros::{entered_span, exit_span},
mle::{ArcMultilinearExtension, MultilinearExtension, Point},
wit_infer_by_monomial_expr,
};
Expand Down Expand Up @@ -111,12 +112,13 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>>
pub_io_evals: &[Arc<<CpuBackend<E, PCS> as ProverBackend>::MultilinearPoly<'a>>],
challenges: &[E],
) -> Vec<Arc<<CpuBackend<E, PCS> as ProverBackend>::MultilinearPoly<'a>>> {
let span = entered_span!("witness_infer", profiling_2 = true);
let out_evals: Vec<_> = layer
.out_sel_and_eval_exprs
.iter()
.flat_map(|(sel_type, out_eval)| izip!(iter::repeat(sel_type), out_eval.iter()))
.collect();
layer
let res = layer
.exprs_with_selector_out_eval_monomial_form
.par_iter()
.zip_eq(layer.expr_names.par_iter())
Expand All @@ -141,10 +143,13 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>>
EvalExpression::Partition(_, _) => unimplemented!(),
}
})
.collect::<Vec<_>>()
.collect::<Vec<_>>();
exit_span!(span);
res
}
}

#[tracing::instrument(skip_all, name = "layer_witness", fields(profiling_2), level = "trace")]
pub fn layer_witness<'a, E>(
layer: &Layer<E>,
layer_wits: &[ArcMultilinearExtension<'a, E>],
Expand All @@ -154,12 +159,13 @@ pub fn layer_witness<'a, E>(
where
E: ExtensionField,
{
let span = entered_span!("witness_infer", profiling_2 = true);
let out_evals: Vec<_> = layer
.out_sel_and_eval_exprs
.iter()
.flat_map(|(sel_type, out_eval)| izip!(iter::repeat(sel_type), out_eval.iter()))
.collect();
layer
let res = layer
.exprs_with_selector_out_eval_monomial_form
.par_iter()
.zip_eq(layer.expr_names.par_iter())
Expand All @@ -184,5 +190,7 @@ where
EvalExpression::Partition(_, _) => unimplemented!(),
}
})
.collect::<Vec<_>>()
.collect::<Vec<_>>();
exit_span!(span);
res
}
51 changes: 28 additions & 23 deletions gkr_iop/src/gkr/layer/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,18 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> LinearLayerProver<Gp
out_point: &multilinear_extensions::mle::Point<E>,
transcript: &mut impl transcript::Transcript<E>,
) -> crate::gkr::layer::sumcheck_layer::LayerProof<E> {
let span = entered_span!("LinearLayerProver", profiling_2 = true);
let cpu_wits: Vec<Arc<MultilinearExtension<'_, E>>> = wit
.0
.into_iter()
.map(|gpu_mle| Arc::new(gpu_mle.inner_to_mle()))
.collect();
let cpu_wit = LayerWitness::<CpuBackend<E, PCS>>(cpu_wits);
<CpuProver<CpuBackend<E, PCS>> as LinearLayerProver<CpuBackend<E, PCS>>>::prove(
let res = <CpuProver<CpuBackend<E, PCS>> as LinearLayerProver<CpuBackend<E, PCS>>>::prove(
layer, cpu_wit, out_point, transcript,
)
);
exit_span!(span);
res
}
}

Expand All @@ -77,20 +80,23 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> SumcheckLayerProver<
challenges: &[<GpuBackend<E, PCS> as ProverBackend>::E],
transcript: &mut impl Transcript<<GpuBackend<E, PCS> as ProverBackend>::E>,
) -> LayerProof<<GpuBackend<E, PCS> as ProverBackend>::E> {
let span = entered_span!("SumcheckLayerProver", profiling_2 = true);
let cpu_wits: Vec<Arc<MultilinearExtension<'_, E>>> = wit
.0
.into_iter()
.map(|gpu_mle| Arc::new(gpu_mle.inner_to_mle()))
.collect();
let cpu_wit = LayerWitness::<CpuBackend<E, PCS>>(cpu_wits);
<CpuProver<CpuBackend<E, PCS>> as SumcheckLayerProver<CpuBackend<E, PCS>>>::prove(
let res = <CpuProver<CpuBackend<E, PCS>> as SumcheckLayerProver<CpuBackend<E, PCS>>>::prove(
layer,
num_threads,
max_num_variables,
cpu_wit,
challenges,
transcript,
)
);
exit_span!(span);
res
}
}

Expand All @@ -111,6 +117,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
LayerProof<<GpuBackend<E, PCS> as ProverBackend>::E>,
Point<<GpuBackend<E, PCS> as ProverBackend>::E>,
) {
let span = entered_span!("ZerocheckLayerProver", profiling_2 = true);
let num_threads = 1; // VP builder for GPU: do not use _num_threads

assert_eq!(challenges.len(), 2);
Expand Down Expand Up @@ -163,7 +170,6 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
)
.collect_vec();

let span = entered_span!("IOPProverState::prove", profiling_4 = true);
let cuda_hal = get_cuda_hal().unwrap();
let eqs_gpu = layer
.out_sel_and_eval_exprs
Expand Down Expand Up @@ -222,11 +228,11 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
.unwrap_or(0);

// Convert types for GPU function Call
let basic_tr: &mut BasicTranscript<GL64Ext> =
unsafe { &mut *(transcript as *mut _ as *mut BasicTranscript<GL64Ext>) };
let term_coefficients_gl64: Vec<GL64Ext> =
let basic_tr: &mut BasicTranscript<BB31Ext> =
unsafe { &mut *(transcript as *mut _ as *mut BasicTranscript<BB31Ext>) };
let term_coefficients_gl64: Vec<BB31Ext> =
unsafe { std::mem::transmute(term_coefficients) };
let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu<GL64Ext>> =
let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu<BB31Ext>> =
unsafe { std::mem::transmute(all_witins_gpu) };
let all_witins_gpu_type_gl64 = all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec();
let (proof_gpu, evals_gpu, challenges_gpu) = cuda_hal
Expand All @@ -247,13 +253,12 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver

// convert back to E: ExtensionField
let proof_gpu_e =
unsafe { std::mem::transmute::<IOPProof<GL64Ext>, IOPProof<E>>(proof_gpu) };
let evals_gpu_e = unsafe { std::mem::transmute::<Vec<GL64Ext>, Vec<E>>(evals_gpu) };
unsafe { std::mem::transmute::<IOPProof<BB31Ext>, IOPProof<E>>(proof_gpu) };
let evals_gpu_e = unsafe { std::mem::transmute::<Vec<BB31Ext>, Vec<E>>(evals_gpu) };
let row_challenges_e =
unsafe { std::mem::transmute::<Vec<GL64Ext>, Vec<E>>(row_challenges) };
unsafe { std::mem::transmute::<Vec<BB31Ext>, Vec<E>>(row_challenges) };

exit_span!(span);

(
LayerProof {
main: SumcheckLayerProof {
Expand Down Expand Up @@ -292,7 +297,7 @@ pub(crate) fn prove_rotation_gpu<E: ExtensionField, PCS: PolynomialCommitmentSch

// rotated_mles is non-deterministic input, rotated from existing witness polynomial
// we will reduce it to zero check, and finally reduce to committed polynomial opening
let span = entered_span!("rotate_witin_selector", profiling_4 = true);
let span = entered_span!("rotate_witin_selector", profiling_3 = true);
let rotated_mles_gpu = build_rotation_mles_gpu(
&cuda_hal,
raw_rotation_exprs,
Expand All @@ -315,7 +320,7 @@ pub(crate) fn prove_rotation_gpu<E: ExtensionField, PCS: PolynomialCommitmentSch
.collect_vec();
exit_span!(span);

let span = entered_span!("rotation IOPProverState::prove", profiling_4 = true);
let span = entered_span!("rotation IOPProverState::prove", profiling_3 = true);
// gpu mles
let mle_gpu_ref: Vec<&MultilinearExtensionGpu<E>> = rotated_mles_gpu
.iter()
Expand Down Expand Up @@ -344,10 +349,10 @@ pub(crate) fn prove_rotation_gpu<E: ExtensionField, PCS: PolynomialCommitmentSch
.unwrap_or(0);

// Convert types for GPU function call
let basic_tr: &mut BasicTranscript<GL64Ext> =
unsafe { &mut *(transcript as *mut _ as *mut BasicTranscript<GL64Ext>) };
let term_coefficients_gl64: Vec<GL64Ext> = unsafe { std::mem::transmute(term_coefficients) };
let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu<GL64Ext>> =
let basic_tr: &mut BasicTranscript<BB31Ext> =
unsafe { &mut *(transcript as *mut _ as *mut BasicTranscript<BB31Ext>) };
let term_coefficients_gl64: Vec<BB31Ext> = unsafe { std::mem::transmute(term_coefficients) };
let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu<BB31Ext>> =
unsafe { std::mem::transmute(mle_gpu_ref) };
let all_witins_gpu_type_gl64 = all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec();
// gpu prover
Expand All @@ -367,14 +372,14 @@ pub(crate) fn prove_rotation_gpu<E: ExtensionField, PCS: PolynomialCommitmentSch
let evals_gpu = evals_gpu.into_iter().flatten().collect_vec();
let row_challenges = challenges_gpu.iter().map(|c| c.elements).collect_vec();

let proof_gpu_e = unsafe { std::mem::transmute::<IOPProof<GL64Ext>, IOPProof<E>>(proof_gpu) };
let mut evals_gpu_e = unsafe { std::mem::transmute::<Vec<GL64Ext>, Vec<E>>(evals_gpu) };
let row_challenges_e = unsafe { std::mem::transmute::<Vec<GL64Ext>, Vec<E>>(row_challenges) };
let proof_gpu_e = unsafe { std::mem::transmute::<IOPProof<BB31Ext>, IOPProof<E>>(proof_gpu) };
let mut evals_gpu_e = unsafe { std::mem::transmute::<Vec<BB31Ext>, Vec<E>>(evals_gpu) };
let row_challenges_e = unsafe { std::mem::transmute::<Vec<BB31Ext>, Vec<E>>(row_challenges) };
// skip selector/eq as verifier can derive itself
evals_gpu_e.truncate(raw_rotation_exprs.len() * 2);
exit_span!(span);

let span = entered_span!("rotation derived left/right eval", profiling_4 = true);
let span = entered_span!("rotation derived left/right eval", profiling_3 = true);
let bh = BooleanHypercube::new(rotation_cyclic_group_log2);
let (left_point, right_point) = bh.get_rotation_points(&row_challenges_e);
let evals = evals_gpu_e
Expand Down
Loading
Loading