diff --git a/bento/Cargo.toml b/bento/Cargo.toml index f5534b912..46c196a8d 100644 --- a/bento/Cargo.toml +++ b/bento/Cargo.toml @@ -18,7 +18,7 @@ repository = "https://github.com/risc0/bento/" [workspace.dependencies] anyhow = "1.0.98" -aws-sdk-s3 = "1.34" # used for minio for max compatibility +aws-sdk-s3 = "1.34" # used for minio for max compatibility bento-client = { path = "crates/bento-client" } bincode = "1.3" bonsai-sdk = { version = "1.4.0", features = ["non_blocking"] } @@ -26,9 +26,15 @@ bytemuck = "1.16" clap = { version = "4.5", features = ["derive", "env"] } deadpool-redis = "0.15" hex = { version = "0.4", default-features = false, features = ["alloc"] } -redis = { version = "0.25", features = ["tokio-comp"] } +redis = { version = "0.25", features = [ + "tokio-comp", + "cluster", + "connection-manager", +] } risc0-build = { version = "2.3.1" } -risc0-zkvm = { version = "2.3.1", features = ["unstable"], default-features = false } +risc0-zkvm = { version = "2.3.1", features = [ + "unstable", +], default-features = false } sample-guest-common = { path = "crates/sample-guest/common" } sample-guest-methods = { path = "crates/sample-guest/methods" } serde = { version = "1.0", features = ["derive"] } diff --git a/bento/crates/workflow/Cargo.toml b/bento/crates/workflow/Cargo.toml index ea7a784c3..b5b97bdd7 100644 --- a/bento/crates/workflow/Cargo.toml +++ b/bento/crates/workflow/Cargo.toml @@ -18,8 +18,15 @@ clap = { workspace = true, features = ["env", "derive"] } deadpool-redis = { workspace = true } hex = { workspace = true } nix = { version = "0.29", features = ["fs"] } -redis = { workspace = true, features = ["tokio-rustls-comp", "tokio-comp"] } -risc0-zkvm = { workspace = true, default-features = false, features = ["prove"] } +redis = { workspace = true, features = [ + "tokio-rustls-comp", + "tokio-comp", + "cluster", + "connection-manager", +] } +risc0-zkvm = { workspace = true, default-features = false, features = [ + "prove", +] } serde = { workspace = true } serde_json = { workspace = true } signal-hook = "0.3" diff --git a/bento/crates/workflow/src/redis.rs b/bento/crates/workflow/src/redis.rs index 6f831ee8e..8d36448bf 100644 --- a/bento/crates/workflow/src/redis.rs +++ b/bento/crates/workflow/src/redis.rs @@ -30,6 +30,35 @@ where } } +/// Batch set multiple keys with expiry using Redis pipelining +pub async fn batch_set_keys_with_expiry( + conn: &mut deadpool_redis::Connection, + key_value_pairs: Vec<(String, T)>, + ttl: Option, +) -> RedisResult<()> +where + T: ToRedisArgs + Send + Sync + 'static, +{ + if key_value_pairs.is_empty() { + return Ok(()); + } + + // Use Redis pipelining for bulk operations + let mut pipe = redis::pipe(); + + for (key, value) in key_value_pairs { + if let Some(expiry) = ttl { + pipe.set_ex(key, value, expiry); + } else { + pipe.set(key, value); + } + } + + // Execute the pipeline + pipe.query_async::<_, ()>(conn).await?; + Ok(()) +} + /// Scan and delete all keys at a given prefix pub async fn scan_and_delete(conn: &mut Connection, prefix: &str) -> RedisResult<()> { // Initialize the cursor for SCAN diff --git a/bento/crates/workflow/src/tasks/executor.rs b/bento/crates/workflow/src/tasks/executor.rs index a1855a27a..a28cbb935 100644 --- a/bento/crates/workflow/src/tasks/executor.rs +++ b/bento/crates/workflow/src/tasks/executor.rs @@ -30,7 +30,7 @@ use workflow_common::{ SnarkReq, UnionReq, AUX_WORK_TYPE, COPROC_WORK_TYPE, JOIN_WORK_TYPE, PROVE_WORK_TYPE, }; // use tempfile::NamedTempFile; -use tokio::task::{JoinHandle, JoinSet}; +use tokio::task::JoinHandle; use uuid::Uuid; const V2_ELF_MAGIC: &[u8] = b"R0BF"; // const V1_ ELF_MAGIC: [u8; 4] = [0x7f, 0x45, 0x4c, 0x46]; @@ -247,18 +247,18 @@ struct SessionData { } struct Coprocessor { - tx: tokio::sync::mpsc::Sender, + tx: tokio::sync::mpsc::Sender, } impl Coprocessor { - fn new(tx: tokio::sync::mpsc::Sender) -> Self { + fn new(tx: tokio::sync::mpsc::Sender) -> Self { Self { tx } } } impl CoprocessorCallback for Coprocessor { fn prove_keccak(&mut self, request: ProveKeccakRequest) -> Result<()> { - self.tx.blocking_send(SenderType::Keccak(request))?; + self.tx.blocking_send(request)?; Ok(()) } @@ -268,16 +268,10 @@ impl CoprocessorCallback for Coprocessor { } } -enum SenderType { - Segment(u32), - Keccak(ProveKeccakRequest), - Fault, -} - /// Run the executor emitting the segments and session to hot storage /// -/// Writes out all segments async using tokio tasks then waits for all -/// tasks to complete before exiting. +/// Collects all segments in memory, then writes them to Redis in batch, +/// and finally updates the database with all tasks. pub async fn executor(agent: &Agent, job_id: &Uuid, request: &ExecutorReq) -> Result { let mut conn = agent.redis_pool.get().await?; let job_prefix = format!("job:{job_id}"); @@ -367,41 +361,9 @@ pub async fn executor(agent: &Agent, job_id: &Uuid, request: &ExecutorReq) -> Re // set the segment prefix let segments_prefix = format!("{job_prefix}:{SEGMENTS_PATH}"); - // queue segments into a spmc queue + // Collect segments and keccak requests in memory let (segment_tx, mut segment_rx) = tokio::sync::mpsc::channel::(CONCURRENT_SEGMENTS); - let (task_tx, mut task_rx) = tokio::sync::mpsc::channel::(TASK_QUEUE_SIZE); - let task_tx_clone = task_tx.clone(); - - let mut writer_conn = agent.redis_pool.get().await?; - let segments_prefix_clone = segments_prefix.clone(); - let redis_ttl = agent.args.redis_ttl; - - let mut writer_tasks = JoinSet::new(); - writer_tasks.spawn(async move { - while let Some(segment) = segment_rx.recv().await { - let index = segment.index; - tracing::debug!("Starting write of index: {index}"); - let segment_key = format!("{segments_prefix_clone}:{index}"); - let segment_vec = serialize_obj(&segment).expect("Failed to serialize the segment"); - redis::set_key_with_expiry( - &mut writer_conn, - &segment_key, - segment_vec, - Some(redis_ttl), - ) - .await - .expect("Failed to set key with expiry"); - tracing::debug!("Completed write of {index}"); - - task_tx - .send(SenderType::Segment(index)) - .await - .expect("failed to push task into task_tx"); - } - // Once the segments wraps up, close the task channel to signal completion to the follow up - // task - drop(task_tx); - }); + let (keccak_tx, mut keccak_rx) = tokio::sync::mpsc::channel::(100); let aux_stream = taskdb::get_stream(&agent.db_pool, &request.user_id, AUX_WORK_TYPE) .await @@ -448,115 +410,32 @@ pub async fn executor(agent: &Agent, job_id: &Uuid, request: &ExecutorReq) -> Re let compress_type = request.compress; let exec_only = request.execute_only; - // Write keccak data to redis + schedule proving - let coproc = Coprocessor::new(task_tx_clone.clone()); - let mut coproc_redis = agent.redis_pool.get().await?; - let coproc_prefix = format!("{job_prefix}:{COPROC_CB_PATH}"); + // Collect segments and keccak requests in memory + let mut segments = Vec::new(); + let mut keccak_requests = Vec::new(); let mut guest_fault = false; - // Generate tasks - writer_tasks.spawn(async move { - let mut planner = Planner::default(); - while let Some(task_type) = task_rx.recv().await { - if exec_only { - continue; - } - - match task_type { - SenderType::Segment(segment_index) => { - planner.enqueue_segment().expect("Failed to enqueue segment"); - while let Some(tree_task) = planner.next_task() { - process_task( - &args_copy, - &pool_copy, - &prove_stream, - &join_stream, - &union_stream, - &aux_stream, - &job_id_copy, - tree_task, - Some(segment_index), - &assumptions, - compress_type, - None, - ) - .await - .expect("Failed to process task and insert into taskdb"); - } - } - SenderType::Keccak(mut keccak_req) => { - let redis_key = format!("{coproc_prefix}:{}", keccak_req.claim_digest); - redis::set_key_with_expiry( - &mut coproc_redis, - &redis_key, - // input, - bytemuck::cast_slice::<_, u8>(&keccak_req.input).to_vec(), - Some(redis_ttl), - ) - .await - .expect("Failed to set key with expiry"); - keccak_req.input.clear(); - tracing::debug!("Wrote keccak input to redis"); - - planner.enqueue_keccak().expect("Failed to enqueue keccak"); - while let Some(tree_task) = planner.next_task() { - let req = KeccakReq { - claim_digest: keccak_req.claim_digest, - control_root: keccak_req.control_root, - po2: keccak_req.po2, - }; - - process_task( - &args_copy, - &pool_copy, - &coproc_stream, - &join_stream, - &union_stream, - &aux_stream, - &job_id_copy, - tree_task, - None, - &assumptions, - compress_type, - Some(req), - ) - .await - .expect("Failed to process task and insert into taskdb"); - } - } - SenderType::Fault => { - guest_fault = true; - break; - } - } + // Spawn task to collect segments and keccak requests + let collect_task = tokio::spawn(async move { + while let Some(segment) = segment_rx.recv().await { + segments.push(segment); } - if !exec_only && !guest_fault { - planner.finish().expect("Planner failed to finish()"); - while let Some(tree_task) = planner.next_task() { - process_task( - &args_copy, - &pool_copy, - &prove_stream, - &join_stream, - &union_stream, - &aux_stream, - &job_id_copy, - tree_task, - None, - &assumptions, - compress_type, - None, - ) - .await - .expect("Failed to process task and insert into taskdb"); - } + while let Some(keccak_req) = keccak_rx.recv().await { + keccak_requests.push(keccak_req); } + + (segments, keccak_requests) + }); + + // Spawn task to collect keccak requests from coprocessor + let coproc = Coprocessor::new(keccak_tx.clone()); + let coproc_task = tokio::spawn(async move { + // This will be handled by the executor callback }); tracing::info!("Starting execution of job: {}", job_id); - // let file_stderr = NamedTempFile::new()?; let log_file = Arc::new(NamedTempFile::new()?); let log_file_copy = log_file.clone(); let guest_log_path = log_file.path().to_path_buf(); @@ -571,7 +450,6 @@ pub async fn executor(agent: &Agent, job_id: &Uuid, request: &ExecutorReq) -> Re let env = env .stdout(log_file_copy.as_file()) - // .stderr(file_stderr) .write_slice(&input_data) .session_limit(Some(exec_limit)) .coprocessor_callback(coproc) @@ -583,7 +461,7 @@ pub async fn executor(agent: &Agent, job_id: &Uuid, request: &ExecutorReq) -> Re let mut segments = 0; let res = match exec.run_with_callback(|segment| { segments += 1; - // Send segments to write queue, blocking if the queue is full. + // Send segments to collect queue, blocking if the queue is full. if !exec_only { segment_tx.blocking_send(segment).unwrap(); } @@ -597,15 +475,13 @@ pub async fn executor(agent: &Agent, job_id: &Uuid, request: &ExecutorReq) -> Re }), Err(err) => { tracing::error!("Failed to run executor"); - task_tx_clone - .blocking_send(SenderType::Fault) - .context("Failed to send fault to planner")?; Err(err) } }; // close the segment queue to trigger the workers to wrap up and exit drop(segment_tx); + drop(keccak_tx); res }); @@ -658,20 +534,163 @@ pub async fn executor(agent: &Agent, job_id: &Uuid, request: &ExecutorReq) -> Re tracing::warn!("No journal to update."); } - // First join all tasks and collect results - while let Some(res) = writer_tasks.join_next().await { - match res { - Ok(()) => { - if guest_fault { - bail!("Ran into fault"); + // Wait for collection to complete + let (segments, keccak_requests) = collect_task + .await + .context("Failed to collect segments and keccak requests")?; + + tracing::info!("Collected {} segments and {} keccak requests", segments.len(), keccak_requests.len()); + + // Batch write all segments to Redis + if !exec_only && !segments.is_empty() { + tracing::info!("Batch writing {} segments to Redis", segments.len()); + + // Prepare batch data + let mut segment_batch = Vec::new(); + for segment in &segments { + let segment_key = format!("{segments_prefix}:{}", segment.index); + let segment_vec = serialize_obj(segment).expect("Failed to serialize the segment"); + segment_batch.push((segment_key, segment_vec)); + } + + // Execute true batch write + redis::batch_set_keys_with_expiry( + &mut conn, + segment_batch, + Some(agent.args.redis_ttl), + ) + .await + .context("Failed to batch write segments to Redis")?; + + tracing::info!("Successfully batch wrote {} segments to Redis", segments.len()); + } + + // Batch write keccak requests to Redis + if !exec_only && !keccak_requests.is_empty() { + tracing::info!("Batch writing {} keccak requests to Redis", keccak_requests.len()); + let coproc_prefix = format!("{job_prefix}:{COPROC_CB_PATH}"); + + // Prepare batch data + let mut keccak_batch = Vec::new(); + for keccak_req in &keccak_requests { + let redis_key = format!("{coproc_prefix}:{}", keccak_req.claim_digest); + let input_data = bytemuck::cast_slice::<_, u8>(&keccak_req.input).to_vec(); + keccak_batch.push((redis_key, input_data)); + } + + // Execute true batch write + redis::batch_set_keys_with_expiry( + &mut conn, + keccak_batch, + Some(agent.args.redis_ttl), + ) + .await + .context("Failed to batch write keccak requests to Redis")?; + + tracing::info!("Successfully batch wrote {} keccak requests to Redis", keccak_requests.len()); + } + + // Now create all tasks in the database + if !exec_only { + tracing::info!("Creating tasks in database"); + + // Create a single planner for all operations + let mut planner = Planner::default(); + + // Add all segments and keccaks to planner + for _ in &segments { + planner.enqueue_segment().expect("Failed to enqueue segment"); + } + for _ in &keccak_requests { + planner.enqueue_keccak().expect("Failed to enqueue keccak"); + } + + // Finish planning to get all tasks + planner.finish().expect("Planner failed to finish()"); + + // Process all tasks in batch + let mut all_tasks = Vec::new(); + while let Some(tree_task) = planner.next_task() { + all_tasks.push(tree_task); + } + + tracing::info!("Planned {} total tasks, creating in database", all_tasks.len()); + + // Create all tasks in parallel batches + let mut task_futures = Vec::new(); + let batch_size = 50; // Process in batches of 50 + + for chunk in all_tasks.chunks(batch_size) { + let chunk_tasks = chunk.to_vec(); + let args_clone = args_copy.clone(); + let pool_clone = pool_copy.clone(); + let prove_stream_clone = prove_stream.clone(); + let join_stream_clone = join_stream.clone(); + let union_stream_clone = union_stream.clone(); + let aux_stream_clone = aux_stream.clone(); + let job_id_clone = job_id_copy; + let assumptions_clone = assumptions.clone(); + let compress_type_clone = compress_type.clone(); + let coproc_stream_clone = coproc_stream.clone(); + let segments_clone = segments.clone(); + let keccak_requests_clone = keccak_requests.clone(); + + let future = tokio::spawn(async move { + for tree_task in chunk_tasks { + // Determine if this is a segment task or keccak task + let segment_index = if tree_task.command == TaskCmd::Segment { + // Find the corresponding segment index + segments_clone.iter().find(|s| s.index as u32 == tree_task.task_number) + .map(|s| s.index as u32) + } else { + None + }; + + // Determine if this is a keccak task + let keccak_req = if tree_task.command == TaskCmd::Keccak { + // Find the corresponding keccak request + keccak_requests_clone.iter().find(|k| k.claim_digest.to_string() == tree_task.task_number.to_string()) + .map(|k| KeccakReq { + claim_digest: k.claim_digest, + control_root: k.control_root, + po2: k.po2, + }) + } else { + None + }; + + if let Err(err) = process_task( + &args_clone, + &pool_clone, + &prove_stream_clone, + &join_stream_clone, + &union_stream_clone, + &aux_stream_clone, + &job_id_clone, + &tree_task, + segment_index, + &assumptions_clone, + compress_type_clone, + keccak_req, + ).await { + tracing::error!("Failed to process task: {:?}", err); + return Err(err); + } } - continue; - } - Err(err) => { - tracing::error!("queue monitor sub task failed: {err:?}"); - bail!(err); - } + Ok::<(), anyhow::Error>(()) + }); + + task_futures.push(future); + } + + // Wait for all batches to complete + for future in task_futures { + future.await + .context("Task batch failed to complete")? + .context("Task batch processing failed")?; } + + tracing::info!("Successfully created all {} tasks in database", all_tasks.len()); } tracing::debug!("Done with all IO tasks"); diff --git a/compose.yml b/compose.yml index 66108ec23..dc2bc6fbe 100644 --- a/compose.yml +++ b/compose.yml @@ -2,7 +2,8 @@ name: bento # Anchors: x-base-environment: &base-environment DATABASE_URL: postgresql://${POSTGRES_USER:-worker}:${POSTGRES_PASSWORD:-password}@${POSTGRES_HOST:-postgres}:${POSTGRES_PORT:-5432}/${POSTGRES_DB:-taskdb} - REDIS_URL: redis://${REDIS_HOST:-redis}:6379 + # KeyDB connection (Redis-compatible protocol) + REDIS_URL: keydb://${REDIS_HOST:-redis}:6379 S3_URL: http://${MINIO_HOST:-minio}:9000 S3_BUCKET: ${MINIO_BUCKET:-workflow} S3_ACCESS_KEY: ${MINIO_ROOT_USER:-admin} @@ -81,15 +82,15 @@ x-broker-common: &broker-common services: redis: hostname: ${REDIS_HOST:-redis} - image: ${REDIS_IMG:-redis:7.2.5-alpine3.19} + image: ${REDIS_IMG:-eqalpha/keydb:alpine-x86_64_v6.3.4} restart: always ports: - 6379:6379 volumes: - redis-data:/data - command: redis-server --maxmemory-policy allkeys-lru --save 900 1 --appendonly yes + command: keydb-server --maxmemory-policy allkeys-lru --save 900 1 --appendonly yes --server-threads 4 --multi-threading yes healthcheck: - test: ["CMD", "redis-cli", "ping"] + test: ["CMD", "keydb-cli", "ping"] interval: 5s timeout: 3s retries: 5