diff --git a/packages/common/chirp-workflow/core/src/db/crdb_nats/debug.rs b/packages/common/chirp-workflow/core/src/db/crdb_nats/debug.rs index 8e63ef2982..4d54c36a1e 100644 --- a/packages/common/chirp-workflow/core/src/db/crdb_nats/debug.rs +++ b/packages/common/chirp-workflow/core/src/db/crdb_nats/debug.rs @@ -157,7 +157,7 @@ impl DatabaseDebug for DatabaseCrdbNats { SET silence_ts = $2 WHERE workflow_id = ANY($1) ", - workflow_ids, + &workflow_ids, rivet_util::timestamp::now(), ) .await?; @@ -174,7 +174,7 @@ impl DatabaseDebug for DatabaseCrdbNats { SET wake_immediate = TRUE WHERE workflow_id = ANY($1) ", - workflow_ids, + &workflow_ids, ) .await?; @@ -599,7 +599,7 @@ impl DatabaseDebug for DatabaseCrdbNats { ) SELECT 1 ", - signal_ids, + &signal_ids, rivet_util::timestamp::now(), ) .await?; diff --git a/packages/common/chirp-workflow/core/src/db/crdb_nats/mod.rs b/packages/common/chirp-workflow/core/src/db/crdb_nats/mod.rs index ebc62e8680..9d84da6847 100644 --- a/packages/common/chirp-workflow/core/src/db/crdb_nats/mod.rs +++ b/packages/common/chirp-workflow/core/src/db/crdb_nats/mod.rs @@ -1,10 +1,6 @@ //! Implementation of a workflow database driver with PostgreSQL (CockroachDB) and NATS. -use std::{ - collections::HashSet, - sync::Arc, - time::{Duration, Instant}, -}; +use std::{collections::HashSet, sync::Arc, time::Instant}; use futures_util::{stream::BoxStream, StreamExt}; use indoc::indoc; @@ -32,11 +28,9 @@ type GlobalError = WorkflowError; /// Max amount of workflows pulled from the database with each call to `pull_workflows`. const MAX_PULLED_WORKFLOWS: i64 = 50; -// Base retry for query retry backoff -const QUERY_RETRY_MS: usize = 500; -// Time in between transaction retries -const TXN_RETRY: Duration = Duration::from_millis(100); -/// Maximum times a query ran by this database adapter is retried. +/// Base retry for txn retry backoff. +const QUERY_RETRY_MS: usize = 100; +/// Maximum times a txn query is retried. const MAX_QUERY_RETRIES: usize = 16; /// How long before considering the leases of a given worker instance "expired". const WORKER_INSTANCE_EXPIRED_THRESHOLD_MS: i64 = rivet_util::duration::seconds(30); @@ -95,9 +89,9 @@ impl DatabaseCrdbNats { } } - /// Executes queries and explicitly handles retry errors. + /// Executes queries while explicitly handling txn retry errors. #[tracing::instrument(skip_all)] - async fn query<'a, F, Fut, T>(&self, mut cb: F) -> WorkflowResult + async fn txn<'a, F, Fut, T>(&self, mut cb: F) -> WorkflowResult where F: FnMut() -> Fut, Fut: std::future::Future> + 'a, @@ -111,24 +105,18 @@ impl DatabaseCrdbNats { Err(WorkflowError::Sqlx(err)) => { i += 1; if i > MAX_QUERY_RETRIES { - return Err(WorkflowError::MaxSqlRetries(err)); + return Err(WorkflowError::Sqlx(sqlx::Error::Io(std::io::Error::new( + std::io::ErrorKind::Other, + rivet_pools::utils::sql_query_macros::Error::MaxSqlRetries(err), + )))); } use sqlx::Error::*; match &err { - // Retry transaction errors in a tight loop - Database(db_err) - if db_err - .message() - .contains("TransactionRetryWithProtoRefreshError") => - { - tracing::warn!(message=%db_err.message(), "transaction retry"); - tokio::time::sleep(TXN_RETRY).await; - } - // Retry other errors with a backoff + // Retry all errors with a backoff Database(_) | Io(_) | Tls(_) | Protocol(_) | PoolTimedOut | PoolClosed | WorkerCrashed => { - tracing::warn!(?err, "query retry"); + tracing::warn!(?err, "txn retry"); backoff.tick().await; } // Throw error @@ -479,21 +467,17 @@ impl Database for DatabaseCrdbNats { ) }; - let (actual_workflow_id,) = self - .query(|| async { - sql_fetch_one!( - [self, (Uuid,)] - query, - workflow_id, - workflow_name, - rivet_util::timestamp::now(), - ray_id, - tags, - sqlx::types::Json(input), - ) - .await - }) - .await?; + let (actual_workflow_id,) = sql_fetch_one!( + [self, (Uuid,)] + query, + workflow_id, + workflow_name, + rivet_util::timestamp::now(), + ray_id, + tags, + sqlx::types::Json(input), + ) + .await?; if workflow_id == actual_workflow_id { self.wake_worker(); @@ -544,126 +528,122 @@ impl Database for DatabaseCrdbNats { let start_instant = Instant::now(); // Select all workflows that have a wake condition - let workflow_rows = self - .query(|| async { - sql_fetch_all!( - [self, PulledWorkflowRow] - " - WITH select_pending_workflows AS ( - SELECT workflow_id - FROM db_workflow.workflows@workflows_pred_standard - WHERE - -- Filter - workflow_name = ANY($2) AND - -- Not already complete - output IS NULL AND - -- No assigned node (not running) - worker_instance_id IS NULL AND - -- Not silenced - silence_ts IS NULL AND - -- Check for wake condition - ( - -- Immediate - wake_immediate OR - -- After deadline - ( - wake_deadline_ts IS NOT NULL AND - $3 > wake_deadline_ts - $4 - ) - ) - UNION - SELECT workflow_id - FROM db_workflow.workflows@workflows_pred_signals AS w - WHERE - -- Filter - workflow_name = ANY($2) AND - -- Not already complete - output IS NULL AND - -- No assigned node (not running) - worker_instance_id IS NULL AND - -- Not silenced - silence_ts IS NULL AND - -- Has signals to listen to - array_length(wake_signals, 1) != 0 AND - -- Signal exists - ( - SELECT true - FROM db_workflow.signals@signals_partial AS s - WHERE - s.workflow_id = w.workflow_id AND - s.signal_name = ANY(w.wake_signals) AND - s.ack_ts IS NULL AND - s.silence_ts IS NULL - LIMIT 1 - ) - UNION - SELECT workflow_id - FROM db_workflow.workflows@workflows_pred_signals AS w - WHERE - -- Filter - workflow_name = ANY($2) AND - -- Not already complete - output IS NULL AND - -- No assigned node (not running) - worker_instance_id IS NULL AND - -- Not silenced - silence_ts IS NULL AND - -- Has signals to listen to - array_length(wake_signals, 1) != 0 AND - -- Tagged signal exists - ( - SELECT true - FROM db_workflow.tagged_signals@tagged_signals_partial AS s - WHERE - s.signal_name = ANY(w.wake_signals) AND - s.tags <@ w.tags AND - s.ack_ts IS NULL AND - s.silence_ts IS NULL - LIMIT 1 - ) - UNION - SELECT workflow_id - FROM db_workflow.workflows@workflows_pred_sub_workflow AS w - WHERE - -- Filter - workflow_name = ANY($2) AND - -- Not already complete - output IS NULL AND - -- No assigned node (not running) - worker_instance_id IS NULL AND - -- Not silenced - silence_ts IS NULL AND - wake_sub_workflow_id IS NOT NULL AND - -- Sub workflow completed + let workflow_rows = sql_fetch_all!( + [self, PulledWorkflowRow] + " + WITH select_pending_workflows AS ( + SELECT workflow_id + FROM db_workflow.workflows@workflows_pred_standard + WHERE + -- Filter + workflow_name = ANY($2) AND + -- Not already complete + output IS NULL AND + -- No assigned node (not running) + worker_instance_id IS NULL AND + -- Not silenced + silence_ts IS NULL AND + -- Check for wake condition + ( + -- Immediate + wake_immediate OR + -- After deadline ( - SELECT true - FROM db_workflow.workflows@workflows_pred_sub_workflow_internal AS w2 - WHERE - w2.workflow_id = w.wake_sub_workflow_id AND - output IS NOT NULL + wake_deadline_ts IS NOT NULL AND + $3 > wake_deadline_ts - $4 ) - LIMIT $5 - ) - UPDATE db_workflow.workflows@workflows_pkey AS w - -- Assign current node to this workflow - SET - worker_instance_id = $1, - last_pull_ts = $3 - FROM select_pending_workflows AS pw - WHERE w.workflow_id = pw.workflow_id - RETURNING w.workflow_id, workflow_name, create_ts, ray_id, input, wake_deadline_ts - ", - worker_instance_id, - filter, - rivet_util::timestamp::now(), - // Add padding to the tick interval so that the workflow deadline is never passed before its pulled. - // The worker sleeps internally to handle this - self.worker_poll_interval().as_millis() as i64 + 1, - MAX_PULLED_WORKFLOWS, + ) + UNION + SELECT workflow_id + FROM db_workflow.workflows@workflows_pred_signals AS w + WHERE + -- Filter + workflow_name = ANY($2) AND + -- Not already complete + output IS NULL AND + -- No assigned node (not running) + worker_instance_id IS NULL AND + -- Not silenced + silence_ts IS NULL AND + -- Has signals to listen to + array_length(wake_signals, 1) != 0 AND + -- Signal exists + ( + SELECT true + FROM db_workflow.signals@signals_partial AS s + WHERE + s.workflow_id = w.workflow_id AND + s.signal_name = ANY(w.wake_signals) AND + s.ack_ts IS NULL AND + s.silence_ts IS NULL + LIMIT 1 + ) + UNION + SELECT workflow_id + FROM db_workflow.workflows@workflows_pred_signals AS w + WHERE + -- Filter + workflow_name = ANY($2) AND + -- Not already complete + output IS NULL AND + -- No assigned node (not running) + worker_instance_id IS NULL AND + -- Not silenced + silence_ts IS NULL AND + -- Has signals to listen to + array_length(wake_signals, 1) != 0 AND + -- Tagged signal exists + ( + SELECT true + FROM db_workflow.tagged_signals@tagged_signals_partial AS s + WHERE + s.signal_name = ANY(w.wake_signals) AND + s.tags <@ w.tags AND + s.ack_ts IS NULL AND + s.silence_ts IS NULL + LIMIT 1 + ) + UNION + SELECT workflow_id + FROM db_workflow.workflows@workflows_pred_sub_workflow AS w + WHERE + -- Filter + workflow_name = ANY($2) AND + -- Not already complete + output IS NULL AND + -- No assigned node (not running) + worker_instance_id IS NULL AND + -- Not silenced + silence_ts IS NULL AND + wake_sub_workflow_id IS NOT NULL AND + -- Sub workflow completed + ( + SELECT true + FROM db_workflow.workflows@workflows_pred_sub_workflow_internal AS w2 + WHERE + w2.workflow_id = w.wake_sub_workflow_id AND + output IS NOT NULL + ) + LIMIT $5 ) - .await - }) - .await?; + UPDATE db_workflow.workflows@workflows_pkey AS w + -- Assign current node to this workflow + SET + worker_instance_id = $1, + last_pull_ts = $3 + FROM select_pending_workflows AS pw + WHERE w.workflow_id = pw.workflow_id + RETURNING w.workflow_id, workflow_name, create_ts, ray_id, input, wake_deadline_ts + ", + worker_instance_id, + filter, + rivet_util::timestamp::now(), + // Add padding to the tick interval so that the workflow deadline is never passed before its pulled. + // The worker sleeps internally to handle this + self.worker_poll_interval().as_millis() as i64 + 1, + MAX_PULLED_WORKFLOWS, + ) + .await?; let worker_instance_id_str = worker_instance_id.to_string(); let dt = start_instant.elapsed().as_secs_f64(); @@ -689,236 +669,236 @@ impl Database for DatabaseCrdbNats { let events = sql_fetch_all!( [self, AmalgamEventRow] " - -- Activity events - SELECT - workflow_id, - location, - location2, - version, - 0 AS event_type, -- EventType - activity_name AS name, - NULL AS auxiliary_id, - input_hash AS hash, - NULL AS input, - output AS output, - create_ts AS create_ts, - ( - SELECT COUNT(*) - FROM db_workflow.workflow_activity_errors AS err - WHERE - ev.workflow_id = err.workflow_id AND - ev.location2 = err.location2 - ) AS error_count, - NULL AS iteration, - NULL AS deadline_ts, - NULL AS state, - NULL AS inner_event_type - FROM db_workflow.workflow_activity_events AS ev - WHERE ev.workflow_id = ANY($1) AND forgotten = FALSE - -- Should only require `workflow_id` and `location2` but because `location2` is nullable the - -- database can't determine uniqueness - GROUP BY - ev.workflow_id, - ev.location, - ev.location2, - ev.version, - ev.activity_name, - ev.input_hash, - ev.output, - ev.create_ts - UNION ALL - -- Signal listen events - SELECT - workflow_id, - location, - location2, - version, - 1 AS event_type, -- EventType - signal_name AS name, - NULL AS auxiliary_id, - NULL AS hash, - NULL AS input, - body AS output, - NULL AS create_ts, - NULL AS error_count, - NULL AS iteration, - NULL AS deadline_ts, - NULL AS state, - NULL AS inner_event_type - FROM db_workflow.workflow_signal_events - WHERE workflow_id = ANY($1) AND forgotten = FALSE - UNION ALL - -- Signal send events - SELECT - workflow_id, - location, - location2, - version, - 2 AS event_type, -- EventType - signal_name AS name, - signal_id AS auxiliary_id, - NULL AS hash, - NULL AS input, - NULL AS output, - NULL AS create_ts, - NULL AS error_count, - NULL AS iteration, - NULL AS deadline_ts, - NULL AS state, - NULL AS inner_event_type - FROM db_workflow.workflow_signal_send_events - WHERE workflow_id = ANY($1) AND forgotten = FALSE - UNION ALL - -- Message send events - SELECT - workflow_id, - location, - location2, - version, - 3 AS event_type, -- EventType - message_name AS name, - NULL AS auxiliary_id, - NULL AS hash, - NULL AS input, - NULL AS output, - NULL AS create_ts, - NULL AS error_count, - NULL AS iteration, - NULL AS deadline_ts, - NULL AS state, - NULL AS inner_event_type - FROM db_workflow.workflow_message_send_events - WHERE workflow_id = ANY($1) AND forgotten = FALSE - UNION ALL - -- Sub workflow events - SELECT - sw.workflow_id, - sw.location, - sw.location2, - version, - 4 AS event_type, -- crdb_nats::types::EventType - w.workflow_name AS name, - sw.sub_workflow_id AS auxiliary_id, - NULL AS hash, - NULL AS input, - NULL AS output, - NULL AS create_ts, - NULL AS error_count, - NULL AS iteration, - NULL AS deadline_ts, - NULL AS state, - NULL AS inner_event_type - FROM db_workflow.workflow_sub_workflow_events AS sw - JOIN db_workflow.workflows AS w - ON sw.sub_workflow_id = w.workflow_id - WHERE sw.workflow_id = ANY($1) AND forgotten = FALSE - UNION ALL - -- Loop events - SELECT - workflow_id, - location, - location2, - version, - 5 AS event_type, -- crdb_nats::types::EventType - NULL AS name, - NULL AS auxiliary_id, - NULL AS hash, - state AS input, - output, - NULL AS create_ts, - NULL AS error_count, - iteration, - NULL AS deadline_ts, - NULL AS state, - NULL AS inner_event_type - FROM db_workflow.workflow_loop_events - WHERE workflow_id = ANY($1) AND forgotten = FALSE - UNION ALL - -- Sleep events - SELECT - workflow_id, - location, - location2, - version, - 6 AS event_type, -- crdb_nats::types::EventType - NULL AS name, - NULL AS auxiliary_id, - NULL AS hash, - NULL AS input, - NULL AS output, - NULL AS create_ts, - NULL AS error_count, - NULL AS iteration, - deadline_ts, - state, - NULL AS inner_event_type - FROM db_workflow.workflow_sleep_events - WHERE workflow_id = ANY($1) AND forgotten = FALSE - UNION ALL - -- Branch events - SELECT - workflow_id, - ARRAY[] AS location, - location AS location2, - version, - 7 AS event_type, -- crdb_nats::types::EventType - NULL AS name, - NULL AS auxiliary_id, - NULL AS hash, - NULL AS input, - NULL AS output, - NULL AS create_ts, - NULL AS error_count, - NULL AS iteration, - NULL AS deadline_ts, - NULL AS state, - NULL AS inner_event_type - FROM db_workflow.workflow_branch_events - WHERE workflow_id = ANY($1) AND forgotten = FALSE - UNION ALL - -- Removed events - SELECT - workflow_id, - ARRAY[] AS location, - location AS location2, - 1 AS version, -- Default - 8 AS event_type, -- crdb_nats::types::EventType - event_name AS name, - NULL AS auxiliary_id, - NULL AS hash, - NULL AS input, - NULL AS output, - NULL AS create_ts, - NULL AS error_count, - NULL AS iteration, - NULL AS deadline_ts, - NULL AS state, - event_type AS inner_event_type - FROM db_workflow.workflow_removed_events - WHERE workflow_id = ANY($1) AND forgotten = FALSE - UNION ALL - -- Version check events - SELECT - workflow_id, - ARRAY[] AS location, - location AS location2, - version, - 9 AS event_type, -- crdb_nats::types::EventType - NULL AS name, - NULL AS auxiliary_id, - NULL AS hash, - NULL AS input, - NULL AS output, - NULL AS create_ts, - NULL AS error_count, - NULL AS iteration, - NULL AS deadline_ts, - NULL AS state, - NULL AS inner_event_type - FROM db_workflow.workflow_version_check_events - WHERE workflow_id = ANY($1) AND forgotten = FALSE - ORDER BY workflow_id ASC, location2 ASC - ", + -- Activity events + SELECT + workflow_id, + location, + location2, + version, + 0 AS event_type, -- EventType + activity_name AS name, + NULL AS auxiliary_id, + input_hash AS hash, + NULL AS input, + output AS output, + create_ts AS create_ts, + ( + SELECT COUNT(*) + FROM db_workflow.workflow_activity_errors AS err + WHERE + ev.workflow_id = err.workflow_id AND + ev.location2 = err.location2 + ) AS error_count, + NULL AS iteration, + NULL AS deadline_ts, + NULL AS state, + NULL AS inner_event_type + FROM db_workflow.workflow_activity_events AS ev + WHERE ev.workflow_id = ANY($1) AND forgotten = FALSE + -- Should only require `workflow_id` and `location2` but because `location2` is nullable the + -- database can't determine uniqueness + GROUP BY + ev.workflow_id, + ev.location, + ev.location2, + ev.version, + ev.activity_name, + ev.input_hash, + ev.output, + ev.create_ts + UNION ALL + -- Signal listen events + SELECT + workflow_id, + location, + location2, + version, + 1 AS event_type, -- EventType + signal_name AS name, + NULL AS auxiliary_id, + NULL AS hash, + NULL AS input, + body AS output, + NULL AS create_ts, + NULL AS error_count, + NULL AS iteration, + NULL AS deadline_ts, + NULL AS state, + NULL AS inner_event_type + FROM db_workflow.workflow_signal_events + WHERE workflow_id = ANY($1) AND forgotten = FALSE + UNION ALL + -- Signal send events + SELECT + workflow_id, + location, + location2, + version, + 2 AS event_type, -- EventType + signal_name AS name, + signal_id AS auxiliary_id, + NULL AS hash, + NULL AS input, + NULL AS output, + NULL AS create_ts, + NULL AS error_count, + NULL AS iteration, + NULL AS deadline_ts, + NULL AS state, + NULL AS inner_event_type + FROM db_workflow.workflow_signal_send_events + WHERE workflow_id = ANY($1) AND forgotten = FALSE + UNION ALL + -- Message send events + SELECT + workflow_id, + location, + location2, + version, + 3 AS event_type, -- EventType + message_name AS name, + NULL AS auxiliary_id, + NULL AS hash, + NULL AS input, + NULL AS output, + NULL AS create_ts, + NULL AS error_count, + NULL AS iteration, + NULL AS deadline_ts, + NULL AS state, + NULL AS inner_event_type + FROM db_workflow.workflow_message_send_events + WHERE workflow_id = ANY($1) AND forgotten = FALSE + UNION ALL + -- Sub workflow events + SELECT + sw.workflow_id, + sw.location, + sw.location2, + version, + 4 AS event_type, -- crdb_nats::types::EventType + w.workflow_name AS name, + sw.sub_workflow_id AS auxiliary_id, + NULL AS hash, + NULL AS input, + NULL AS output, + NULL AS create_ts, + NULL AS error_count, + NULL AS iteration, + NULL AS deadline_ts, + NULL AS state, + NULL AS inner_event_type + FROM db_workflow.workflow_sub_workflow_events AS sw + JOIN db_workflow.workflows AS w + ON sw.sub_workflow_id = w.workflow_id + WHERE sw.workflow_id = ANY($1) AND forgotten = FALSE + UNION ALL + -- Loop events + SELECT + workflow_id, + location, + location2, + version, + 5 AS event_type, -- crdb_nats::types::EventType + NULL AS name, + NULL AS auxiliary_id, + NULL AS hash, + state AS input, + output, + NULL AS create_ts, + NULL AS error_count, + iteration, + NULL AS deadline_ts, + NULL AS state, + NULL AS inner_event_type + FROM db_workflow.workflow_loop_events + WHERE workflow_id = ANY($1) AND forgotten = FALSE + UNION ALL + -- Sleep events + SELECT + workflow_id, + location, + location2, + version, + 6 AS event_type, -- crdb_nats::types::EventType + NULL AS name, + NULL AS auxiliary_id, + NULL AS hash, + NULL AS input, + NULL AS output, + NULL AS create_ts, + NULL AS error_count, + NULL AS iteration, + deadline_ts, + state, + NULL AS inner_event_type + FROM db_workflow.workflow_sleep_events + WHERE workflow_id = ANY($1) AND forgotten = FALSE + UNION ALL + -- Branch events + SELECT + workflow_id, + ARRAY[] AS location, + location AS location2, + version, + 7 AS event_type, -- crdb_nats::types::EventType + NULL AS name, + NULL AS auxiliary_id, + NULL AS hash, + NULL AS input, + NULL AS output, + NULL AS create_ts, + NULL AS error_count, + NULL AS iteration, + NULL AS deadline_ts, + NULL AS state, + NULL AS inner_event_type + FROM db_workflow.workflow_branch_events + WHERE workflow_id = ANY($1) AND forgotten = FALSE + UNION ALL + -- Removed events + SELECT + workflow_id, + ARRAY[] AS location, + location AS location2, + 1 AS version, -- Default + 8 AS event_type, -- crdb_nats::types::EventType + event_name AS name, + NULL AS auxiliary_id, + NULL AS hash, + NULL AS input, + NULL AS output, + NULL AS create_ts, + NULL AS error_count, + NULL AS iteration, + NULL AS deadline_ts, + NULL AS state, + event_type AS inner_event_type + FROM db_workflow.workflow_removed_events + WHERE workflow_id = ANY($1) AND forgotten = FALSE + UNION ALL + -- Version check events + SELECT + workflow_id, + ARRAY[] AS location, + location AS location2, + version, + 9 AS event_type, -- crdb_nats::types::EventType + NULL AS name, + NULL AS auxiliary_id, + NULL AS hash, + NULL AS input, + NULL AS output, + NULL AS create_ts, + NULL AS error_count, + NULL AS iteration, + NULL AS deadline_ts, + NULL AS state, + NULL AS inner_event_type + FROM db_workflow.workflow_version_check_events + WHERE workflow_id = ANY($1) AND forgotten = FALSE + ORDER BY workflow_id ASC, location2 ASC + ", &workflow_ids, ) .await?; @@ -952,20 +932,16 @@ impl Database for DatabaseCrdbNats { ) -> WorkflowResult<()> { let start_instant = Instant::now(); - self.query(|| async { - sqlx::query(indoc!( - " - UPDATE db_workflow.workflows - SET output = $2 - WHERE workflow_id = $1 - ", - )) - .bind(workflow_id) - .bind(sqlx::types::Json(output)) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) + sql_execute!( + [self] + " + UPDATE db_workflow.workflows + SET output = $2 + WHERE workflow_id = $1 + ", + workflow_id, + sqlx::types::Json(output), + ) .await?; self.wake_worker(); @@ -991,30 +967,26 @@ impl Database for DatabaseCrdbNats { ) -> WorkflowResult<()> { let start_instant = Instant::now(); - self.query(|| async { - sqlx::query(indoc!( - " - UPDATE db_workflow.workflows - SET - worker_instance_id = NULL, - wake_immediate = $2, - wake_deadline_ts = $3, - wake_signals = $4, - wake_sub_workflow_id = $5, - error = $6 - WHERE workflow_id = $1 - ", - )) - .bind(workflow_id) - .bind(immediate) - .bind(wake_deadline_ts) - .bind(wake_signals) - .bind(wake_sub_workflow_id) - .bind(error) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) + sql_execute!( + [self] + " + UPDATE db_workflow.workflows + SET + worker_instance_id = NULL, + wake_immediate = $2, + wake_deadline_ts = $3, + wake_signals = $4, + wake_sub_workflow_id = $5, + error = $6 + WHERE workflow_id = $1 + ", + workflow_id, + immediate, + wake_deadline_ts, + wake_signals, + wake_sub_workflow_id, + error, + ) .await?; // Wake worker again if the deadline is before the next tick @@ -1045,94 +1017,90 @@ impl Database for DatabaseCrdbNats { loop_location: Option<&Location>, _last_try: bool, ) -> WorkflowResult> { - let signal = self - .query(|| async { - sql_fetch_optional!( - [self, SignalRow] - " - WITH - -- Finds the oldest signal matching the signal name filter in either the normal signals table - -- or tagged signals table - next_signal AS ( - SELECT false AS tagged, signal_id, create_ts, signal_name, body - FROM db_workflow.signals@signals_partial - WHERE - workflow_id = $1 AND - signal_name = ANY($2) AND - ack_ts IS NULL AND - silence_ts IS NULL - UNION ALL - SELECT true AS tagged, signal_id, s.create_ts, signal_name, body - FROM db_workflow.tagged_signals@tagged_signals_partial AS s - JOIN db_workflow.workflows AS w - ON s.tags <@ w.tags - WHERE - w.workflow_id = $1 AND - s.signal_name = ANY($2) AND - s.ack_ts IS NULL AND - s.silence_ts IS NULL - ORDER BY create_ts ASC - LIMIT 1 - ), - -- If the next signal is not tagged, acknowledge it with this statement - ack_signal AS ( - UPDATE db_workflow.signals - SET ack_ts = $5 - WHERE signal_id = ( - SELECT signal_id - FROM next_signal - WHERE tagged = false - ) - RETURNING 1 - ), - -- If the next signal is tagged, acknowledge it with this statement - ack_tagged_signal AS ( - UPDATE db_workflow.tagged_signals - SET ack_ts = $5 - WHERE signal_id = ( - SELECT signal_id - FROM next_signal - WHERE tagged = true - ) - RETURNING 1 - ), - -- After acking the signal, add it to the events table - insert_event AS ( - INSERT INTO db_workflow.workflow_signal_events ( - workflow_id, - location2, - version, - signal_id, - signal_name, - body, - ack_ts, - loop_location2 - ) - SELECT - $1 AS workflow_id, - $3 AS location2, - $4 AS version, - signal_id, - signal_name, - body, - $5 AS ack_ts, - $6 AS loop_location2 - FROM next_signal - RETURNING 1 - ) - SELECT * FROM next_signal - ", - workflow_id, - filter, - location, - version as i64, - rivet_util::timestamp::now(), - loop_location, + let signal = sql_fetch_optional!( + [self, SignalRow] + " + WITH + -- Finds the oldest signal matching the signal name filter in either the normal signals table + -- or tagged signals table + next_signal AS ( + SELECT false AS tagged, signal_id, create_ts, signal_name, body + FROM db_workflow.signals@signals_partial + WHERE + workflow_id = $1 AND + signal_name = ANY($2) AND + ack_ts IS NULL AND + silence_ts IS NULL + UNION ALL + SELECT true AS tagged, signal_id, s.create_ts, signal_name, body + FROM db_workflow.tagged_signals@tagged_signals_partial AS s + JOIN db_workflow.workflows AS w + ON s.tags <@ w.tags + WHERE + w.workflow_id = $1 AND + s.signal_name = ANY($2) AND + s.ack_ts IS NULL AND + s.silence_ts IS NULL + ORDER BY create_ts ASC + LIMIT 1 + ), + -- If the next signal is not tagged, acknowledge it with this statement + ack_signal AS ( + UPDATE db_workflow.signals + SET ack_ts = $5 + WHERE signal_id = ( + SELECT signal_id + FROM next_signal + WHERE tagged = false + ) + RETURNING 1 + ), + -- If the next signal is tagged, acknowledge it with this statement + ack_tagged_signal AS ( + UPDATE db_workflow.tagged_signals + SET ack_ts = $5 + WHERE signal_id = ( + SELECT signal_id + FROM next_signal + WHERE tagged = true + ) + RETURNING 1 + ), + -- After acking the signal, add it to the events table + insert_event AS ( + INSERT INTO db_workflow.workflow_signal_events ( + workflow_id, + location2, + version, + signal_id, + signal_name, + body, + ack_ts, + loop_location2 + ) + SELECT + $1 AS workflow_id, + $3 AS location2, + $4 AS version, + signal_id, + signal_name, + body, + $5 AS ack_ts, + $6 AS loop_location2 + FROM next_signal + RETURNING 1 ) - .await - .map(|row| row.map(Into::into)) - }) - .await?; + SELECT * FROM next_signal + ", + workflow_id, + filter, + location, + version as i64, + rivet_util::timestamp::now(), + loop_location, + ) + .await? + .map(Into::into); Ok(signal) } @@ -1156,24 +1124,21 @@ impl Database for DatabaseCrdbNats { signal_name: &str, body: &serde_json::value::RawValue, ) -> WorkflowResult<()> { - self.query(|| async { - sql_execute!( - [self] - " - INSERT INTO db_workflow.signals ( - signal_id, workflow_id, signal_name, body, ray_id, create_ts - ) - VALUES ($1, $2, $3, $4, $5, $6) - ", - signal_id, - workflow_id, - signal_name, - sqlx::types::Json(body), - ray_id, - rivet_util::timestamp::now(), - ) - .await - }) + sql_execute!( + [self] + " + INSERT INTO db_workflow.signals ( + signal_id, workflow_id, signal_name, body, ray_id, create_ts + ) + VALUES ($1, $2, $3, $4, $5, $6) + ", + signal_id, + workflow_id, + signal_name, + sqlx::types::Json(body), + ray_id, + rivet_util::timestamp::now(), + ) .await?; self.wake_worker(); @@ -1190,24 +1155,21 @@ impl Database for DatabaseCrdbNats { signal_name: &str, body: &serde_json::value::RawValue, ) -> WorkflowResult<()> { - self.query(|| async { - sql_execute!( - [self] - " - INSERT INTO db_workflow.tagged_signals ( - signal_id, tags, signal_name, body, ray_id, create_ts - ) - VALUES ($1, $2, $3, $4, $5, $6) - ", - signal_id, - tags, - signal_name, - sqlx::types::Json(body), - ray_id, - rivet_util::timestamp::now(), - ) - .await - }) + sql_execute!( + [self] + " + INSERT INTO db_workflow.tagged_signals ( + signal_id, tags, signal_name, body, ray_id, create_ts + ) + VALUES ($1, $2, $3, $4, $5, $6) + ", + signal_id, + tags, + signal_name, + sqlx::types::Json(body), + ray_id, + rivet_util::timestamp::now(), + ) .await?; self.wake_worker(); @@ -1228,40 +1190,37 @@ impl Database for DatabaseCrdbNats { body: &serde_json::value::RawValue, loop_location: Option<&Location>, ) -> WorkflowResult<()> { - self.query(|| async { - sql_execute!( - [self] - " - WITH - signal AS ( - INSERT INTO db_workflow.signals ( - signal_id, workflow_id, signal_name, body, ray_id, create_ts - ) - VALUES ($1, $2, $3, $4, $5, $6) - RETURNING 1 - ), - send_event AS ( - INSERT INTO db_workflow.workflow_signal_send_events( - workflow_id, location2, version, signal_id, signal_name, body, loop_location2 - ) - VALUES($7, $8, $9, $1, $3, $4, $10) - RETURNING 1 + sql_execute!( + [self] + " + WITH + signal AS ( + INSERT INTO db_workflow.signals ( + signal_id, workflow_id, signal_name, body, ray_id, create_ts + ) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING 1 + ), + send_event AS ( + INSERT INTO db_workflow.workflow_signal_send_events( + workflow_id, location2, version, signal_id, signal_name, body, loop_location2 ) - SELECT 1 - ", - signal_id, - to_workflow_id, - signal_name, - sqlx::types::Json(body), - ray_id, - rivet_util::timestamp::now(), - from_workflow_id, - location, - version as i64, - loop_location, - ) - .await - }) + VALUES($7, $8, $9, $1, $3, $4, $10) + RETURNING 1 + ) + SELECT 1 + ", + signal_id, + to_workflow_id, + signal_name, + sqlx::types::Json(body), + ray_id, + rivet_util::timestamp::now(), + from_workflow_id, + location, + version as i64, + loop_location, + ) .await?; self.wake_worker(); @@ -1282,40 +1241,37 @@ impl Database for DatabaseCrdbNats { body: &serde_json::value::RawValue, loop_location: Option<&Location>, ) -> WorkflowResult<()> { - self.query(|| async { - sql_execute!( - [self] - " - WITH - signal AS ( - INSERT INTO db_workflow.tagged_signals ( - signal_id, tags, signal_name, body, ray_id, create_ts - ) - VALUES ($1, $2, $3, $4, $5, $6) - RETURNING 1 - ), - send_event AS ( - INSERT INTO db_workflow.workflow_signal_send_events ( - workflow_id, location2, version, signal_id, signal_name, body, loop_location2 - ) - VALUES($7, $8, $9, $1, $3, $4, $10) - RETURNING 1 + sql_execute!( + [self] + " + WITH + signal AS ( + INSERT INTO db_workflow.tagged_signals ( + signal_id, tags, signal_name, body, ray_id, create_ts + ) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING 1 + ), + send_event AS ( + INSERT INTO db_workflow.workflow_signal_send_events ( + workflow_id, location2, version, signal_id, signal_name, body, loop_location2 ) - SELECT 1 - ", - signal_id, - tags, - signal_name, - sqlx::types::Json(body), - ray_id, - rivet_util::timestamp::now(), - from_workflow_id, - location, - version as i64, - loop_location, - ) - .await - }) + VALUES($7, $8, $9, $1, $3, $4, $10) + RETURNING 1 + ) + SELECT 1 + ", + signal_id, + tags, + signal_name, + sqlx::types::Json(body), + ray_id, + rivet_util::timestamp::now(), + from_workflow_id, + location, + version as i64, + loop_location, + ) .await?; self.wake_worker(); @@ -1396,24 +1352,21 @@ impl Database for DatabaseCrdbNats { ) }; - let (actual_sub_workflow_id,) = self - .query(|| async { - sqlx::query_as::<_, (Uuid,)>(query) - .bind(workflow_id) - .bind(sub_workflow_name) - .bind(rivet_util::timestamp::now()) - .bind(ray_id) - .bind(tags) - .bind(sqlx::types::Json(input)) - .bind(location) - .bind(version as i64) - .bind(sub_workflow_id) - .bind(loop_location) - .fetch_one(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) - .await?; + let (actual_sub_workflow_id,) = sql_fetch_one!( + [self, (Uuid,)] + query, + workflow_id, + sub_workflow_name, + rivet_util::timestamp::now(), + ray_id, + tags, + sqlx::types::Json(input), + location, + version as i64, + sub_workflow_id, + loop_location, + ) + .await?; if sub_workflow_id == actual_sub_workflow_id { self.wake_worker(); @@ -1429,20 +1382,16 @@ impl Database for DatabaseCrdbNats { _workflow_name: &str, tags: &serde_json::Value, ) -> WorkflowResult<()> { - self.query(|| async { - sqlx::query(indoc!( - " - UPDATE db_workflow.workflows - SET tags = $2 - WHERE workflow_id = $1 - ", - )) - .bind(workflow_id) - .bind(tags) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) + sql_execute!( + [self] + " + UPDATE db_workflow.workflows + SET tags = $2 + WHERE workflow_id = $1 + ", + workflow_id, + tags, + ) .await?; Ok(()) @@ -1462,84 +1411,76 @@ impl Database for DatabaseCrdbNats { ) -> WorkflowResult<()> { match res { Ok(output) => { - self.query(|| async { - sqlx::query(indoc!( - " - INSERT INTO db_workflow.workflow_activity_events ( - workflow_id, - location2, - version, - activity_name, - input_hash, - input, - output, - create_ts, - loop_location2 - ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - ON CONFLICT (workflow_id, location2_hash) DO UPDATE - SET output = EXCLUDED.output - ", - )) - .bind(workflow_id) - .bind(location) - .bind(version as i64) - .bind(&event_id.name) - .bind(event_id.input_hash.to_le_bytes()) - .bind(sqlx::types::Json(input)) - .bind(sqlx::types::Json(output)) - .bind(create_ts) - .bind(loop_location) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) + sql_execute!( + [self] + " + INSERT INTO db_workflow.workflow_activity_events ( + workflow_id, + location2, + version, + activity_name, + input_hash, + input, + output, + create_ts, + loop_location2 + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (workflow_id, location2_hash) DO UPDATE + SET output = EXCLUDED.output + ", + workflow_id, + location, + version as i64, + &event_id.name, + event_id.input_hash.to_le_bytes(), + sqlx::types::Json(input), + sqlx::types::Json(output), + create_ts, + loop_location, + ) .await?; } Err(err) => { - self.query(|| async { - sqlx::query(indoc!( - " - WITH - event AS ( - INSERT INTO db_workflow.workflow_activity_events ( - workflow_id, - location2, - version, - activity_name, - input_hash, - input, - create_ts, - loop_location2 - ) - VALUES ($1, $2, $3, $4, $5, $6, $8, $9) - ON CONFLICT (workflow_id, location2_hash) DO NOTHING - RETURNING 1 - ), - err AS ( - INSERT INTO db_workflow.workflow_activity_errors ( - workflow_id, location2, activity_name, error, ts - ) - VALUES ($1, $2, $4, $7, $10) - RETURNING 1 + sql_execute!( + [self] + " + WITH + event AS ( + INSERT INTO db_workflow.workflow_activity_events ( + workflow_id, + location2, + version, + activity_name, + input_hash, + input, + create_ts, + loop_location2 ) - SELECT 1 - ", - )) - .bind(workflow_id) - .bind(location) - .bind(version as i64) - .bind(&event_id.name) - .bind(event_id.input_hash.to_le_bytes()) - .bind(sqlx::types::Json(input)) - .bind(err) - .bind(create_ts) - .bind(loop_location) - .bind(rivet_util::timestamp::now()) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) + VALUES ($1, $2, $3, $4, $5, $6, $8, $9) + ON CONFLICT (workflow_id, location2_hash) DO NOTHING + RETURNING 1 + ), + err AS ( + INSERT INTO db_workflow.workflow_activity_errors ( + workflow_id, location2, activity_name, error, ts + ) + VALUES ($1, $2, $4, $7, $10) + RETURNING 1 + ) + SELECT 1 + ", + workflow_id, + location, + version as i64, + &event_id.name, + event_id.input_hash.to_le_bytes(), + sqlx::types::Json(input), + err, + create_ts, + loop_location, + rivet_util::timestamp::now(), + ) .await?; } } @@ -1558,27 +1499,23 @@ impl Database for DatabaseCrdbNats { body: &serde_json::value::RawValue, loop_location: Option<&Location>, ) -> WorkflowResult<()> { - self.query(|| async { - sqlx::query(indoc!( - " - INSERT INTO db_workflow.workflow_message_send_events ( - workflow_id, location2, version, tags, message_name, body, loop_location2 - ) - VALUES($1, $2, $3, $4, $5, $6, $7) - RETURNING 1 - ", - )) - .bind(from_workflow_id) - .bind(location) - .bind(version as i64) - .bind(tags) - .bind(message_name) - .bind(sqlx::types::Json(body)) - .bind(loop_location) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) + sql_execute!( + [self] + " + INSERT INTO db_workflow.workflow_message_send_events ( + workflow_id, location2, version, tags, message_name, body, loop_location2 + ) + VALUES($1, $2, $3, $4, $5, $6, $7) + RETURNING 1 + ", + from_workflow_id, + location, + version as i64, + tags, + message_name, + sqlx::types::Json(body), + loop_location, + ) .await?; Ok(()) @@ -1596,7 +1533,7 @@ impl Database for DatabaseCrdbNats { output: Option<&serde_json::value::RawValue>, loop_location: Option<&Location>, ) -> WorkflowResult<()> { - self.query(|| async { + self.txn(|| async { let mut conn = self.conn().await?; let mut tx = conn.begin().await.map_err(WorkflowError::Sqlx)?; @@ -1752,26 +1689,22 @@ impl Database for DatabaseCrdbNats { deadline_ts: i64, loop_location: Option<&Location>, ) -> WorkflowResult<()> { - self.query(|| async { - sqlx::query(indoc!( - " - INSERT INTO db_workflow.workflow_sleep_events ( - workflow_id, location2, version, deadline_ts, loop_location2, state - ) - VALUES($1, $2, $3, $4, $5, $6) - RETURNING 1 - ", - )) - .bind(from_workflow_id) - .bind(location) - .bind(version as i64) - .bind(deadline_ts) - .bind(loop_location) - .bind(SleepState::Normal as i64) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) + sql_execute!( + [self] + " + INSERT INTO db_workflow.workflow_sleep_events ( + workflow_id, location2, version, deadline_ts, loop_location2, state + ) + VALUES($1, $2, $3, $4, $5, $6) + RETURNING 1 + ", + from_workflow_id, + location, + version as i64, + deadline_ts, + loop_location, + SleepState::Normal as i64, + ) .await?; Ok(()) @@ -1784,21 +1717,17 @@ impl Database for DatabaseCrdbNats { location: &Location, state: SleepState, ) -> WorkflowResult<()> { - self.query(|| async { - sqlx::query(indoc!( - " - UPDATE db_workflow.workflow_sleep_events - SET state = $3 - WHERE workflow_id = $1 AND location2 = $2 - ", - )) - .bind(from_workflow_id) - .bind(location) - .bind(state as i64) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) + sql_execute!( + [self] + " + UPDATE db_workflow.workflow_sleep_events + SET state = $3 + WHERE workflow_id = $1 AND location2 = $2 + ", + from_workflow_id, + location, + state as i64, + ) .await?; Ok(()) @@ -1812,24 +1741,20 @@ impl Database for DatabaseCrdbNats { version: usize, loop_location: Option<&Location>, ) -> WorkflowResult<()> { - self.query(|| async { - sqlx::query(indoc!( - " - INSERT INTO db_workflow.workflow_branch_events ( - workflow_id, location, version, loop_location - ) - VALUES($1, $2, $3, $4) - RETURNING 1 - ", - )) - .bind(from_workflow_id) - .bind(location) - .bind(version as i64) - .bind(loop_location) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) + sql_execute!( + [self] + " + INSERT INTO db_workflow.workflow_branch_events ( + workflow_id, location, version, loop_location + ) + VALUES($1, $2, $3, $4) + RETURNING 1 + ", + from_workflow_id, + location, + version as i64, + loop_location, + ) .await?; Ok(()) @@ -1844,25 +1769,21 @@ impl Database for DatabaseCrdbNats { event_name: Option<&str>, loop_location: Option<&Location>, ) -> WorkflowResult<()> { - self.query(|| async { - sqlx::query(indoc!( - " - INSERT INTO db_workflow.workflow_removed_events ( - workflow_id, location, event_type, event_name, loop_location - ) - VALUES($1, $2, $3, $4, $5) - RETURNING 1 - ", - )) - .bind(from_workflow_id) - .bind(location) - .bind(event_type as i32) - .bind(event_name) - .bind(loop_location) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) + sql_execute!( + [self] + " + INSERT INTO db_workflow.workflow_removed_events ( + workflow_id, location, event_type, event_name, loop_location + ) + VALUES($1, $2, $3, $4, $5) + RETURNING 1 + ", + from_workflow_id, + location, + event_type as i32, + event_name, + loop_location, + ) .await?; Ok(()) @@ -1876,24 +1797,20 @@ impl Database for DatabaseCrdbNats { version: usize, loop_location: Option<&Location>, ) -> WorkflowResult<()> { - self.query(|| async { - sqlx::query(indoc!( - " - INSERT INTO db_workflow.workflow_version_check_events ( - workflow_id, location, version, loop_location - ) - VALUES($1, $2, $3, $4) - RETURNING 1 - ", - )) - .bind(from_workflow_id) - .bind(location) - .bind(version as i64) - .bind(loop_location) - .execute(&mut *self.conn().await?) - .await - .map_err(WorkflowError::Sqlx) - }) + sql_execute!( + [self] + " + INSERT INTO db_workflow.workflow_version_check_events ( + workflow_id, location, version, loop_location + ) + VALUES($1, $2, $3, $4) + RETURNING 1 + ", + from_workflow_id, + location, + version as i64, + loop_location, + ) .await?; Ok(()) diff --git a/packages/common/chirp-workflow/core/src/db/fdb_sqlite_nats/mod.rs b/packages/common/chirp-workflow/core/src/db/fdb_sqlite_nats/mod.rs index f812a3be52..544fc03532 100644 --- a/packages/common/chirp-workflow/core/src/db/fdb_sqlite_nats/mod.rs +++ b/packages/common/chirp-workflow/core/src/db/fdb_sqlite_nats/mod.rs @@ -92,9 +92,9 @@ impl DatabaseFdbSqliteNats { // MARK: Sqlite impl DatabaseFdbSqliteNats { - /// Executes SQL queries and explicitly handles retry errors. + /// Executes queries while explicitly handling txn retry errors. #[tracing::instrument(skip_all)] - async fn query<'a, F, Fut, T>(&self, mut cb: F) -> WorkflowResult + async fn txn<'a, F, Fut, T>(&self, mut cb: F) -> WorkflowResult where F: FnMut() -> Fut, Fut: std::future::Future> + 'a, @@ -108,14 +108,18 @@ impl DatabaseFdbSqliteNats { Err(WorkflowError::Sqlx(err)) => { i += 1; if i > MAX_QUERY_RETRIES { - return Err(WorkflowError::MaxSqlRetries(err)); + return Err(WorkflowError::Sqlx(sqlx::Error::Io(std::io::Error::new( + std::io::ErrorKind::Other, + rivet_pools::utils::sql_query_macros::Error::MaxSqlRetries(err), + )))); } use sqlx::Error::*; match &err { + // Retry all errors with a backoff Database(_) | Io(_) | Tls(_) | Protocol(_) | PoolTimedOut | PoolClosed | WorkerCrashed => { - tracing::warn!(?err, "query retry"); + tracing::warn!(?err, "txn retry"); backoff.tick().await; } // Throw error @@ -1995,49 +1999,44 @@ impl Database for DatabaseFdbSqliteNats { // In the event of an FDB txn retry, we have to delete the previously inserted row if is_retrying.load(Ordering::Relaxed) { - self.query(|| async { - sql_execute!( - [self, &pool] - " - DELETE FROM workflow_signal_events - WHERE location = jsonb(?1) - ", - location, - ) - .await - }) + sql_execute!( + [self, &pool] + " + DELETE FROM workflow_signal_events + WHERE location = jsonb(?1) + ", + location, + ) .await .map_err(|x| fdb::FdbBindingError::CustomError(x.into()))?; } // Insert history event - self.query(|| async { - sql_execute!( - [self, &pool] - " - INSERT INTO workflow_signal_events ( - location, - version, - signal_id, - signal_name, - body, - create_ts, - loop_location - ) - VALUES (jsonb(?1), ?2, ?3, ?4, jsonb(?5), ?6, jsonb(?7)) - ", + sql_execute!( + [self, &pool] + " + INSERT INTO workflow_signal_events ( location, - version as i64, + version, signal_id, - &signal_name, - sqlx::types::Json(&body), - rivet_util::timestamp::now(), - loop_location, + signal_name, + body, + create_ts, + loop_location ) - .await - }) + VALUES (jsonb(?1), ?2, ?3, ?4, jsonb(?5), ?6, jsonb(?7)) + ", + location, + version as i64, + signal_id, + &signal_name, + sqlx::types::Json(&body), + rivet_util::timestamp::now(), + loop_location, + ) .await .map_err(|x| fdb::FdbBindingError::CustomError(x.into()))?; + is_retrying.store(true, Ordering::Relaxed); Ok(Some(SignalData { @@ -2328,26 +2327,23 @@ impl Database for DatabaseFdbSqliteNats { .await?; // Insert history event - self.query(|| async { - sql_execute!( - [self, pool] - " - INSERT INTO workflow_signal_send_events ( - location, version, signal_id, signal_name, body, workflow_id, create_ts, loop_location - ) - VALUES (jsonb(?1), ?2, ?3, ?4, jsonb(?5), ?6, ?7, jsonb(?8)) - ", - location, - version as i64, - signal_id, - signal_name, - sqlx::types::Json(body), - to_workflow_id, - rivet_util::timestamp::now(), - loop_location, + sql_execute!( + [self, pool] + " + INSERT INTO workflow_signal_send_events ( + location, version, signal_id, signal_name, body, workflow_id, create_ts, loop_location ) - .await - }) + VALUES (jsonb(?1), ?2, ?3, ?4, jsonb(?5), ?6, ?7, jsonb(?8)) + ", + location, + version as i64, + signal_id, + signal_name, + sqlx::types::Json(body), + to_workflow_id, + rivet_util::timestamp::now(), + loop_location, + ) .await?; // Block while flushing databases in order ensure listeners have the latest data @@ -2367,17 +2363,14 @@ impl Database for DatabaseFdbSqliteNats { .await { // Undo history if FDB failed - self.query(|| async { - sql_execute!( - [self, pool] - " - DELETE FROM workflow_signal_send_events - WHERE location = jsonb(?1) - ", - location, - ) - .await - }) + sql_execute!( + [self, pool] + " + DELETE FROM workflow_signal_send_events + WHERE location = jsonb(?1) + ", + location, + ) .await?; self.flush_wf_sqlite(from_workflow_id)?; @@ -2424,33 +2417,30 @@ impl Database for DatabaseFdbSqliteNats { .await?; // Insert history event - self.query(|| async { - sql_execute!( - [self, pool] - " - INSERT INTO workflow_sub_workflow_events ( - location, - version, - sub_workflow_id, - sub_workflow_name, - tags, - input, - create_ts, - loop_location - ) - VALUES (jsonb(?1), ?2, ?3, ?4, jsonb(?5), jsonb(?6), ?7, jsonb(?8)) - ", + sql_execute!( + [self, pool] + " + INSERT INTO workflow_sub_workflow_events ( location, - version as i64, + version, sub_workflow_id, sub_workflow_name, tags, - sqlx::types::Json(input), - rivet_util::timestamp::now(), - loop_location, + input, + create_ts, + loop_location ) - .await - }) + VALUES (jsonb(?1), ?2, ?3, ?4, jsonb(?5), jsonb(?6), ?7, jsonb(?8)) + ", + location, + version as i64, + sub_workflow_id, + sub_workflow_name, + tags, + sqlx::types::Json(input), + rivet_util::timestamp::now(), + loop_location, + ) .await?; // Block while flushing databases in order ensure sub workflow have the latest data @@ -2479,17 +2469,14 @@ impl Database for DatabaseFdbSqliteNats { Ok(workflow_id) => Ok(workflow_id), Err(err) => { // Undo history if FDB failed - self.query(|| async { - sql_execute!( - [self, pool] - " - DELETE FROM workflow_sub_workflow_events - WHERE location = jsonb(?1) - ", - location, - ) - .await - }) + sql_execute!( + [self, pool] + " + DELETE FROM workflow_sub_workflow_events + WHERE location = jsonb(?1) + ", + location, + ) .await?; self.flush_wf_sqlite(workflow_id)?; @@ -2615,39 +2602,36 @@ impl Database for DatabaseFdbSqliteNats { match res { Ok(output) => { - self.query(|| async { - sql_execute!( - [self, pool] - " - INSERT INTO workflow_activity_events ( - location, - version, - activity_name, - input_hash, - input, - output, - create_ts, - loop_location - ) - VALUES (jsonb(?1), ?2, ?3, ?4, jsonb(?5), jsonb(?6), ?7, jsonb(?8)) - ON CONFLICT (location) DO UPDATE - SET output = EXCLUDED.output - ", + sql_execute!( + [self, pool] + " + INSERT INTO workflow_activity_events ( location, - version as i64, - &event_id.name, - input_hash.as_slice(), - sqlx::types::Json(input), - sqlx::types::Json(output), + version, + activity_name, + input_hash, + input, + output, create_ts, - loop_location, + loop_location ) - .await - }) + VALUES (jsonb(?1), ?2, ?3, ?4, jsonb(?5), jsonb(?6), ?7, jsonb(?8)) + ON CONFLICT (location) DO UPDATE + SET output = EXCLUDED.output + ", + location, + version as i64, + &event_id.name, + input_hash.as_slice(), + sqlx::types::Json(input), + sqlx::types::Json(output), + create_ts, + loop_location, + ) .await?; } Err(err) => { - self.query(|| async { + self.txn(|| async { let mut conn = pool.conn().await?; let mut tx = conn.begin().await?; @@ -2730,25 +2714,22 @@ impl Database for DatabaseFdbSqliteNats { .sqlite(sqlite::db_name_internal(from_workflow_id), false) .await?; - self.query(|| async { - sql_execute!( - [self, pool] - " - INSERT INTO workflow_message_send_events ( - location, version, tags, message_name, body, create_ts, loop_location - ) - VALUES (jsonb(?1), ?2, jsonb(?3), ?4, jsonb(?5), ?6, jsonb(?7)) - ", - location, - version as i64, - tags, - message_name, - sqlx::types::Json(body), - rivet_util::timestamp::now(), - loop_location, + sql_execute!( + [self, pool] + " + INSERT INTO workflow_message_send_events ( + location, version, tags, message_name, body, create_ts, loop_location ) - .await - }) + VALUES (jsonb(?1), ?2, jsonb(?3), ?4, jsonb(?5), ?6, jsonb(?7)) + ", + location, + version as i64, + tags, + message_name, + sqlx::types::Json(body), + rivet_util::timestamp::now(), + loop_location, + ) .await?; self.flush_wf_sqlite(from_workflow_id)?; @@ -2773,7 +2754,7 @@ impl Database for DatabaseFdbSqliteNats { .sqlite(sqlite::db_name_internal(workflow_id), false) .await?; - self.query(|| async { + self.txn(|| async { let mut conn = pool.conn().await?; let mut tx = conn.begin().await?; @@ -3062,24 +3043,21 @@ impl Database for DatabaseFdbSqliteNats { .sqlite(sqlite::db_name_internal(from_workflow_id), false) .await?; - self.query(|| async { - sql_execute!( - [self, pool] - " - INSERT INTO workflow_sleep_events ( - location, version, deadline_ts, create_ts, state, loop_location - ) - VALUES (jsonb(?1), ?2, ?3, ?4, ?5, jsonb(?6)) - ", - location, - version as i64, - deadline_ts, - rivet_util::timestamp::now(), - SleepState::Normal as i64, - loop_location, + sql_execute!( + [self, pool] + " + INSERT INTO workflow_sleep_events ( + location, version, deadline_ts, create_ts, state, loop_location ) - .await - }) + VALUES (jsonb(?1), ?2, ?3, ?4, ?5, jsonb(?6)) + ", + location, + version as i64, + deadline_ts, + rivet_util::timestamp::now(), + SleepState::Normal as i64, + loop_location, + ) .await?; Ok(()) @@ -3097,19 +3075,16 @@ impl Database for DatabaseFdbSqliteNats { .sqlite(sqlite::db_name_internal(from_workflow_id), false) .await?; - self.query(|| async { - sql_execute!( - [self, pool] - " - UPDATE workflow_sleep_events - SET state = ?1 - WHERE location = jsonb(?2) - ", - state as i64, - location, - ) - .await - }) + sql_execute!( + [self, pool] + " + UPDATE workflow_sleep_events + SET state = ?1 + WHERE location = jsonb(?2) + ", + state as i64, + location, + ) .await?; Ok(()) @@ -3128,22 +3103,19 @@ impl Database for DatabaseFdbSqliteNats { .sqlite(sqlite::db_name_internal(from_workflow_id), false) .await?; - self.query(|| async { - sql_execute!( - [self, pool] - " - INSERT INTO workflow_branch_events ( - location, version, create_ts, loop_location - ) - VALUES (jsonb(?1), ?2, ?3, jsonb(?4)) - ", - location, - version as i64, - rivet_util::timestamp::now(), - loop_location, + sql_execute!( + [self, pool] + " + INSERT INTO workflow_branch_events ( + location, version, create_ts, loop_location ) - .await - }) + VALUES (jsonb(?1), ?2, ?3, jsonb(?4)) + ", + location, + version as i64, + rivet_util::timestamp::now(), + loop_location, + ) .await?; Ok(()) @@ -3163,23 +3135,20 @@ impl Database for DatabaseFdbSqliteNats { .sqlite(sqlite::db_name_internal(from_workflow_id), false) .await?; - self.query(|| async { - sql_execute!( - [self, pool] - " - INSERT INTO workflow_removed_events ( - location, event_type, event_name, create_ts, loop_location - ) - VALUES (jsonb(?1), ?2, ?3, ?4, jsonb(?5)) - ", - location, - event_type as i64, - event_name, - rivet_util::timestamp::now(), - loop_location, + sql_execute!( + [self, pool] + " + INSERT INTO workflow_removed_events ( + location, event_type, event_name, create_ts, loop_location ) - .await - }) + VALUES (jsonb(?1), ?2, ?3, ?4, jsonb(?5)) + ", + location, + event_type as i64, + event_name, + rivet_util::timestamp::now(), + loop_location, + ) .await?; Ok(()) @@ -3198,22 +3167,19 @@ impl Database for DatabaseFdbSqliteNats { .sqlite(sqlite::db_name_internal(from_workflow_id), false) .await?; - self.query(|| async { - sql_execute!( - [self, pool] - " - INSERT INTO workflow_version_check_events ( - location, version, create_ts, loop_location - ) - VALUES (jsonb(?1), ?2, ?3, jsonb(?4)) - ", - location, - version as i64, - rivet_util::timestamp::now(), - loop_location, + sql_execute!( + [self, pool] + " + INSERT INTO workflow_version_check_events ( + location, version, create_ts, loop_location ) - .await - }) + VALUES (jsonb(?1), ?2, ?3, jsonb(?4)) + ", + location, + version as i64, + rivet_util::timestamp::now(), + loop_location, + ) .await?; Ok(()) diff --git a/packages/common/chirp-workflow/core/src/error.rs b/packages/common/chirp-workflow/core/src/error.rs index 541f2ed1c4..7a20f286a8 100644 --- a/packages/common/chirp-workflow/core/src/error.rs +++ b/packages/common/chirp-workflow/core/src/error.rs @@ -9,10 +9,6 @@ use crate::ctx::common::RETRY_TIMEOUT_MS; pub type WorkflowResult = Result; -/// Throwing this will eject from the workflow scope back in to the engine. -/// -/// This error should not be touched by the user and is only intended to be handled by the workflow -/// engine. #[derive(thiserror::Error, Debug)] pub enum WorkflowError { #[error("workflow failure: {0:?}")] @@ -151,9 +147,6 @@ pub enum WorkflowError { #[error("fdb error: {0}")] Fdb(#[from] fdb::FdbBindingError), - #[error("max sql retries, last error: {0}")] - MaxSqlRetries(sqlx::Error), - #[error("pools error: {0}")] Pools(#[from] rivet_pools::Error), diff --git a/packages/common/pools/src/lib.rs b/packages/common/pools/src/lib.rs index f6ac3dd8c4..4fc463834b 100644 --- a/packages/common/pools/src/lib.rs +++ b/packages/common/pools/src/lib.rs @@ -10,3 +10,7 @@ pub use crate::{ db::clickhouse::ClickHousePool, db::crdb::CrdbPool, db::fdb::FdbPool, db::nats::NatsPool, db::redis::RedisPool, db::sqlite::SqlitePool, error::Error, pools::Pools, }; + +// Re-export for macros +#[doc(hidden)] +pub use rivet_util as __rivet_util; diff --git a/packages/common/pools/src/utils/sql_query_macros.rs b/packages/common/pools/src/utils/sql_query_macros.rs index 6f80e4f5c5..2a67aa02c1 100644 --- a/packages/common/pools/src/utils/sql_query_macros.rs +++ b/packages/common/pools/src/utils/sql_query_macros.rs @@ -1,3 +1,14 @@ +/// Base retry for query retry backoff. +pub const QUERY_RETRY_MS: usize = 100; +/// Maximum times a query is retried. +pub const MAX_QUERY_RETRIES: usize = 6; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("max sql retries, last error: {0}")] + MaxSqlRetries(sqlx::Error), +} + lazy_static::lazy_static! { /// Rate limit used to limit creating a stampede of connections to the database. pub static ref CONN_ACQUIRE_RATE_LIMIT: governor::RateLimiter< @@ -132,22 +143,61 @@ macro_rules! __sql_query { async { use sqlx::Acquire; - let query = sqlx::query($crate::__opt_indoc!($sql)) - $( - .bind($bind) - )*; - // Acquire connection $crate::__sql_query_metrics_acquire!(_acquire); let driver = $driver; let mut conn = $crate::__sql_acquire!($ctx, driver); - // Execute query $crate::__sql_query_metrics_start!($ctx, execute, _acquire, _start); - let res = query.execute(&mut *conn).await.map_err(Into::::into); + + let mut backoff = $crate::__rivet_util::Backoff::new( + 4, + None, + $crate::utils::sql_query_macros::QUERY_RETRY_MS, + 50 + ); + let mut i = 0; + + // Retry loop + let res = loop { + let query = sqlx::query($crate::__opt_indoc!($sql)) + $( + .bind($bind) + )*; + + match query.execute(&mut *conn).await { + Err(err) => { + i += 1; + if i > $crate::utils::sql_query_macros::MAX_QUERY_RETRIES { + break Err( + sqlx::Error::Io( + std::io::Error::new( + std::io::ErrorKind::Other, + $crate::utils::sql_query_macros::Error::MaxSqlRetries(err), + ) + ) + ); + } + + use sqlx::Error::*; + match &err { + // Retry other errors with a backoff + Database(_) | Io(_) | Tls(_) | Protocol(_) | PoolTimedOut | PoolClosed + | WorkerCrashed => { + tracing::warn!(?err, "query retry ({i}/{})", $crate::utils::sql_query_macros::MAX_QUERY_RETRIES); + backoff.tick().await; + } + // Throw error + _ => break Err(err), + } + } + x => break x, + } + }; + $crate::__sql_query_metrics_finish!($ctx, execute, _start); - res + res.map_err(Into::::into) } .instrument(tracing::info_span!("sql_query")) }; diff --git a/packages/core/services/build/ops/create/src/lib.rs b/packages/core/services/build/ops/create/src/lib.rs index 8514654c26..d0bc01c1d9 100644 --- a/packages/core/services/build/ops/create/src/lib.rs +++ b/packages/core/services/build/ops/create/src/lib.rs @@ -129,7 +129,7 @@ async fn handle( env_id, upload_id, &ctx.display_name, - image_tag, + &image_tag, ctx.ts(), kind as i32, compression as i32, diff --git a/packages/core/services/build/src/ops/create.rs b/packages/core/services/build/src/ops/create.rs index c30d11ed59..50d2a6cf35 100644 --- a/packages/core/services/build/src/ops/create.rs +++ b/packages/core/services/build/src/ops/create.rs @@ -168,7 +168,7 @@ pub async fn get(ctx: &OperationCtx, input: &Input) -> GlobalResult { env_id, upload_id, &input.display_name, - image_tag, + &image_tag, ctx.ts(), input.kind as i32, input.compression as i32, diff --git a/packages/core/services/cdn/ops/version-publish/src/lib.rs b/packages/core/services/cdn/ops/version-publish/src/lib.rs index e5483106eb..e0277469d3 100644 --- a/packages/core/services/cdn/ops/version-publish/src/lib.rs +++ b/packages/core/services/cdn/ops/version-publish/src/lib.rs @@ -90,11 +90,11 @@ async fn handle( ) SELECT * FROM UNNEST($1, $2, $3, $4, $5) ", - version_ids, - globs, - priorities, - header_names, - header_values, + &version_ids, + &globs, + &priorities, + &header_names, + &header_values, ) .await?; } diff --git a/packages/core/services/cluster/src/workflows/datacenter/mod.rs b/packages/core/services/cluster/src/workflows/datacenter/mod.rs index 401a7a80ab..c08505a0ba 100644 --- a/packages/core/services/cluster/src/workflows/datacenter/mod.rs +++ b/packages/core/services/cluster/src/workflows/datacenter/mod.rs @@ -369,8 +369,8 @@ async fn update_db(ctx: &ActivityCtx, input: &UpdateDbInput) -> GlobalResult<()> input.datacenter_id, serde_json::to_string(&pools)?, input.prebakes_enabled, - gph_dns_parent, - gph_static + &gph_dns_parent, + &gph_static ) .await?; diff --git a/packages/core/services/mm-config/ops/version-publish/src/lib.rs b/packages/core/services/mm-config/ops/version-publish/src/lib.rs index 72cb91ce43..7a41e4ab46 100644 --- a/packages/core/services/mm-config/ops/version-publish/src/lib.rs +++ b/packages/core/services/mm-config/ops/version-publish/src/lib.rs @@ -40,7 +40,7 @@ async fn handle( ) VALUES ($1, $2, $3)", version_id, - captcha_buf, + &captcha_buf, util_mm::version_migrations::all(), ) .await?; diff --git a/packages/core/services/token/ops/create/src/lib.rs b/packages/core/services/token/ops/create/src/lib.rs index 4fce17dc00..0850d06b79 100644 --- a/packages/core/services/token/ops/create/src/lib.rs +++ b/packages/core/services/token/ops/create/src/lib.rs @@ -145,7 +145,7 @@ async fn handle( [ctx] "INSERT INTO db_token.sessions (session_id, entitlements, entitlement_tags, exp) VALUES ($1, $2, $3, $4)", new_session_id, - ent_bufs, + &ent_bufs, &tags, ctx.ts() + token_config.ttl, ).await?; diff --git a/packages/core/services/upload/ops/prepare/src/lib.rs b/packages/core/services/upload/ops/prepare/src/lib.rs index 89c85c2878..a686a8eb82 100644 --- a/packages/core/services/upload/ops/prepare/src/lib.rs +++ b/packages/core/services/upload/ops/prepare/src/lib.rs @@ -113,9 +113,9 @@ async fn handle( user_id, // Hardcoded to AWS since we don't use this feature anymore backend::upload::Provider::Aws as i64, - paths, - mimes, - content_lengths, + &paths, + &mimes, + &content_lengths, ) .await?; @@ -165,8 +165,8 @@ async fn handle( upload_files.path = v.path ", upload_id, - multipart_paths, - multipart_upload_ids, + &multipart_paths, + &multipart_upload_ids, ) .await .map_err(Into::::into) diff --git a/packages/edge/services/pegboard/src/workflows/actor/runtime.rs b/packages/edge/services/pegboard/src/workflows/actor/runtime.rs index 7943d0804a..3fcec4197b 100644 --- a/packages/edge/services/pegboard/src/workflows/actor/runtime.rs +++ b/packages/edge/services/pegboard/src/workflows/actor/runtime.rs @@ -82,7 +82,7 @@ async fn update_client(ctx: &ActivityCtx, input: &UpdateClientInput) -> GlobalRe ", input.client_id, input.client_workflow_id, - client_wan_hostname, + &client_wan_hostname, ) .await?;