Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 112 additions & 3 deletions libsql/src/sync/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,73 @@ async fn test_sync_context_retry_on_error() {
assert_eq!(server.frame_count(), 1);
}

#[tokio::test]
async fn test_bootstrap_db_downloads_export() {
let server = MockServer::start();
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path().join("bootstrap.db");

// Seed metadata so SyncContext can be constructed (generation=1)
gen_metadata_file(&db_path, 3278479626, 0, 0, 1);

let mut sync_ctx = SyncContext::new(
server.connector(),
db_path.to_str().unwrap().to_string(),
server.url(),
None,
None,
)
.await
.unwrap();


let _ = std::fs::remove_file(&db_path);
let _ = std::fs::remove_file(format!("{}-info", db_path.to_str().unwrap()));

// Bootstrap should fetch /info and then /export/{generation}
crate::sync::bootstrap_db(&mut sync_ctx).await.unwrap();

assert!(std::path::Path::new(&db_path).exists());
assert!(std::path::Path::new(&format!("{}-info", db_path.to_str().unwrap())).exists());

assert_eq!(sync_ctx.durable_generation(), 1);
assert_eq!(sync_ctx.durable_frame_num(), 0);

assert!(server.request_count() >= 2);
}

#[tokio::test]
async fn test_bootstrap_db_is_idempotent() {
let server = MockServer::start();
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path().join("bootstrap2.db");


gen_metadata_file(&db_path, 3278479626, 0, 0, 1);

let mut sync_ctx = SyncContext::new(
server.connector(),
db_path.to_str().unwrap().to_string(),
server.url(),
None,
None,
)
.await
.unwrap();

let _ = std::fs::remove_file(&db_path);
let _ = std::fs::remove_file(format!("{}-info", db_path.to_str().unwrap()));


crate::sync::bootstrap_db(&mut sync_ctx).await.unwrap();
let first_requests = server.request_count();

// Second bootstrap should be a no-op (no new network calls)
crate::sync::bootstrap_db(&mut sync_ctx).await.unwrap();
let second_requests = server.request_count();
assert_eq!(first_requests, second_requests);
}

#[test]
fn test_hash_verification() {
let mut metadata = MetadataJson {
Expand Down Expand Up @@ -328,12 +395,14 @@ impl Service<http::Uri> for MockConnector {
}
}

#[allow(dead_code)]
struct MockServer {
url: String,
frame_count: Arc<AtomicU32>,
connector: ConnectorService,
return_error: Arc<AtomicBool>,
request_count: Arc<AtomicU32>,
export_bytes: Arc<Vec<u8>>, // bytes returned by /export/{generation}
}

impl MockServer {
Expand All @@ -342,6 +411,25 @@ impl MockServer {
let return_error = Arc::new(AtomicBool::new(false));
let request_count = Arc::new(AtomicU32::new(0));

let export_bytes: Arc<Vec<u8>> = {
use crate::local::Database;
use crate::database::OpenFlags;
use std::fs;
use tempfile::NamedTempFile;

let tmp = NamedTempFile::new().expect("temp file for export db");
let path = tmp.path().to_path_buf();
let db = Database::open(path.to_str().unwrap().to_string(), OpenFlags::default())
.expect("open export db");
let conn = db.connect().expect("connect export db");

let _ = conn.query("CREATE TABLE IF NOT EXISTS t(x INTEGER);", crate::params::Params::None);
drop(conn);
drop(db);
let bytes = fs::read(&path).expect("read export db bytes");
Arc::new(bytes)
};

// Create the mock connector with Some(client_stream)
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let mock_connector = MockConnector { tx };
Expand All @@ -353,18 +441,21 @@ impl MockServer {
connector,
return_error: return_error.clone(),
request_count: request_count.clone(),
export_bytes: export_bytes.clone(),
};

// Spawn the server handler
let frame_count_clone = frame_count.clone();
let return_error_clone = return_error.clone();
let request_count_clone = request_count.clone();
let export_bytes_clone = export_bytes.clone();

tokio::spawn(async move {
while let Some(server_stream) = rx.recv().await {
let frame_count_clone = frame_count_clone.clone();
let return_error_clone = return_error_clone.clone();
let request_count_clone = request_count_clone.clone();
let export_bytes_clone = export_bytes_clone.clone();

tokio::spawn(async move {
use hyper::server::conn::Http;
Expand All @@ -377,6 +468,7 @@ impl MockServer {
let frame_count = frame_count_clone.clone();
let return_error = return_error_clone.clone();
let request_count = request_count_clone.clone();
let export_bytes = export_bytes_clone.clone();
async move {
request_count.fetch_add(1, Ordering::SeqCst);
if return_error.load(Ordering::SeqCst) {
Expand All @@ -388,9 +480,9 @@ impl MockServer {
);
}

let current_count = frame_count.fetch_add(1, Ordering::SeqCst);

if req.uri().path().contains("/sync/") {
// Count only sync requests as frames to keep tests stable.
let current_count = frame_count.fetch_add(1, Ordering::SeqCst);
// Return the max_frame_no that has been accepted
let response = serde_json::json!({
"status": "ok",
Expand All @@ -404,6 +496,23 @@ impl MockServer {
.body(Body::from(response.to_string()))
.unwrap(),
)
} else if req.uri().path().eq("/info") {
let response = serde_json::json!({
"current_generation": 1
});
Ok::<_, hyper::Error>(
http::Response::builder()
.status(200)
.body(Body::from(response.to_string()))
.unwrap(),
)
} else if req.uri().path().starts_with("/export/") {
Ok::<_, hyper::Error>(
http::Response::builder()
.status(200)
.body(Body::from(export_bytes.as_ref().clone()))
.unwrap(),
)
} else {
Ok(http::Response::builder()
.status(404)
Expand Down Expand Up @@ -489,4 +598,4 @@ fn gen_metadata_file(db_path: &Path, hash: u32, version: u32, durable_frame_num:
.as_bytes(),
)
.unwrap();
}
}
Loading