diff --git a/Cargo.lock b/Cargo.lock index 375d0e8a1..7fbd99098 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1269,6 +1269,14 @@ dependencies = [ "litrs", ] +[[package]] +name = "domain-validator" +version = "0.1.0" +dependencies = [ + "async-trait", + "thiserror 2.0.14", +] + [[package]] name = "dotenvy" version = "0.15.7" @@ -1513,6 +1521,19 @@ version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" +[[package]] +name = "fly-api" +version = "0.1.0" +dependencies = [ + "async-trait", + "domain-validator", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.14", + "tokio", +] + [[package]] name = "fnv" version = "1.0.7" diff --git a/Cargo.toml b/Cargo.toml index d4886c191..bcae47a15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,10 +11,12 @@ members = [ "crates/breez-sdk/bindings", "crates/breez-sdk/cli", "crates/breez-sdk/common", - "crates/breez-sdk/core", + "crates/breez-sdk/core", "crates/breez-sdk/lnurl-models", "crates/breez-sdk/breez-itest", "crates/breez-sdk/wasm", + "crates/domain-validator", + "crates/fly-api", "crates/macros", "crates/macro_test", "crates/spark", diff --git a/crates/breez-sdk/lnurl/Cargo.lock b/crates/breez-sdk/lnurl/Cargo.lock index 49bd0981d..79a610a36 100644 --- a/crates/breez-sdk/lnurl/Cargo.lock +++ b/crates/breez-sdk/lnurl/Cargo.lock @@ -1011,6 +1011,14 @@ dependencies = [ "litrs", ] +[[package]] +name = "domain-validator" +version = "0.1.0" +dependencies = [ + "async-trait", + "thiserror 2.0.16", +] + [[package]] name = "dotenvy" version = "0.15.7" @@ -1215,6 +1223,19 @@ dependencies = [ "spin", ] +[[package]] +name = "fly-api" +version = "0.1.0" +dependencies = [ + "async-trait", + "domain-validator", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", +] + [[package]] name = "fnv" version = "1.0.7" @@ -2210,7 +2231,9 @@ dependencies = [ "base64", "bitcoin", "clap", + "domain-validator", "figment", + "fly-api", "hex", "lightning-invoice", "lnurl-models", diff --git a/crates/breez-sdk/lnurl/Cargo.toml b/crates/breez-sdk/lnurl/Cargo.toml index bce1b3dee..f51026783 100644 --- a/crates/breez-sdk/lnurl/Cargo.toml +++ b/crates/breez-sdk/lnurl/Cargo.toml @@ -15,6 +15,8 @@ bitcoin = { version = "0.32.6", features = ["serde"] } clap = { version = "4.5.40", features = ["derive"] } figment = { version = "0.10.19", features = ["env", "toml"] } hex = "0.4.3" +domain-validator = { path = "../../domain-validator" } +fly-api = { path = "../../fly-api" } lightning-invoice = { version = "0.33.2", features = ["std"] } lnurl-models = { path = "../lnurl-models" } nostr = { version = "0.43.1", default-features = false, features = ["std", "nip57"] } diff --git a/crates/breez-sdk/lnurl/src/main.rs b/crates/breez-sdk/lnurl/src/main.rs index 552df90f3..b458ad806 100644 --- a/crates/breez-sdk/lnurl/src/main.rs +++ b/crates/breez-sdk/lnurl/src/main.rs @@ -9,10 +9,12 @@ use axum::{ }; use base64::{Engine, prelude::BASE64_STANDARD}; use clap::Parser; +use domain_validator::ListDomainValidator; use figment::{ Figment, providers::{Env, Format, Serialized, Toml}, }; +use fly_api::FlyDomainValidator; use serde::{Deserialize, Serialize}; use spark::operator::rpc::DefaultConnectionManager; use spark::session_manager::InMemorySessionManager; @@ -97,6 +99,16 @@ struct Args { /// If set, the server will use this certificate to validate api keys. #[arg(long)] pub ca_cert: Option, + + /// Fly.io app name for certificate-based domain validation. + /// If set along with --fly-api-token, enables Fly.io certificate validation. + #[arg(long)] + pub fly_app_name: Option, + + /// Fly.io API token for certificate-based domain validation. + /// If set along with --fly-app-name, enables Fly.io certificate validation. + #[arg(long)] + pub fly_api_token: Option, } #[tokio::main] @@ -191,12 +203,17 @@ where ) .await?, ); - - let domains = args - .domains - .split(',') - .map(|d| d.trim().to_lowercase()) - .collect(); + let domain_validator: Arc = + if let (Some(app_name), Some(api_token)) = (args.fly_app_name, args.fly_api_token) { + Arc::new(FlyDomainValidator::new(app_name, api_token)) + } else { + let domains = args + .domains + .split(',') + .map(|d| d.trim().to_lowercase()) + .collect(); + Arc::new(ListDomainValidator::new(domains)) + }; let ca_cert = args .ca_cert @@ -250,7 +267,7 @@ where min_sendable: args.min_sendable, max_sendable: args.max_sendable, include_spark_address: args.include_spark_address, - domains, + domain_validator: domain_validator, nostr_keys, ca_cert, connection_manager, diff --git a/crates/breez-sdk/lnurl/src/routes.rs b/crates/breez-sdk/lnurl/src/routes.rs index 702470e91..4d195635a 100644 --- a/crates/breez-sdk/lnurl/src/routes.rs +++ b/crates/breez-sdk/lnurl/src/routes.rs @@ -95,9 +95,10 @@ where Extension(state): Extension>, ) -> Result, (StatusCode, Json)> { let username = sanitize_username(&identifier); + let domain = host.trim().to_lowercase(); let user = state .db - .get_user_by_name(&sanitize_domain(&state, &host)?, &username) + .get_user_by_name(&domain, &username) .await .map_err(|e| { error!("failed to execute query: {}", e); @@ -127,8 +128,15 @@ where Json(Value::String("description too long".into())), )); } + let domain = host.trim().to_lowercase(); + + // Use domain validator from state + if let Err(e) = state.domain_validator.validate_domain(&domain).await { + warn!("domain not allowed for registration: {} - {}", domain, e); + return Err((StatusCode::NOT_FOUND, Json(Value::String(String::new())))); + } let user = User { - domain: sanitize_domain(&state, &host)?, + domain, pubkey: pubkey.to_string(), name: username, description: payload.description, @@ -167,9 +175,10 @@ where let username = sanitize_username(&payload.username); let pubkey = validate(&pubkey, &payload.signature, &username, &state).await?; + let domain = host.trim().to_lowercase(); state .db - .delete_user(&sanitize_domain(&state, &host)?, &pubkey.to_string()) + .delete_user(&domain, &pubkey.to_string()) .await .map_err(|e| { error!("failed to execute query: {}", e); @@ -190,9 +199,10 @@ where ) -> Result, (StatusCode, Json)> { let pubkey = validate(&pubkey, &payload.signature, &pubkey, &state).await?; + let domain = host.trim().to_lowercase(); let user = state .db - .get_user_by_pubkey(&sanitize_domain(&state, &host)?, &pubkey.to_string()) + .get_user_by_pubkey(&domain, &pubkey.to_string()) .await .map_err(|e| { error!("failed to execute query: {}", e); @@ -229,9 +239,10 @@ where } let username = sanitize_username(&identifier); + let domain = host.trim().to_lowercase(); let user = state .db - .get_user_by_name(&sanitize_domain(&state, &host)?, &username) + .get_user_by_name(&domain, &username) .await .map_err(|e| { error!("failed to execute query: {}", e); @@ -273,7 +284,7 @@ where } let username = sanitize_username(&identifier); - let domain = sanitize_domain(&state, &host)?; + let domain = host.trim().to_lowercase(); let user = state .db .get_user_by_name(&domain, &username) @@ -583,15 +594,3 @@ fn lnurl_error(message: &str) -> (StatusCode, Json) { )), ) } - -fn sanitize_domain( - state: &State, - domain: &str, -) -> Result)> { - let domain = domain.trim().to_lowercase(); - if !state.domains.contains(&domain) { - warn!("domain not allowed: {}", domain); - return Err((StatusCode::NOT_FOUND, Json(Value::String(String::new())))); - } - Ok(domain) -} diff --git a/crates/breez-sdk/lnurl/src/state.rs b/crates/breez-sdk/lnurl/src/state.rs index 974913f5c..65715f5cd 100644 --- a/crates/breez-sdk/lnurl/src/state.rs +++ b/crates/breez-sdk/lnurl/src/state.rs @@ -1,9 +1,9 @@ -use spark::operator::OperatorConfig; use spark::operator::rpc::ConnectionManager; +use spark::operator::OperatorConfig; use spark::session_manager::InMemorySessionManager; use spark::ssp::ServiceProvider; use spark_wallet::DefaultSigner; -use std::{collections::HashSet, sync::Arc}; +use std::sync::Arc; use tokio::sync::Mutex; pub struct State { @@ -13,7 +13,7 @@ pub struct State { pub min_sendable: u64, pub max_sendable: u64, pub include_spark_address: bool, - pub domains: HashSet, + pub domain_validator: Arc, pub nostr_keys: Option, pub ca_cert: Option>, pub connection_manager: Arc, @@ -21,7 +21,7 @@ pub struct State { pub signer: Arc, pub session_manager: Arc, pub service_provider: Arc, - pub subscribed_keys: Arc>>, + pub subscribed_keys: Arc>>, } impl Clone for State @@ -36,7 +36,7 @@ where min_sendable: self.min_sendable, max_sendable: self.max_sendable, include_spark_address: self.include_spark_address, - domains: self.domains.clone(), + domain_validator: Arc::clone(&self.domain_validator), nostr_keys: self.nostr_keys.clone(), ca_cert: self.ca_cert.clone(), connection_manager: self.connection_manager.clone(), diff --git a/crates/domain-validator/Cargo.toml b/crates/domain-validator/Cargo.toml new file mode 100644 index 000000000..dcc20953f --- /dev/null +++ b/crates/domain-validator/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "domain-validator" +edition = "2024" +version = "0.1.0" + +[dependencies] +async-trait = "0.1.88" + +thiserror = "2.0.12" + +[lints] +clippy.suspicious = { level = "warn", priority = -1 } +clippy.complexity = { level = "warn", priority = -1 } +clippy.perf = { level = "warn", priority = -1 } +clippy.style = { level = "warn", priority = -1 } +clippy.pedantic = { level = "warn", priority = -1 } +clippy.missing_errors_doc = "allow" +clippy.missing_panics_doc = "allow" +clippy.must_use_candidate = "allow" +clippy.struct_field_names = "allow" +clippy.arithmetic_side_effects = "warn" diff --git a/crates/domain-validator/src/lib.rs b/crates/domain-validator/src/lib.rs new file mode 100644 index 000000000..df08fbc4c --- /dev/null +++ b/crates/domain-validator/src/lib.rs @@ -0,0 +1,54 @@ +use std::collections::HashSet; + +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum DomainValidatorError { + #[error("Domain {0} is not allowed")] + DomainNotAllowed(String), +} + +#[async_trait::async_trait] +pub trait DomainValidator: Send + Sync { + async fn validate_domain(&self, domain: &str) -> Result<(), DomainValidatorError>; +} + +pub struct ListDomainValidator { + allowed_domains: HashSet, +} + +impl ListDomainValidator { + pub fn new(domains: HashSet) -> Self { + Self { + allowed_domains: domains, + } + } +} + +#[async_trait::async_trait] +impl DomainValidator for ListDomainValidator { + async fn validate_domain(&self, domain: &str) -> Result<(), DomainValidatorError> { + let domain_lower = domain.to_lowercase(); + if self.allowed_domains.contains(&domain_lower) { + Ok(()) + } else { + Err(DomainValidatorError::DomainNotAllowed(domain.to_string())) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_list_domain_validator() { + let domains = HashSet::from(["example.com".to_string(), "test.org".to_string()]); + let validator = ListDomainValidator::new(domains); + + assert!(validator.validate_domain("example.com").await.is_ok()); + assert!(validator.validate_domain("EXAMPLE.COM").await.is_ok()); + assert!(validator.validate_domain("test.org").await.is_ok()); + assert!(validator.validate_domain("invalid.com").await.is_err()); + } +} diff --git a/crates/fly-api/Cargo.toml b/crates/fly-api/Cargo.toml new file mode 100644 index 000000000..af1c2fab7 --- /dev/null +++ b/crates/fly-api/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "fly-api" +edition = "2024" +version = "0.1.0" + +[dependencies] +reqwest = { version = "0.12.10", features = ["json"] } +serde = { version = "1.0.219", features = ["derive"] } +serde_json = "1.0.140" +thiserror = "2.0.12" +tokio = { version = "1.45.1", features = ["rt-multi-thread", "macros", "signal"] } +domain-validator = { path = "../domain-validator" } +async-trait = "0.1.88" + +[lints] +clippy.suspicious = { level = "warn", priority = -1 } +clippy.complexity = { level = "warn", priority = -1 } +clippy.perf = { level = "warn", priority = -1 } +clippy.style = { level = "warn", priority = -1 } +clippy.pedantic = { level = "warn", priority = -1 } +clippy.missing_errors_doc = "allow" +clippy.missing_panics_doc = "allow" +clippy.must_use_candidate = "allow" +clippy.struct_field_names = "allow" +clippy.arithmetic_side_effects = "warn" diff --git a/crates/fly-api/src/lib.rs b/crates/fly-api/src/lib.rs new file mode 100644 index 000000000..ac4b7c7e0 --- /dev/null +++ b/crates/fly-api/src/lib.rs @@ -0,0 +1,128 @@ +use std::collections::HashSet; + +use async_trait::async_trait; +use domain_validator::{DomainValidator, DomainValidatorError}; +use reqwest::Client; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum FlyApiError { + #[error("Failed to fetch certificates: {0}")] + FetchError(String), +} + +pub struct FlyDomainValidator { + app_name: String, + api_token: String, + client: Client, +} + +impl FlyDomainValidator { + pub fn new(app_name: String, api_token: String) -> Self { + Self { + app_name, + api_token, + client: Client::new(), + } + } + + async fn get_certificate_domains(&self) -> Result, FlyApiError> { + let graphql_query = serde_json::json!({ + "query": r#" + query($appName: String!) { + app(name: $appName) { + certificates { + nodes { + hostname + clientStatus + } + } + } + } + "#, + "variables": { + "appName": self.app_name + } + }); + + let response = self + .client + .post("https://api.fly.io/graphql") + .header("Authorization", format!("Bearer {}", self.api_token)) + .header("Content-Type", "application/json") + .json(&graphql_query) + .send() + .await + .map_err(|e| FlyApiError::FetchError(e.to_string()))?; + + if !response.status().is_success() { + return Err(FlyApiError::FetchError(format!( + "HTTP {}: {}", + response.status(), + response.text().await.unwrap_or_default() + ))); + } + + let response_data: serde_json::Value = response + .json() + .await + .map_err(|e| FlyApiError::FetchError(e.to_string()))?; + + let mut allowed_domains = HashSet::new(); + + if let Some(certificates) = response_data + .get("data") + .and_then(|d| d.get("app")) + .and_then(|a| a.get("certificates")) + .and_then(|c| c.get("nodes")) + .and_then(|n| n.as_array()) + { + for cert in certificates { + if let Some(hostname) = cert.get("hostname").and_then(|h| h.as_str()) { + allowed_domains.insert(hostname.to_lowercase()); + } + } + } + + if allowed_domains.is_empty() { + return Err(FlyApiError::FetchError( + "No domains found in Fly.io certificates".to_string(), + )); + } + + Ok(allowed_domains) + } +} + +#[async_trait] +impl DomainValidator for FlyDomainValidator { + async fn validate_domain(&self, domain: &str) -> Result<(), DomainValidatorError> { + let allowed_domains = self + .get_certificate_domains() + .await + .map_err(|e| DomainValidatorError::DomainNotAllowed(e.to_string()))?; + + let domain_lower = domain.to_lowercase(); + + if allowed_domains.contains(&domain_lower) { + Ok(()) + } else { + Err(DomainValidatorError::DomainNotAllowed(format!( + "Domain {} not found in Fly.io certificates", + domain + ))) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_fly_domain_validator_creation() { + let fly_validator = FlyDomainValidator::new("test-app".to_string(), "test-token".to_string()); + assert_eq!(fly_validator.app_name, "test-app"); + assert_eq!(fly_validator.api_token, "test-token"); + } +}