diff --git a/libsql/src/sync/test.rs b/libsql/src/sync/test.rs index c7fb0b65f3..104df12a31 100644 --- a/libsql/src/sync/test.rs +++ b/libsql/src/sync/test.rs @@ -251,6 +251,73 @@ async fn test_sync_context_retry_on_error() { assert_eq!(server.frame_count(), 1); } +#[tokio::test] +async fn test_bootstrap_db_downloads_export() { + let server = MockServer::start(); + let temp_dir = tempdir().unwrap(); + let db_path = temp_dir.path().join("bootstrap.db"); + + // Seed metadata so SyncContext can be constructed (generation=1) + gen_metadata_file(&db_path, 3278479626, 0, 0, 1); + + let mut sync_ctx = SyncContext::new( + server.connector(), + db_path.to_str().unwrap().to_string(), + server.url(), + None, + None, + ) + .await + .unwrap(); + + + let _ = std::fs::remove_file(&db_path); + let _ = std::fs::remove_file(format!("{}-info", db_path.to_str().unwrap())); + + // Bootstrap should fetch /info and then /export/{generation} + crate::sync::bootstrap_db(&mut sync_ctx).await.unwrap(); + + assert!(std::path::Path::new(&db_path).exists()); + assert!(std::path::Path::new(&format!("{}-info", db_path.to_str().unwrap())).exists()); + + assert_eq!(sync_ctx.durable_generation(), 1); + assert_eq!(sync_ctx.durable_frame_num(), 0); + + assert!(server.request_count() >= 2); +} + +#[tokio::test] +async fn test_bootstrap_db_is_idempotent() { + let server = MockServer::start(); + let temp_dir = tempdir().unwrap(); + let db_path = temp_dir.path().join("bootstrap2.db"); + + + gen_metadata_file(&db_path, 3278479626, 0, 0, 1); + + let mut sync_ctx = SyncContext::new( + server.connector(), + db_path.to_str().unwrap().to_string(), + server.url(), + None, + None, + ) + .await + .unwrap(); + + let _ = std::fs::remove_file(&db_path); + let _ = std::fs::remove_file(format!("{}-info", db_path.to_str().unwrap())); + + + crate::sync::bootstrap_db(&mut sync_ctx).await.unwrap(); + let first_requests = server.request_count(); + + // Second bootstrap should be a no-op (no new network calls) + crate::sync::bootstrap_db(&mut sync_ctx).await.unwrap(); + let second_requests = server.request_count(); + assert_eq!(first_requests, second_requests); +} + #[test] fn test_hash_verification() { let mut metadata = MetadataJson { @@ -328,12 +395,14 @@ impl Service for MockConnector { } } +#[allow(dead_code)] struct MockServer { url: String, frame_count: Arc, connector: ConnectorService, return_error: Arc, request_count: Arc, + export_bytes: Arc>, // bytes returned by /export/{generation} } impl MockServer { @@ -342,6 +411,25 @@ impl MockServer { let return_error = Arc::new(AtomicBool::new(false)); let request_count = Arc::new(AtomicU32::new(0)); + let export_bytes: Arc> = { + use crate::local::Database; + use crate::database::OpenFlags; + use std::fs; + use tempfile::NamedTempFile; + + let tmp = NamedTempFile::new().expect("temp file for export db"); + let path = tmp.path().to_path_buf(); + let db = Database::open(path.to_str().unwrap().to_string(), OpenFlags::default()) + .expect("open export db"); + let conn = db.connect().expect("connect export db"); + + let _ = conn.query("CREATE TABLE IF NOT EXISTS t(x INTEGER);", crate::params::Params::None); + drop(conn); + drop(db); + let bytes = fs::read(&path).expect("read export db bytes"); + Arc::new(bytes) + }; + // Create the mock connector with Some(client_stream) let (tx, mut rx) = tokio::sync::mpsc::channel(1); let mock_connector = MockConnector { tx }; @@ -353,18 +441,21 @@ impl MockServer { connector, return_error: return_error.clone(), request_count: request_count.clone(), + export_bytes: export_bytes.clone(), }; // Spawn the server handler let frame_count_clone = frame_count.clone(); let return_error_clone = return_error.clone(); let request_count_clone = request_count.clone(); + let export_bytes_clone = export_bytes.clone(); tokio::spawn(async move { while let Some(server_stream) = rx.recv().await { let frame_count_clone = frame_count_clone.clone(); let return_error_clone = return_error_clone.clone(); let request_count_clone = request_count_clone.clone(); + let export_bytes_clone = export_bytes_clone.clone(); tokio::spawn(async move { use hyper::server::conn::Http; @@ -377,6 +468,7 @@ impl MockServer { let frame_count = frame_count_clone.clone(); let return_error = return_error_clone.clone(); let request_count = request_count_clone.clone(); + let export_bytes = export_bytes_clone.clone(); async move { request_count.fetch_add(1, Ordering::SeqCst); if return_error.load(Ordering::SeqCst) { @@ -388,9 +480,9 @@ impl MockServer { ); } - let current_count = frame_count.fetch_add(1, Ordering::SeqCst); - if req.uri().path().contains("/sync/") { + // Count only sync requests as frames to keep tests stable. + let current_count = frame_count.fetch_add(1, Ordering::SeqCst); // Return the max_frame_no that has been accepted let response = serde_json::json!({ "status": "ok", @@ -404,6 +496,23 @@ impl MockServer { .body(Body::from(response.to_string())) .unwrap(), ) + } else if req.uri().path().eq("/info") { + let response = serde_json::json!({ + "current_generation": 1 + }); + Ok::<_, hyper::Error>( + http::Response::builder() + .status(200) + .body(Body::from(response.to_string())) + .unwrap(), + ) + } else if req.uri().path().starts_with("/export/") { + Ok::<_, hyper::Error>( + http::Response::builder() + .status(200) + .body(Body::from(export_bytes.as_ref().clone())) + .unwrap(), + ) } else { Ok(http::Response::builder() .status(404) @@ -489,4 +598,4 @@ fn gen_metadata_file(db_path: &Path, hash: u32, version: u32, durable_frame_num: .as_bytes(), ) .unwrap(); -} \ No newline at end of file +}