diff --git a/crates/orchestrator/src/cli.rs b/crates/orchestrator/src/cli.rs new file mode 100644 index 00000000..be93a657 --- /dev/null +++ b/crates/orchestrator/src/cli.rs @@ -0,0 +1,464 @@ +use std::sync::Arc; + +use alloy::providers::Provider; +use anyhow::Result; +use clap::Parser; +use futures::FutureExt; +use log::{debug, error, info}; +use shared::{ + utils::google_cloud::GcsStorageProvider, + web3::{contracts::core::builder::ContractBuilder, wallet::Wallet}, +}; +use tokio::task::JoinSet; +use tokio_util::sync::CancellationToken; +use url::Url; + +use crate::{ + start_server, DiscoveryMonitor, LoopHeartbeats, MetricsContext, MetricsSyncService, + MetricsWebhookSender, NodeGroupConfiguration, NodeGroupsPlugin, NodeInviter, NodeStatusUpdater, + P2PService, RedisStore, Scheduler, SchedulerPlugin, ServerMode, StatusUpdatePlugin, + StoreContext, WebhookConfig, WebhookPlugin, +}; + +#[derive(Parser)] +pub struct Cli { + // Server mode + #[arg(long, default_value = "full")] + // TODO: directly parse into `ServerMode` + pub mode: String, + + /// RPC URL + #[arg(short = 'r', long, default_value = "http://localhost:8545")] + pub rpc_url: String, + + /// Owner key + #[arg(short = 'k', long)] + pub coordinator_key: String, + + /// Compute pool id + #[arg(long, default_value = "0")] + pub compute_pool_id: u32, + + /// Domain id + #[arg(short = 'd', long, default_value = "0")] + pub domain_id: u32, + + /// External ip - advertised to workers + #[arg(short = 'e', long)] + pub host: Option, + + /// Port + #[arg(short = 'p', long, default_value = "8090")] + pub port: u16, + + /// External url - advertised to workers + #[arg(short = 'u', long)] + pub url: Option, + + /// Discovery refresh interval + #[arg(short = 'i', long, default_value = "10")] + pub discovery_refresh_interval: u64, + + /// Redis store url + #[arg(short = 's', long, default_value = "redis://localhost:6380")] + pub redis_store_url: String, + + /// Admin api key + #[arg(short = 'a', long, default_value = "admin")] + pub admin_api_key: String, + + /// Disable instance ejection from chain + #[arg(long)] + pub disable_ejection: bool, + + /// Hourly s3 upload limit + #[arg(long, default_value = "2")] + pub hourly_s3_upload_limit: i64, + + /// S3 bucket name + #[arg(long)] + pub bucket_name: Option, + + /// Log level + #[arg(short = 'l', long, default_value = "info")] + pub log_level: String, + + /// Node group management interval + #[arg(long, default_value = "10")] + pub node_group_management_interval: u64, + + /// Max healthy nodes with same endpoint + #[arg(long, default_value = "1")] + pub max_healthy_nodes_with_same_endpoint: u32, + + /// Libp2p port + #[arg(long, default_value = "4004")] + pub libp2p_port: u16, + + /// Comma-separated list of libp2p bootnode multiaddresses + /// Example: `/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ,/ip4/104.131.131.82/udp/4001/quic-v1/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ` + #[arg(long, default_value = "")] + pub bootnodes: String, + + /// Location service URL (e.g., https://ipapi.co). If not provided, location services are disabled. + #[arg(long)] + pub location_service_url: Option, + + /// Location service API key + #[arg(long)] + pub location_service_api_key: Option, +} + +impl Cli { + pub async fn run(self, cancellation_token: CancellationToken) -> anyhow::Result<()> { + let server_mode = match self.mode.as_str() { + "api" => ServerMode::ApiOnly, + "processor" => ServerMode::ProcessorOnly, + "full" => ServerMode::Full, + _ => anyhow::bail!("invalid server mode: {}", self.mode), + }; + + debug!("Server mode: {server_mode:?}"); + + let metrics_context = Arc::new(MetricsContext::new(self.compute_pool_id.to_string())); + + let heartbeats = Arc::new(LoopHeartbeats::new(&server_mode)); + + let compute_pool_id = self.compute_pool_id; + let domain_id = self.domain_id; + let coordinator_key = self.coordinator_key; + let rpc_url: Url = self.rpc_url.parse().unwrap(); + + let mut tasks: JoinSet> = JoinSet::new(); + + let wallet = Wallet::new(&coordinator_key, rpc_url).unwrap_or_else(|err| { + error!("Error creating wallet: {err:?}"); + std::process::exit(1); + }); + + let store = Arc::new(RedisStore::new(&self.redis_store_url)); + let store_context = Arc::new(StoreContext::new(store.clone())); + + let keypair = p2p::Keypair::generate_ed25519(); + let bootnodes: Vec = self + .bootnodes + .split(',') + .filter_map(|addr| match addr.to_string().try_into() { + Ok(multiaddr) => Some(multiaddr), + Err(e) => { + error!("Invalid bootnode address '{addr}': {e}"); + None + } + }) + .collect(); + if bootnodes.is_empty() { + error!( + "No valid bootnodes provided. Please provide at least one valid bootnode address." + ); + std::process::exit(1); + } + + let (p2p_service, invite_tx, get_task_logs_tx, restart_task_tx, kademlia_action_tx) = { + match P2PService::new( + keypair, + self.libp2p_port, + bootnodes, + cancellation_token.clone(), + wallet.clone(), + ) { + Ok(res) => { + info!("p2p service initialized successfully"); + res + } + Err(e) => { + error!("failed to initialize p2p service: {e}"); + std::process::exit(1); + } + } + }; + + tokio::task::spawn(p2p_service.run()); + + let contracts = ContractBuilder::new(wallet.provider().root().clone()) + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_compute_pool() + .build() + .unwrap(); + + let contracts_with_wallet = ContractBuilder::new(wallet.provider()) + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_compute_pool() + .build() + .unwrap(); + + let group_store_context = store_context.clone(); + let mut scheduler_plugins: Vec = Vec::new(); + let mut status_update_plugins: Vec = vec![]; + let mut node_groups_plugin: Option> = None; + let mut webhook_plugins: Vec = vec![]; + + let configs = std::env::var("WEBHOOK_CONFIGS").unwrap_or_default(); + if !configs.is_empty() { + match serde_json::from_str::>(&configs) { + Ok(configs) => { + for config in configs { + let plugin = WebhookPlugin::new(config); + let plugin_clone = plugin.clone(); + webhook_plugins.push(plugin_clone); + status_update_plugins.push(plugin.into()); + info!("Plugin: Webhook plugin initialized"); + } + } + Err(e) => { + error!("Failed to parse webhook configs from environment: {e}"); + } + } + } else { + info!("No webhook configurations provided"); + } + + let webhook_sender_store = store_context.clone(); + let webhook_plugins_clone = webhook_plugins.clone(); + if !webhook_plugins_clone.is_empty() && server_mode != ServerMode::ApiOnly { + tasks.spawn(async move { + let mut webhook_sender = MetricsWebhookSender::new( + webhook_sender_store.clone(), + webhook_plugins_clone.clone(), + compute_pool_id, + ); + if let Err(e) = webhook_sender.run().await { + error!("Error running webhook sender: {e}"); + } + Ok(()) + }); + } + + // Load node group configurations from environment variable + let node_group_configs = std::env::var("NODE_GROUP_CONFIGS").unwrap_or_default(); + if !node_group_configs.is_empty() { + match serde_json::from_str::>(&node_group_configs) { + Ok(configs) if !configs.is_empty() => { + let node_groups_heartbeats = heartbeats.clone(); + + let group_plugin = Arc::new(NodeGroupsPlugin::new( + configs, + store.clone(), + group_store_context.clone(), + Some(node_groups_heartbeats.clone()), + Some(webhook_plugins.clone()), + )); + + // Register the plugin as a task observer + group_store_context + .task_store + .add_observer(group_plugin.clone()) + .await; + + let status_group_plugin = group_plugin.clone(); + let group_plugin_for_server = group_plugin.clone(); + + node_groups_plugin = Some(group_plugin_for_server); + scheduler_plugins.push(group_plugin.into()); + status_update_plugins.push(status_group_plugin.into()); + info!("Plugin: Node group plugin initialized"); + } + Ok(_) => { + info!( + "No node group configurations provided in environment, skipping plugin setup" + ); + } + Err(e) => { + error!("Failed to parse node group configurations from environment: {e}"); + std::process::exit(1); + } + } + } + + let scheduler = Scheduler::new(store_context.clone(), scheduler_plugins); + + // Only spawn processor tasks if in ProcessorOnly or Full mode + if matches!(server_mode, ServerMode::ProcessorOnly | ServerMode::Full) { + // Start metrics sync service to centralize metrics from Redis to Prometheus + let metrics_sync_store_context = store_context.clone(); + let metrics_sync_context = metrics_context.clone(); + let metrics_sync_node_groups = node_groups_plugin.clone(); + tasks.spawn(async move { + let sync_service = MetricsSyncService::new( + metrics_sync_store_context, + metrics_sync_context, + server_mode, + 10, + metrics_sync_node_groups, + ); + sync_service.run().await + }); + + if let Some(group_plugin) = node_groups_plugin.clone() { + tasks.spawn(async move { + group_plugin + .run_group_management_loop(self.node_group_management_interval) + .await + }); + } + + // Create status_update_plugins for discovery monitor + let mut discovery_status_update_plugins: Vec = vec![]; + + // Add webhook plugins to discovery status update plugins + for plugin in &webhook_plugins { + discovery_status_update_plugins.push(plugin.into()); + } + + // Add node groups plugin if available + if let Some(group_plugin) = node_groups_plugin.clone() { + discovery_status_update_plugins.push(group_plugin.into()); + } + + let discovery_store_context = store_context.clone(); + let discovery_heartbeats = heartbeats.clone(); + let monitor = match DiscoveryMonitor::new( + compute_pool_id, + self.discovery_refresh_interval, + discovery_store_context.clone(), + discovery_heartbeats.clone(), + discovery_status_update_plugins, + kademlia_action_tx, + wallet.provider().root().clone(), + contracts.clone(), + self.location_service_url, + self.location_service_api_key, + ) { + Ok(monitor) => { + info!("Discovery monitor initialized successfully"); + monitor + } + Err(e) => { + error!("Failed to initialize discovery monitor: {e}"); + std::process::exit(1); + } + }; + + tasks.spawn(monitor.run(cancellation_token.clone()).map(|_| Ok(()))); + + let inviter_store_context = store_context.clone(); + let inviter_heartbeats = heartbeats.clone(); + let wallet = wallet.clone(); + let inviter = match NodeInviter::new( + wallet, + compute_pool_id, + domain_id, + self.host.as_deref(), + Some(&self.port), + self.url.as_deref(), + inviter_store_context.clone(), + inviter_heartbeats.clone(), + invite_tx, + ) { + Ok(inviter) => { + info!("Node inviter initialized successfully"); + inviter + } + Err(e) => { + error!("Failed to initialize node inviter: {e}"); + std::process::exit(1); + } + }; + + tasks.spawn(async move { inviter.run().await }); + + // Create status_update_plugins for status updater + let mut status_updater_plugins: Vec = vec![]; + + // Add webhook plugins to status updater plugins + for plugin in &webhook_plugins { + status_updater_plugins.push(plugin.into()); + } + + // Add node groups plugin if available + if let Some(group_plugin) = node_groups_plugin.clone() { + status_updater_plugins.push(group_plugin.into()); + } + + let status_update_store_context = store_context.clone(); + let status_update_heartbeats = heartbeats.clone(); + let status_update_metrics = metrics_context.clone(); + tasks.spawn({ + let contracts = contracts_with_wallet.clone(); + async move { + let status_updater = NodeStatusUpdater::new( + status_update_store_context.clone(), + 15, + None, + contracts, + compute_pool_id, + self.disable_ejection, + status_update_heartbeats.clone(), + status_updater_plugins, + status_update_metrics, + ); + status_updater.run().await + } + }); + } + + let port = self.port; + let server_store_context = store_context.clone(); + let s3_credentials = std::env::var("S3_CREDENTIALS").ok(); + let storage_provider: Option> = + match (self.bucket_name.as_ref(), s3_credentials) { + (Some(bucket_name), Some(s3_credentials)) + if !bucket_name.is_empty() && !s3_credentials.is_empty() => + { + let gcs_storage = GcsStorageProvider::new(bucket_name, &s3_credentials) + .await + .unwrap_or_else(|_| panic!("Failed to create GCS storage provider")); + Some(Arc::new(gcs_storage) as Arc) + } + _ => { + info!("Bucket name or S3 credentials not provided, storage provider disabled"); + None + } + }; + + // Always start server regardless of mode + tokio::select! { + res = start_server( + "0.0.0.0", + port, + server_store_context.clone(), + self.admin_api_key, + storage_provider, + heartbeats.clone(), + store.clone(), + self.hourly_s3_upload_limit, + Some(contracts_with_wallet), + compute_pool_id, + server_mode, + scheduler, + node_groups_plugin, + metrics_context, + get_task_logs_tx, + restart_task_tx, + ) => { + if let Err(e) = res { + error!("Server error: {e}"); + } + } + Some(res) = tasks.join_next() => { + if let Err(e) = res? { + error!("Task error: {e}"); + } + } + _ = cancellation_token.cancelled() => { + error!("Shutdown signal received"); + } + } + + tasks.shutdown().await; + Ok(()) + } +} diff --git a/crates/orchestrator/src/discovery/monitor.rs b/crates/orchestrator/src/discovery/monitor.rs index 54156371..7f6c850e 100644 --- a/crates/orchestrator/src/discovery/monitor.rs +++ b/crates/orchestrator/src/discovery/monitor.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use std::time::Duration; use tokio::sync::mpsc::Sender; use tokio::time::interval; +use tokio_util::sync::CancellationToken; #[derive(Clone)] struct NodeFetcher { @@ -362,7 +363,7 @@ impl DiscoveryMonitor { }) } - pub async fn run(self) { + pub async fn run(self, cancellation_token: CancellationToken) { use futures::StreamExt as _; let Self { @@ -444,6 +445,10 @@ impl DiscoveryMonitor { } } } + _ = cancellation_token.cancelled() => { + error!("Shutdown signal received, stopping discovery monitor"); + break; + } } } } diff --git a/crates/orchestrator/src/lib.rs b/crates/orchestrator/src/lib.rs index 19d13eba..81be7f0b 100644 --- a/crates/orchestrator/src/lib.rs +++ b/crates/orchestrator/src/lib.rs @@ -1,4 +1,5 @@ mod api; +mod cli; mod discovery; mod metrics; mod models; @@ -11,6 +12,7 @@ mod store; mod utils; pub use api::server::start_server; +pub use cli::Cli; pub use discovery::monitor::DiscoveryMonitor; pub use metrics::sync_service::MetricsSyncService; pub use metrics::webhook_sender::MetricsWebhookSender; diff --git a/crates/orchestrator/src/main.rs b/crates/orchestrator/src/main.rs index 4ea7a1bd..62239adf 100644 --- a/crates/orchestrator/src/main.rs +++ b/crates/orchestrator/src/main.rs @@ -1,124 +1,22 @@ -use alloy::providers::Provider; use anyhow::Result; use clap::Parser; -use futures::FutureExt; use log::debug; -use log::error; -use log::info; use log::LevelFilter; -use shared::utils::google_cloud::GcsStorageProvider; -use shared::web3::contracts::core::builder::ContractBuilder; -use shared::web3::wallet::Wallet; -use std::sync::Arc; -use tokio::task::JoinSet; -use tokio_util::sync::CancellationToken; -use url::Url; - -use orchestrator::{ - start_server, LoopHeartbeats, MetricsContext, MetricsSyncService, MetricsWebhookSender, - NodeGroupConfiguration, NodeGroupsPlugin, NodeInviter, NodeStatusUpdater, P2PService, - RedisStore, Scheduler, SchedulerPlugin, ServerMode, StatusUpdatePlugin, StoreContext, - WebhookConfig, WebhookPlugin, -}; - -#[derive(Parser)] -struct Args { - // Server mode - #[arg(long, default_value = "full")] - mode: String, - - /// RPC URL - #[arg(short = 'r', long, default_value = "http://localhost:8545")] - rpc_url: String, - - /// Owner key - #[arg(short = 'k', long)] - coordinator_key: String, - - /// Compute pool id - #[arg(long, default_value = "0")] - compute_pool_id: u32, - - /// Domain id - #[arg(short = 'd', long, default_value = "0")] - domain_id: u32, - - /// External ip - advertised to workers - #[arg(short = 'e', long)] - host: Option, - - /// Port - #[arg(short = 'p', long, default_value = "8090")] - port: u16, - - /// External url - advertised to workers - #[arg(short = 'u', long)] - url: Option, - - /// Discovery refresh interval - #[arg(short = 'i', long, default_value = "10")] - discovery_refresh_interval: u64, - - /// Redis store url - #[arg(short = 's', long, default_value = "redis://localhost:6380")] - redis_store_url: String, - /// Admin api key - #[arg(short = 'a', long, default_value = "admin")] - admin_api_key: String, - - /// Disable instance ejection from chain - #[arg(long)] - disable_ejection: bool, - - /// Hourly s3 upload limit - #[arg(long, default_value = "2")] - hourly_s3_upload_limit: i64, - - /// S3 bucket name - #[arg(long)] - bucket_name: Option, - - /// Log level - #[arg(short = 'l', long, default_value = "info")] - log_level: String, - - /// Node group management interval - #[arg(long, default_value = "10")] - node_group_management_interval: u64, - - /// Max healthy nodes with same endpoint - #[arg(long, default_value = "1")] - max_healthy_nodes_with_same_endpoint: u32, - - /// Libp2p port - #[arg(long, default_value = "4004")] - libp2p_port: u16, - - /// Comma-separated list of libp2p bootnode multiaddresses - /// Example: `/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ,/ip4/104.131.131.82/udp/4001/quic-v1/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ` - #[arg(long, default_value = "")] - bootnodes: String, - - /// Location service URL (e.g., https://ipapi.co). If not provided, location services are disabled. - #[arg(long)] - location_service_url: Option, - - /// Location service API key - #[arg(long)] - location_service_api_key: Option, -} +use orchestrator::Cli; +use shared::utils::signal::trigger_cancellation_on_signal; +use tokio_util::sync::CancellationToken; #[tokio::main] async fn main() -> Result<()> { - let args = Args::parse(); - let log_level = match args.log_level.as_str() { + let cli = Cli::parse(); + let log_level = match cli.log_level.as_str() { "error" => LevelFilter::Error, "warn" => LevelFilter::Warn, "info" => LevelFilter::Info, "debug" => LevelFilter::Debug, "trace" => LevelFilter::Trace, - _ => anyhow::bail!("invalid log level: {}", args.log_level), + _ => anyhow::bail!("invalid log level: {}", cli.log_level), }; env_logger::Builder::new() .filter_level(log_level) @@ -126,358 +24,22 @@ async fn main() -> Result<()> { .filter_module("tracing::span", log::LevelFilter::Warn) .init(); - let server_mode = match args.mode.as_str() { - "api" => ServerMode::ApiOnly, - "processor" => ServerMode::ProcessorOnly, - "full" => ServerMode::Full, - _ => anyhow::bail!("invalid server mode: {}", args.mode), - }; - debug!("Log level: {log_level}"); - debug!("Server mode: {server_mode:?}"); - - let metrics_context = Arc::new(MetricsContext::new(args.compute_pool_id.to_string())); - - let heartbeats = Arc::new(LoopHeartbeats::new(&server_mode)); - - let compute_pool_id = args.compute_pool_id; - let domain_id = args.domain_id; - let coordinator_key = args.coordinator_key; - let rpc_url: Url = args.rpc_url.parse().unwrap(); - - let mut tasks: JoinSet> = JoinSet::new(); - - let wallet = Wallet::new(&coordinator_key, rpc_url).unwrap_or_else(|err| { - error!("Error creating wallet: {err:?}"); - std::process::exit(1); - }); - - let store = Arc::new(RedisStore::new(&args.redis_store_url)); - let store_context = Arc::new(StoreContext::new(store.clone())); - - let keypair = p2p::Keypair::generate_ed25519(); - let bootnodes: Vec = args - .bootnodes - .split(',') - .filter_map(|addr| match addr.to_string().try_into() { - Ok(multiaddr) => Some(multiaddr), - Err(e) => { - error!("Invalid bootnode address '{addr}': {e}"); - None - } - }) - .collect(); - if bootnodes.is_empty() { - error!("No valid bootnodes provided. Please provide at least one valid bootnode address."); - std::process::exit(1); - } let cancellation_token = CancellationToken::new(); - let (p2p_service, invite_tx, get_task_logs_tx, restart_task_tx, kademlia_action_tx) = { - match P2PService::new( - keypair, - args.libp2p_port, - bootnodes, - cancellation_token.clone(), - wallet.clone(), - ) { - Ok(res) => { - info!("p2p service initialized successfully"); - res - } - Err(e) => { - error!("failed to initialize p2p service: {e}"); - std::process::exit(1); - } - } - }; - - tokio::task::spawn(p2p_service.run()); - - let contracts = ContractBuilder::new(wallet.provider().root().clone()) - .with_compute_registry() - .with_ai_token() - .with_prime_network() - .with_compute_pool() - .build() - .unwrap(); + let _signal_handle = trigger_cancellation_on_signal(cancellation_token.clone())?; - let contracts_with_wallet = ContractBuilder::new(wallet.provider()) - .with_compute_registry() - .with_ai_token() - .with_prime_network() - .with_compute_pool() - .build() - .unwrap(); - - let group_store_context = store_context.clone(); - let mut scheduler_plugins: Vec = Vec::new(); - let mut status_update_plugins: Vec = vec![]; - let mut node_groups_plugin: Option> = None; - let mut webhook_plugins: Vec = vec![]; - - let configs = std::env::var("WEBHOOK_CONFIGS").unwrap_or_default(); - if !configs.is_empty() { - match serde_json::from_str::>(&configs) { - Ok(configs) => { - for config in configs { - let plugin = WebhookPlugin::new(config); - let plugin_clone = plugin.clone(); - webhook_plugins.push(plugin_clone); - status_update_plugins.push(plugin.into()); - info!("Plugin: Webhook plugin initialized"); - } - } - Err(e) => { - error!("Failed to parse webhook configs from environment: {e}"); - } - } - } else { - info!("No webhook configurations provided"); - } - - let webhook_sender_store = store_context.clone(); - let webhook_plugins_clone = webhook_plugins.clone(); - if !webhook_plugins_clone.is_empty() && server_mode != ServerMode::ApiOnly { - tasks.spawn(async move { - let mut webhook_sender = MetricsWebhookSender::new( - webhook_sender_store.clone(), - webhook_plugins_clone.clone(), - compute_pool_id, - ); - if let Err(e) = webhook_sender.run().await { - error!("Error running webhook sender: {e}"); - } - Ok(()) - }); - } - - // Load node group configurations from environment variable - let node_group_configs = std::env::var("NODE_GROUP_CONFIGS").unwrap_or_default(); - if !node_group_configs.is_empty() { - match serde_json::from_str::>(&node_group_configs) { - Ok(configs) if !configs.is_empty() => { - let node_groups_heartbeats = heartbeats.clone(); - - let group_plugin = Arc::new(NodeGroupsPlugin::new( - configs, - store.clone(), - group_store_context.clone(), - Some(node_groups_heartbeats.clone()), - Some(webhook_plugins.clone()), - )); - - // Register the plugin as a task observer - group_store_context - .task_store - .add_observer(group_plugin.clone()) - .await; - - let status_group_plugin = group_plugin.clone(); - let group_plugin_for_server = group_plugin.clone(); - - node_groups_plugin = Some(group_plugin_for_server); - scheduler_plugins.push(group_plugin.into()); - status_update_plugins.push(status_group_plugin.into()); - info!("Plugin: Node group plugin initialized"); - } - Ok(_) => { - info!( - "No node group configurations provided in environment, skipping plugin setup" - ); - } - Err(e) => { - error!("Failed to parse node group configurations from environment: {e}"); - std::process::exit(1); - } - } - } - - let scheduler = Scheduler::new(store_context.clone(), scheduler_plugins); - - // Only spawn processor tasks if in ProcessorOnly or Full mode - if matches!(server_mode, ServerMode::ProcessorOnly | ServerMode::Full) { - // Start metrics sync service to centralize metrics from Redis to Prometheus - let metrics_sync_store_context = store_context.clone(); - let metrics_sync_context = metrics_context.clone(); - let metrics_sync_node_groups = node_groups_plugin.clone(); - tasks.spawn(async move { - let sync_service = MetricsSyncService::new( - metrics_sync_store_context, - metrics_sync_context, - server_mode, - 10, - metrics_sync_node_groups, - ); - sync_service.run().await - }); - - if let Some(group_plugin) = node_groups_plugin.clone() { - tasks.spawn(async move { - group_plugin - .run_group_management_loop(args.node_group_management_interval) - .await - }); - } - - // Create status_update_plugins for discovery monitor - let mut discovery_status_update_plugins: Vec = vec![]; - - // Add webhook plugins to discovery status update plugins - for plugin in &webhook_plugins { - discovery_status_update_plugins.push(plugin.into()); - } - - // Add node groups plugin if available - if let Some(group_plugin) = node_groups_plugin.clone() { - discovery_status_update_plugins.push(group_plugin.into()); - } - - let discovery_store_context = store_context.clone(); - let discovery_heartbeats = heartbeats.clone(); - let monitor = match orchestrator::DiscoveryMonitor::new( - compute_pool_id, - args.discovery_refresh_interval, - discovery_store_context.clone(), - discovery_heartbeats.clone(), - discovery_status_update_plugins, - kademlia_action_tx, - wallet.provider().root().clone(), - contracts.clone(), - args.location_service_url, - args.location_service_api_key, - ) { - Ok(monitor) => { - info!("Discovery monitor initialized successfully"); - monitor - } - Err(e) => { - error!("Failed to initialize discovery monitor: {e}"); - std::process::exit(1); - } - }; - - tasks.spawn( - // TODO: refactor task handling (https://github.com/PrimeIntellect-ai/protocol/issues/627) - monitor.run().map(|_| Ok(())), - ); - - let inviter_store_context = store_context.clone(); - let inviter_heartbeats = heartbeats.clone(); - let wallet = wallet.clone(); - let inviter = match NodeInviter::new( - wallet, - compute_pool_id, - domain_id, - args.host.as_deref(), - Some(&args.port), - args.url.as_deref(), - inviter_store_context.clone(), - inviter_heartbeats.clone(), - invite_tx, - ) { - Ok(inviter) => { - info!("Node inviter initialized successfully"); - inviter - } - Err(e) => { - error!("Failed to initialize node inviter: {e}"); - std::process::exit(1); - } - }; - - tasks.spawn(async move { inviter.run().await }); - - // Create status_update_plugins for status updater - let mut status_updater_plugins: Vec = vec![]; - - // Add webhook plugins to status updater plugins - for plugin in &webhook_plugins { - status_updater_plugins.push(plugin.into()); - } - - // Add node groups plugin if available - if let Some(group_plugin) = node_groups_plugin.clone() { - status_updater_plugins.push(group_plugin.into()); - } - - let status_update_store_context = store_context.clone(); - let status_update_heartbeats = heartbeats.clone(); - let status_update_metrics = metrics_context.clone(); - tasks.spawn({ - let contracts = contracts_with_wallet.clone(); - async move { - let status_updater = NodeStatusUpdater::new( - status_update_store_context.clone(), - 15, - None, - contracts, - compute_pool_id, - args.disable_ejection, - status_update_heartbeats.clone(), - status_updater_plugins, - status_update_metrics, - ); - status_updater.run().await - } - }); - } - - let port = args.port; - let server_store_context = store_context.clone(); - let s3_credentials = std::env::var("S3_CREDENTIALS").ok(); - let storage_provider: Option> = - match (args.bucket_name.as_ref(), s3_credentials) { - (Some(bucket_name), Some(s3_credentials)) - if !bucket_name.is_empty() && !s3_credentials.is_empty() => - { - let gcs_storage = GcsStorageProvider::new(bucket_name, &s3_credentials) - .await - .unwrap_or_else(|_| panic!("Failed to create GCS storage provider")); - Some(Arc::new(gcs_storage) as Arc) - } - _ => { - info!("Bucket name or S3 credentials not provided, storage provider disabled"); - None - } - }; - - // Always start server regardless of mode tokio::select! { - res = start_server( - "0.0.0.0", - port, - server_store_context.clone(), - args.admin_api_key, - storage_provider, - heartbeats.clone(), - store.clone(), - args.hourly_s3_upload_limit, - Some(contracts_with_wallet), - compute_pool_id, - server_mode, - scheduler, - node_groups_plugin, - metrics_context, - get_task_logs_tx, - restart_task_tx, - ) => { - if let Err(e) = res { - error!("Server error: {e}"); - } - } - Some(res) = tasks.join_next() => { - if let Err(e) = res? { - error!("Task error: {e}"); + cmd_result = cli.run(cancellation_token.clone()) => { + if let Err(e) = cmd_result { + log::error!("Command execution error: {e}"); + cancellation_token.cancel(); } } - _ = tokio::signal::ctrl_c() => { - error!("Shutdown signal received"); + _ = cancellation_token.cancelled() => { + log::info!("Received cancellation request"); } } - // TODO: use cancellation token to gracefully shutdown tasks (https://github.com/PrimeIntellect-ai/protocol/issues/627) - cancellation_token.cancel(); - tasks.shutdown().await; Ok(()) } diff --git a/crates/shared/src/utils/mod.rs b/crates/shared/src/utils/mod.rs index 290f1ae5..be83994c 100644 --- a/crates/shared/src/utils/mod.rs +++ b/crates/shared/src/utils/mod.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use std::sync::Arc; use tokio::sync::Mutex; pub mod google_cloud; +pub mod signal; use anyhow::Result; #[async_trait] diff --git a/crates/shared/src/utils/signal.rs b/crates/shared/src/utils/signal.rs new file mode 100644 index 00000000..0a10c76d --- /dev/null +++ b/crates/shared/src/utils/signal.rs @@ -0,0 +1,39 @@ +use tokio::{ + io, + signal::unix::{signal, SignalKind}, + task::JoinHandle, +}; +use tokio_util::sync::CancellationToken; + +// Spawn a task to listen for signals and call `cancel` on the cancellation token +// when a signal is received. +// +// Returns a handle to the spawned task. +pub fn trigger_cancellation_on_signal( + cancellation_token: CancellationToken, +) -> io::Result> { + let mut sigterm = signal(SignalKind::terminate())?; + let mut sigint = signal(SignalKind::interrupt())?; + let mut sighup = signal(SignalKind::hangup())?; + let mut sigquit = signal(SignalKind::quit())?; + + let signal_handle = tokio::spawn(async move { + tokio::select! { + _ = sigterm.recv() => { + log::info!("Received termination signal"); + } + _ = sigint.recv() => { + log::info!("Received interrupt signal"); + } + _ = sighup.recv() => { + log::info!("Received hangup signal"); + } + _ = sigquit.recv() => { + log::info!("Received quit signal"); + } + } + cancellation_token.cancel(); + }); + + Ok(signal_handle) +} diff --git a/crates/validator/src/cli.rs b/crates/validator/src/cli.rs new file mode 100644 index 00000000..492e1e66 --- /dev/null +++ b/crates/validator/src/cli.rs @@ -0,0 +1,331 @@ +use std::str::FromStr; +use std::sync::Arc; + +use actix_web::{web, App, HttpResponse, HttpServer}; +use alloy::primitives::{utils::Unit, U256}; +use clap::Parser; +use log::{error, info}; +use shared::{ + security::api_key_middleware::ApiKeyMiddleware, + utils::google_cloud::GcsStorageProvider, + web3::{contracts::core::builder::ContractBuilder, wallet::Wallet}, +}; +use tokio_util::sync::CancellationToken; +use url::Url; + +use crate::{ + export_metrics, + handler::{get_rejections, health_check, State}, + HardwareValidator, InvalidationType, MetricsContext, P2PService, RedisStore, + SyntheticDataValidator, Validator, +}; + +#[derive(Parser)] +pub struct Cli { + /// RPC URL + #[arg(short = 'r', long, default_value = "http://localhost:8545")] + pub rpc_url: String, + + /// Owner key + #[arg(short = 'k', long)] + pub validator_key: String, + + /// Ability to disable hardware validation + #[arg(long, default_value = "false")] + pub disable_hardware_validation: bool, + + /// Optional: Pool Id for work validation + /// If not provided, the validator will not validate work + #[arg(long, default_value = None)] + pub pool_id: Option, + + /// Optional: Toploc Grace Interval in seconds between work validation requests + #[arg(long, default_value = "15")] + pub toploc_grace_interval: u64, + + /// Optional: interval in minutes of max age of work on chain + #[arg(long, default_value = "15")] + pub toploc_work_validation_interval: u64, + + /// Optional: interval in minutes of max age of work on chain + #[arg(long, default_value = "120")] + pub toploc_work_validation_unknown_status_expiry_seconds: u64, + + /// Disable toploc ejection + /// If true, the validator will not invalidate work on toploc + #[arg(long, default_value = "false")] + pub disable_toploc_invalidation: bool, + + /// Optional: batch trigger size + #[arg(long, default_value = "10")] + pub batch_trigger_size: usize, + + /// Grouping + #[arg(long, default_value = "false")] + pub use_grouping: bool, + + /// Grace period in minutes for incomplete groups to recover (0 = disabled) + #[arg(long, default_value = "0")] + pub incomplete_group_grace_period_minutes: u64, + + /// Optional: toploc invalidation type + #[arg(long, default_value = "hard")] + pub toploc_invalidation_type: InvalidationType, + + /// Optional: work unit invalidation type + #[arg(long, default_value = "hard")] + pub work_unit_invalidation_type: InvalidationType, + + /// Optional: Validator penalty in whole tokens + /// Note: This value will be multiplied by 10^18 (1 token = 10^18 wei) + #[arg(long, default_value = "200")] + pub validator_penalty: u64, + + /// Temporary: S3 bucket name + #[arg(long, default_value = None)] + pub bucket_name: Option, + + /// Log level + #[arg(short = 'l', long, default_value = "info")] + pub log_level: String, + + /// Redis URL + #[arg(long, default_value = "redis://localhost:6380")] + pub redis_url: String, + + /// Libp2p port + #[arg(long, default_value = "4003")] + pub libp2p_port: u16, + + /// Comma-separated list of libp2p bootnode multiaddresses + /// Example: `/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ,/ip4/104.131.131.82/udp/4001/quic-v1/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ` + #[arg(long, default_value = "")] + pub bootnodes: String, +} + +impl Cli { + pub async fn run(self, cancellation_token: CancellationToken) -> anyhow::Result<()> { + let private_key_validator = self.validator_key; + let rpc_url: Url = self.rpc_url.parse().unwrap(); + + let redis_store = RedisStore::new(&self.redis_url); + + let validator_wallet = Wallet::new(&private_key_validator, rpc_url).unwrap_or_else(|err| { + error!("Error creating wallet: {err:?}"); + std::process::exit(1); + }); + + let mut contract_builder = ContractBuilder::new(validator_wallet.provider()) + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_compute_pool() + .with_domain_registry() + .with_stake_manager(); + + let contracts = contract_builder.build_partial().unwrap(); + let metrics_ctx = + MetricsContext::new(validator_wallet.address().to_string(), self.pool_id.clone()); + + let keypair = p2p::Keypair::generate_ed25519(); + let bootnodes: Vec = self + .bootnodes + .split(',') + .filter_map(|addr| match addr.to_string().try_into() { + Ok(multiaddr) => Some(multiaddr), + Err(e) => { + error!("Invalid bootnode address '{addr}': {e}"); + None + } + }) + .collect(); + if bootnodes.is_empty() { + error!( + "No valid bootnodes provided. Please provide at least one valid bootnode address." + ); + std::process::exit(1); + } + + let (p2p_service, hardware_challenge_tx, kademlia_action_tx) = { + match P2PService::new( + keypair, + self.libp2p_port, + bootnodes, + cancellation_token.clone(), + validator_wallet.clone(), + ) { + Ok(res) => { + info!("p2p service initialized successfully"); + res + } + Err(e) => { + error!("failed to initialize p2p service: {e}"); + std::process::exit(1); + } + } + }; + + tokio::task::spawn(p2p_service.run()); + + if let Some(pool_id) = self.pool_id.clone() { + let pool = match contracts + .compute_pool + .get_pool_info(U256::from_str(&pool_id).unwrap()) + .await + { + Ok(pool_info) => pool_info, + Err(e) => { + error!("Failed to get pool info: {e:?}"); + std::process::exit(1); + } + }; + let domain_id: u32 = pool.domain_id.try_into().unwrap(); + let domain = contracts + .domain_registry + .as_ref() + .unwrap() + .get_domain(domain_id) + .await + .unwrap(); + contract_builder = + contract_builder.with_synthetic_data_validator(Some(domain.validation_logic)); + } + + let contracts = contract_builder.build().unwrap(); + + let hardware_validator = HardwareValidator::new(contracts.clone(), hardware_challenge_tx); + + let synthetic_validator = if let Some(pool_id) = self.pool_id.clone() { + let penalty = U256::from(self.validator_penalty) * Unit::ETHER.wei(); + match contracts.synthetic_data_validator.clone() { + Some(validator) => { + info!( + "Synthetic validator has penalty: {} ({})", + penalty, self.validator_penalty + ); + + let Ok(toploc_configs) = std::env::var("TOPLOC_CONFIGS") else { + error!("Toploc configs are required but not provided in environment"); + std::process::exit(1); + }; + info!("Toploc configs: {toploc_configs}"); + + let configs = match serde_json::from_str(&toploc_configs) { + Ok(configs) => configs, + Err(e) => { + error!("Failed to parse toploc configs: {e}"); + std::process::exit(1); + } + }; + let s3_credentials = std::env::var("S3_CREDENTIALS").ok(); + + match (self.bucket_name.as_ref(), s3_credentials) { + (Some(bucket_name), Some(s3_credentials)) + if !bucket_name.is_empty() && !s3_credentials.is_empty() => + { + let gcs_storage = GcsStorageProvider::new(bucket_name, &s3_credentials) + .await + .unwrap_or_else(|_| { + panic!("Failed to create GCS storage provider") + }); + let storage_provider = Arc::new(gcs_storage); + + Some(SyntheticDataValidator::new( + pool_id, + validator, + contracts.prime_network.clone(), + configs, + penalty, + storage_provider, + redis_store, + cancellation_token.clone(), + self.toploc_work_validation_interval, + self.toploc_work_validation_unknown_status_expiry_seconds, + self.toploc_grace_interval, + self.batch_trigger_size, + self.use_grouping, + self.disable_toploc_invalidation, + self.incomplete_group_grace_period_minutes, + self.toploc_invalidation_type, + self.work_unit_invalidation_type, + Some(metrics_ctx.clone()), + )) + } + _ => { + info!("Bucket name or S3 credentials not provided, skipping synthetic data validation"); + None + } + } + } + None => { + error!("Synthetic data validator not found"); + std::process::exit(1); + } + } + } else { + None + }; + + let (validator, validator_health) = match Validator::new( + cancellation_token.clone(), + validator_wallet.provider(), + contracts, + hardware_validator, + synthetic_validator.clone(), + kademlia_action_tx, + self.disable_hardware_validation, + metrics_ctx, + ) { + Ok(v) => v, + Err(e) => { + error!("Failed to create validator: {e}"); + std::process::exit(1); + } + }; + + // Start HTTP server with access to the validator + tokio::spawn(async move { + let key = std::env::var("VALIDATOR_API_KEY").unwrap_or_default(); + let api_key_middleware = Arc::new(ApiKeyMiddleware::new(key)); + + if let Err(e) = HttpServer::new(move || { + App::new() + .app_data(web::Data::new(State { + synthetic_validator: synthetic_validator.clone(), + validator_health: validator_health.clone(), + })) + .route("/health", web::get().to(health_check)) + .route( + "/rejections", + web::get() + .to(get_rejections) + .wrap(api_key_middleware.clone()), + ) + .route( + "/metrics", + web::get().to(|| async { + match export_metrics() { + Ok(metrics) => { + HttpResponse::Ok().content_type("text/plain").body(metrics) + } + Err(e) => { + error!("Error exporting metrics: {e:?}"); + HttpResponse::InternalServerError().finish() + } + } + }), + ) + }) + .bind("0.0.0.0:9879") + .expect("Failed to bind health check server") + .run() + .await + { + error!("Actix server error: {e:?}"); + } + }); + + tokio::task::spawn(validator.run()); + Ok(()) + } +} diff --git a/crates/validator/src/handler.rs b/crates/validator/src/handler.rs new file mode 100644 index 00000000..20cf5982 --- /dev/null +++ b/crates/validator/src/handler.rs @@ -0,0 +1,104 @@ +use std::{ + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; + +use actix_web::{web, HttpRequest, HttpResponse, Responder}; +use log::{error, info}; +use serde_json::json; +use shared::models::api::ApiResponse; + +use crate::{validator, SyntheticDataValidator}; + +pub(crate) struct State { + pub synthetic_validator: Option>, + pub validator_health: Arc>, +} + +pub(crate) async fn health_check(_: HttpRequest, state: web::Data) -> impl Responder { + // Maximum allowed time between validation loops (2 minutes) + const MAX_VALIDATION_INTERVAL_SECS: u64 = 120; + + let validator_health = state.validator_health.lock().await; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + if validator_health.last_validation_timestamp() == 0 { + // Validation hasn't run yet, but we're still starting up + return HttpResponse::Ok().json(json!({ + "status": "starting", + "message": "Validation loop hasn't started yet" + })); + } + + let elapsed = now - validator_health.last_validation_timestamp(); + + if elapsed > MAX_VALIDATION_INTERVAL_SECS { + return HttpResponse::ServiceUnavailable().json(json!({ + "status": "error", + "message": format!("Validation loop hasn't run in {} seconds (max allowed: {})", elapsed, MAX_VALIDATION_INTERVAL_SECS), + "last_loop_duration_ms": validator_health.last_loop_duration_ms(), + })); + } + + HttpResponse::Ok().json(json!({ + "status": "ok", + "last_validation_seconds_ago": elapsed, + "last_loop_duration_ms": validator_health.last_loop_duration_ms(), + })) +} + +pub(crate) async fn get_rejections(req: HttpRequest, state: web::Data) -> impl Responder { + match state.synthetic_validator.as_ref() { + Some(synthetic_validator) => { + // Parse query parameters + let query = req.query_string(); + let limit = parse_limit_param(query).unwrap_or(100); // Default limit of 100 + + let result = if limit > 0 && limit < 1000 { + // Use the optimized recent rejections method for reasonable limits + synthetic_validator + .get_recent_rejections(limit as isize) + .await + } else { + // Fallback to all rejections (but warn about potential performance impact) + if limit >= 1000 { + info!("Large limit requested ({limit}), this may impact performance"); + } + synthetic_validator.get_all_rejections().await + }; + + match result { + Ok(rejections) => HttpResponse::Ok().json(ApiResponse { + success: true, + data: rejections, + }), + Err(e) => { + error!("Failed to get rejections: {e}"); + HttpResponse::InternalServerError().json(ApiResponse { + success: false, + data: format!("Failed to get rejections: {e}"), + }) + } + } + } + None => HttpResponse::ServiceUnavailable().json(ApiResponse { + success: false, + data: "Synthetic data validator not available", + }), + } +} + +fn parse_limit_param(query: &str) -> Option { + for pair in query.split('&') { + if let Some((key, value)) = pair.split_once('=') { + if key == "limit" { + return value.parse::().ok(); + } + } + } + None +} diff --git a/crates/validator/src/lib.rs b/crates/validator/src/lib.rs index e80f711c..e8aef06f 100644 --- a/crates/validator/src/lib.rs +++ b/crates/validator/src/lib.rs @@ -1,9 +1,12 @@ +mod cli; +mod handler; mod metrics; mod p2p; mod store; mod validator; mod validators; +pub use cli::Cli; pub use metrics::export_metrics; pub use metrics::MetricsContext; pub use p2p::Service as P2PService; diff --git a/crates/validator/src/main.rs b/crates/validator/src/main.rs index 6c63e72a..3818510c 100644 --- a/crates/validator/src/main.rs +++ b/crates/validator/src/main.rs @@ -1,120 +1,20 @@ -use actix_web::{web, App, HttpRequest, HttpResponse, HttpServer, Responder}; -use alloy::primitives::utils::Unit; -use alloy::primitives::U256; use clap::Parser; use log::LevelFilter; -use log::{error, info}; -use serde_json::json; -use shared::models::api::ApiResponse; -use shared::security::api_key_middleware::ApiKeyMiddleware; -use shared::utils::google_cloud::GcsStorageProvider; -use shared::web3::contracts::core::builder::ContractBuilder; -use shared::web3::wallet::Wallet; -use std::str::FromStr; -use std::sync::Arc; -use std::time::{SystemTime, UNIX_EPOCH}; -use tokio::signal::unix::{signal, SignalKind}; +use shared::utils::signal::trigger_cancellation_on_signal; use tokio_util::sync::CancellationToken; -use url::Url; -use validator::{ - export_metrics, HardwareValidator, InvalidationType, MetricsContext, P2PService, RedisStore, - SyntheticDataValidator, Validator, -}; - -#[derive(Parser)] -struct Args { - /// RPC URL - #[arg(short = 'r', long, default_value = "http://localhost:8545")] - rpc_url: String, - - /// Owner key - #[arg(short = 'k', long)] - validator_key: String, - - /// Ability to disable hardware validation - #[arg(long, default_value = "false")] - disable_hardware_validation: bool, - - /// Optional: Pool Id for work validation - /// If not provided, the validator will not validate work - #[arg(long, default_value = None)] - pool_id: Option, - - /// Optional: Toploc Grace Interval in seconds between work validation requests - #[arg(long, default_value = "15")] - toploc_grace_interval: u64, - - /// Optional: interval in minutes of max age of work on chain - #[arg(long, default_value = "15")] - toploc_work_validation_interval: u64, - - /// Optional: interval in minutes of max age of work on chain - #[arg(long, default_value = "120")] - toploc_work_validation_unknown_status_expiry_seconds: u64, - - /// Disable toploc ejection - /// If true, the validator will not invalidate work on toploc - #[arg(long, default_value = "false")] - disable_toploc_invalidation: bool, - - /// Optional: batch trigger size - #[arg(long, default_value = "10")] - batch_trigger_size: usize, - - /// Grouping - #[arg(long, default_value = "false")] - use_grouping: bool, - - /// Grace period in minutes for incomplete groups to recover (0 = disabled) - #[arg(long, default_value = "0")] - incomplete_group_grace_period_minutes: u64, - - /// Optional: toploc invalidation type - #[arg(long, default_value = "hard")] - toploc_invalidation_type: InvalidationType, - - /// Optional: work unit invalidation type - #[arg(long, default_value = "hard")] - work_unit_invalidation_type: InvalidationType, - - /// Optional: Validator penalty in whole tokens - /// Note: This value will be multiplied by 10^18 (1 token = 10^18 wei) - #[arg(long, default_value = "200")] - validator_penalty: u64, - - /// Temporary: S3 bucket name - #[arg(long, default_value = None)] - bucket_name: Option, - - /// Log level - #[arg(short = 'l', long, default_value = "info")] - log_level: String, - - /// Redis URL - #[arg(long, default_value = "redis://localhost:6380")] - redis_url: String, - - /// Libp2p port - #[arg(long, default_value = "4003")] - libp2p_port: u16, - - /// Comma-separated list of libp2p bootnode multiaddresses - /// Example: `/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ,/ip4/104.131.131.82/udp/4001/quic-v1/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ` - #[arg(long, default_value = "")] - bootnodes: String, -} +use validator::Cli; #[tokio::main] async fn main() -> anyhow::Result<()> { - let args = Args::parse(); - let log_level = match args.log_level.as_str() { + let cli = Cli::parse(); + let log_level = match cli.log_level.as_str() { "error" => LevelFilter::Error, "warn" => LevelFilter::Warn, "info" => LevelFilter::Info, "debug" => LevelFilter::Debug, "trace" => LevelFilter::Trace, - _ => anyhow::bail!("invalid log level: {}", args.log_level), + _ => anyhow::bail!("invalid log level: {}", cli.log_level), }; env_logger::Builder::new() .filter_level(log_level) @@ -124,342 +24,14 @@ async fn main() -> anyhow::Result<()> { let cancellation_token = CancellationToken::new(); - let private_key_validator = args.validator_key; - let rpc_url: Url = args.rpc_url.parse().unwrap(); - - let redis_store = RedisStore::new(&args.redis_url); - - let validator_wallet = Wallet::new(&private_key_validator, rpc_url).unwrap_or_else(|err| { - error!("Error creating wallet: {err:?}"); - std::process::exit(1); - }); - - let mut contract_builder = ContractBuilder::new(validator_wallet.provider()) - .with_compute_registry() - .with_ai_token() - .with_prime_network() - .with_compute_pool() - .with_domain_registry() - .with_stake_manager(); - - let contracts = contract_builder.build_partial().unwrap(); - let metrics_ctx = - MetricsContext::new(validator_wallet.address().to_string(), args.pool_id.clone()); - - let keypair = p2p::Keypair::generate_ed25519(); - let bootnodes: Vec = args - .bootnodes - .split(',') - .filter_map(|addr| match addr.to_string().try_into() { - Ok(multiaddr) => Some(multiaddr), - Err(e) => { - error!("Invalid bootnode address '{addr}': {e}"); - None - } - }) - .collect(); - if bootnodes.is_empty() { - error!("No valid bootnodes provided. Please provide at least one valid bootnode address."); - std::process::exit(1); - } - - let (p2p_service, hardware_challenge_tx, kademlia_action_tx) = { - match P2PService::new( - keypair, - args.libp2p_port, - bootnodes, - cancellation_token.clone(), - validator_wallet.clone(), - ) { - Ok(res) => { - info!("p2p service initialized successfully"); - res - } - Err(e) => { - error!("failed to initialize p2p service: {e}"); - std::process::exit(1); - } - } - }; - - tokio::task::spawn(p2p_service.run()); - - if let Some(pool_id) = args.pool_id.clone() { - let pool = match contracts - .compute_pool - .get_pool_info(U256::from_str(&pool_id).unwrap()) - .await - { - Ok(pool_info) => pool_info, - Err(e) => { - error!("Failed to get pool info: {e:?}"); - std::process::exit(1); - } - }; - let domain_id: u32 = pool.domain_id.try_into().unwrap(); - let domain = contracts - .domain_registry - .as_ref() - .unwrap() - .get_domain(domain_id) - .await - .unwrap(); - contract_builder = - contract_builder.with_synthetic_data_validator(Some(domain.validation_logic)); - } - - let contracts = contract_builder.build().unwrap(); - - let hardware_validator = HardwareValidator::new(contracts.clone(), hardware_challenge_tx); + let _signal_handle = trigger_cancellation_on_signal(cancellation_token.clone())?; - let synthetic_validator = if let Some(pool_id) = args.pool_id.clone() { - let penalty = U256::from(args.validator_penalty) * Unit::ETHER.wei(); - match contracts.synthetic_data_validator.clone() { - Some(validator) => { - info!( - "Synthetic validator has penalty: {} ({})", - penalty, args.validator_penalty - ); - - let Ok(toploc_configs) = std::env::var("TOPLOC_CONFIGS") else { - error!("Toploc configs are required but not provided in environment"); - std::process::exit(1); - }; - info!("Toploc configs: {toploc_configs}"); - - let configs = match serde_json::from_str(&toploc_configs) { - Ok(configs) => configs, - Err(e) => { - error!("Failed to parse toploc configs: {e}"); - std::process::exit(1); - } - }; - let s3_credentials = std::env::var("S3_CREDENTIALS").ok(); - - match (args.bucket_name.as_ref(), s3_credentials) { - (Some(bucket_name), Some(s3_credentials)) - if !bucket_name.is_empty() && !s3_credentials.is_empty() => - { - let gcs_storage = GcsStorageProvider::new(bucket_name, &s3_credentials) - .await - .unwrap_or_else(|_| panic!("Failed to create GCS storage provider")); - let storage_provider = Arc::new(gcs_storage); - - Some(SyntheticDataValidator::new( - pool_id, - validator, - contracts.prime_network.clone(), - configs, - penalty, - storage_provider, - redis_store, - cancellation_token.clone(), - args.toploc_work_validation_interval, - args.toploc_work_validation_unknown_status_expiry_seconds, - args.toploc_grace_interval, - args.batch_trigger_size, - args.use_grouping, - args.disable_toploc_invalidation, - args.incomplete_group_grace_period_minutes, - args.toploc_invalidation_type, - args.work_unit_invalidation_type, - Some(metrics_ctx.clone()), - )) - } - _ => { - info!("Bucket name or S3 credentials not provided, skipping synthetic data validation"); - None - } - } - } - None => { - error!("Synthetic data validator not found"); - std::process::exit(1); - } - } - } else { - None - }; - - let (validator, validator_health) = match Validator::new( - cancellation_token.clone(), - validator_wallet.provider(), - contracts, - hardware_validator, - synthetic_validator.clone(), - kademlia_action_tx, - args.disable_hardware_validation, - metrics_ctx, - ) { - Ok(v) => v, - Err(e) => { - error!("Failed to create validator: {e}"); - std::process::exit(1); - } - }; - - // Start HTTP server with access to the validator - tokio::spawn(async move { - let key = std::env::var("VALIDATOR_API_KEY").unwrap_or_default(); - let api_key_middleware = Arc::new(ApiKeyMiddleware::new(key)); - - if let Err(e) = HttpServer::new(move || { - App::new() - .app_data(web::Data::new(State { - synthetic_validator: synthetic_validator.clone(), - validator_health: validator_health.clone(), - })) - .route("/health", web::get().to(health_check)) - .route( - "/rejections", - web::get() - .to(get_rejections) - .wrap(api_key_middleware.clone()), - ) - .route( - "/metrics", - web::get().to(|| async { - match export_metrics() { - Ok(metrics) => { - HttpResponse::Ok().content_type("text/plain").body(metrics) - } - Err(e) => { - error!("Error exporting metrics: {e:?}"); - HttpResponse::InternalServerError().finish() - } - } - }), - ) - }) - .bind("0.0.0.0:9879") - .expect("Failed to bind health check server") - .run() - .await - { - error!("Actix server error: {e:?}"); - } - }); - - tokio::task::spawn(validator.run()); - - let mut sigterm = signal(SignalKind::terminate())?; - let mut sigint = signal(SignalKind::interrupt())?; - let mut sighup = signal(SignalKind::hangup())?; - let mut sigquit = signal(SignalKind::quit())?; - - tokio::select! { - _ = sigterm.recv() => { - log::info!("Received termination signal"); - } - _ = sigint.recv() => { - log::info!("Received interrupt signal"); - } - _ = sighup.recv() => { - log::info!("Received hangup signal"); - } - _ = sigquit.recv() => { - log::info!("Received quit signal"); - } - } - cancellation_token.cancel(); + cli.run(cancellation_token).await?; // TODO: handle spawn handles here (https://github.com/PrimeIntellect-ai/protocol/issues/627) Ok(()) } -struct State { - synthetic_validator: Option>, - validator_health: Arc>, -} - -async fn health_check(_: HttpRequest, state: web::Data) -> impl Responder { - // Maximum allowed time between validation loops (2 minutes) - const MAX_VALIDATION_INTERVAL_SECS: u64 = 120; - - let validator_health = state.validator_health.lock().await; - - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - - if validator_health.last_validation_timestamp() == 0 { - // Validation hasn't run yet, but we're still starting up - return HttpResponse::Ok().json(json!({ - "status": "starting", - "message": "Validation loop hasn't started yet" - })); - } - - let elapsed = now - validator_health.last_validation_timestamp(); - - if elapsed > MAX_VALIDATION_INTERVAL_SECS { - return HttpResponse::ServiceUnavailable().json(json!({ - "status": "error", - "message": format!("Validation loop hasn't run in {} seconds (max allowed: {})", elapsed, MAX_VALIDATION_INTERVAL_SECS), - "last_loop_duration_ms": validator_health.last_loop_duration_ms(), - })); - } - - HttpResponse::Ok().json(json!({ - "status": "ok", - "last_validation_seconds_ago": elapsed, - "last_loop_duration_ms": validator_health.last_loop_duration_ms(), - })) -} - -async fn get_rejections(req: HttpRequest, state: web::Data) -> impl Responder { - match state.synthetic_validator.as_ref() { - Some(synthetic_validator) => { - // Parse query parameters - let query = req.query_string(); - let limit = parse_limit_param(query).unwrap_or(100); // Default limit of 100 - - let result = if limit > 0 && limit < 1000 { - // Use the optimized recent rejections method for reasonable limits - synthetic_validator - .get_recent_rejections(limit as isize) - .await - } else { - // Fallback to all rejections (but warn about potential performance impact) - if limit >= 1000 { - info!("Large limit requested ({limit}), this may impact performance"); - } - synthetic_validator.get_all_rejections().await - }; - - match result { - Ok(rejections) => HttpResponse::Ok().json(ApiResponse { - success: true, - data: rejections, - }), - Err(e) => { - error!("Failed to get rejections: {e}"); - HttpResponse::InternalServerError().json(ApiResponse { - success: false, - data: format!("Failed to get rejections: {e}"), - }) - } - } - } - None => HttpResponse::ServiceUnavailable().json(ApiResponse { - success: false, - data: "Synthetic data validator not available", - }), - } -} - -fn parse_limit_param(query: &str) -> Option { - for pair in query.split('&') { - if let Some((key, value)) = pair.split_once('=') { - if key == "limit" { - return value.parse::().ok(); - } - } - } - None -} - #[cfg(test)] mod tests { use actix_web::{test, App}; diff --git a/crates/worker/src/cli.rs b/crates/worker/src/cli.rs new file mode 100644 index 00000000..d712cad3 --- /dev/null +++ b/crates/worker/src/cli.rs @@ -0,0 +1,1176 @@ +use crate::checks::hardware::HardwareChecker; +use crate::checks::issue::IssueReport; +use crate::checks::software::SoftwareChecker; +use crate::checks::stun::StunCheck; +use crate::console::Console; +use crate::docker::taskbridge::TaskBridge; +use crate::docker::DockerService; +use crate::metrics::store::MetricsStore; +use crate::operations::compute_node::ComputeNodeOperations; +use crate::operations::heartbeat::service::HeartbeatService; +use crate::operations::provider::ProviderOperations; +use crate::state::system_state::SystemState; +use crate::TaskHandles; +use alloy::primitives::utils::format_ether; +use alloy::primitives::Address; +use alloy::primitives::U256; +use alloy::signers::local::PrivateKeySigner; +use alloy::signers::Signer; +use clap::{Parser, Subcommand}; +use log::{error, info}; +use p2p::KademliaAction; +use shared::models::node::ComputeRequirements; +use shared::models::node::Node; +use shared::web3::contracts::core::builder::ContractBuilder; +use shared::web3::contracts::core::builder::Contracts; +use shared::web3::contracts::structs::compute_pool::PoolStatus; +use shared::web3::wallet::Wallet; +use shared::web3::wallet::WalletProvider; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; +use url::Url; + +const APP_VERSION: &str = match option_env!("WORKER_VERSION") { + Some(version) => version, + None => env!("CARGO_PKG_VERSION"), +}; + +#[derive(Parser)] +#[command(author, version = APP_VERSION, about, long_about = None)] +pub struct Cli { + #[command(subcommand)] + pub command: Commands, +} + +#[derive(Subcommand)] +pub enum Commands { + Run { + /// RPC URL + #[arg(long, default_value = option_env!("WORKER_RPC_URL").unwrap_or("http://localhost:8545"))] + rpc_url: String, + + /// Port number for the worker to listen on - DEPRECATED + #[arg(long, default_value = "8080")] + port: u16, + + /// Port for libp2p service + #[arg(long, default_value = "4002")] + libp2p_port: u16, + + /// Comma-separated list of libp2p bootnode multiaddresses + /// Example: `/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ,/ip4/104.131.131.82/udp/4001/quic-v1/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ` + #[arg(long, default_value = "")] + bootnodes: String, + + /// External IP address for the worker to advertise + #[arg(long)] + external_ip: Option, + + /// Compute pool ID + #[arg(long)] + compute_pool_id: u32, + + /// Dry run the command without starting the worker + #[arg(long, default_value = "false")] + dry_run: bool, + + /// Optional state storage directory overwrite + #[arg(long)] + state_dir_overwrite: Option, + + /// Disable state storing + #[arg(long, default_value = "false")] + disable_state_storing: bool, + + /// Auto recover from previous state + #[arg(long, default_value = "false")] + no_auto_recover: bool, + + /// Private key for the provider (not recommended, use environment variable PRIVATE_KEY_PROVIDER instead) + #[arg(long)] + private_key_provider: Option, + + /// Private key for the node (not recommended, use environment variable PRIVATE_KEY_NODE instead) + #[arg(long)] + private_key_node: Option, + + /// Auto accept transactions + #[arg(long, default_value = "false")] + auto_accept: bool, + + /// Retry count until provider has enough balance to stake (0 for unlimited retries) + #[arg(long, default_value = "0")] + funding_retry_count: u32, + + /// Skip system requirement checks (for development/testing) + #[arg(long, default_value = "false")] + skip_system_checks: bool, + + /// Loki URL + #[arg(long)] + loki_url: Option, + + /// Log level + #[arg(long)] + log_level: Option, + + /// Storage path for worker data (overrides automatic selection) + #[arg(long)] + storage_path: Option, + + /// Disable host network mode + #[arg(long, default_value = "false")] + disable_host_network_mode: bool, + + #[arg(long, default_value = "false")] + with_ipfs_upload: bool, + + #[arg(long, default_value = "5001")] + ipfs_port: u16, + }, + Check {}, + + /// Generate new wallets for provider and node + GenerateWallets {}, + + /// Generate new wallet for node only + GenerateNodeWallet {}, + + /// Get balance of provider and node + Balance { + /// Private key for the provider + #[arg(long)] + private_key: Option, + + /// RPC URL + #[arg(long, default_value = option_env!("WORKER_RPC_URL").unwrap_or("http://localhost:8545"))] + rpc_url: String, + }, + + /// Sign Message + SignMessage { + /// Message to sign + #[arg(long)] + message: String, + + /// Private key for the provider + #[arg(long)] + private_key_provider: Option, + + /// Private key for the node + #[arg(long)] + private_key_node: Option, + }, + + /// Deregister worker from compute pool + Deregister { + /// Private key for the provider + #[arg(long)] + private_key_provider: Option, + + /// Private key for the node + #[arg(long)] + private_key_node: Option, + + /// RPC URL + #[arg(long, default_value = option_env!("WORKER_RPC_URL").unwrap_or("http://localhost:8545"))] + rpc_url: String, + + /// Compute pool ID + #[arg(long)] + compute_pool_id: u32, + }, +} + +impl Cli { + pub async fn run( + &self, + cancellation_token: CancellationToken, + task_handles: TaskHandles, + ) -> Result<(), Box> { + match &self.command { + Commands::Run { + port: _, + libp2p_port, + bootnodes, + external_ip, + compute_pool_id, + dry_run: _, + rpc_url, + state_dir_overwrite, + disable_state_storing, + no_auto_recover, + private_key_provider, + private_key_node, + auto_accept, + funding_retry_count, + skip_system_checks, + loki_url: _, + log_level: _, + storage_path, + disable_host_network_mode, + with_ipfs_upload, + ipfs_port, + } => { + if *disable_state_storing && !(*no_auto_recover) { + Console::user_error( + "Cannot disable state storing and enable auto recover at the same time. Use --no-auto-recover to disable auto recover.", + ); + std::process::exit(1); + } + let state = match SystemState::new( + state_dir_overwrite.clone(), + *disable_state_storing, + *compute_pool_id, + ) { + Ok(state) => state, + Err(e) => { + error!("❌ Failed to initialize system state: {e}"); + std::process::exit(1); + } + }; + + let state = Arc::new(state); + + let private_key_provider = if let Some(key) = private_key_provider { + Console::warning("Using private key from command line is not recommended. Consider using PRIVATE_KEY_PROVIDER environment variable instead."); + key.clone() + } else { + std::env::var("PRIVATE_KEY_PROVIDER").expect("PRIVATE_KEY_PROVIDER must be set") + }; + + let private_key_node = if let Some(key) = private_key_node { + Console::warning("Using private key from command line is not recommended. Consider using PRIVATE_KEY_NODE environment variable instead."); + key.clone() + } else { + std::env::var("PRIVATE_KEY_NODE").expect("PRIVATE_KEY_NODE must be set") + }; + + let mut recover_last_state = !(*no_auto_recover); + let version = APP_VERSION; + Console::section("🚀 PRIME WORKER INITIALIZATION - beta"); + Console::info("Version", version); + + /* + Initialize Wallet instances + */ + let provider_wallet_instance = + match Wallet::new(&private_key_provider, Url::parse(rpc_url).unwrap()) { + Ok(wallet) => wallet, + Err(err) => { + error!("Failed to create wallet: {err}"); + std::process::exit(1); + } + }; + + let node_wallet_instance = + match Wallet::new(&private_key_node, Url::parse(rpc_url).unwrap()) { + Ok(wallet) => wallet, + Err(err) => { + error!("❌ Failed to create wallet: {err}"); + std::process::exit(1); + } + }; + + /* + Initialize dependencies - services, contracts, operations + */ + let contracts = ContractBuilder::new(provider_wallet_instance.provider()) + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_compute_pool() + .with_stake_manager() + .build() + .unwrap(); + + let provider_ops = ProviderOperations::new( + provider_wallet_instance.clone(), + contracts.clone(), + *auto_accept, + ); + + let provider_ops_cancellation = cancellation_token.clone(); + + let compute_node_state = state.clone(); + let compute_node_ops = ComputeNodeOperations::new( + &provider_wallet_instance, + &node_wallet_instance, + contracts.clone(), + compute_node_state, + ); + + let pool_id = U256::from(*compute_pool_id); + let pool_info = loop { + match contracts.compute_pool.get_pool_info(pool_id).await { + Ok(pool) if pool.status == PoolStatus::ACTIVE => break Arc::new(pool), + Ok(_) => { + Console::warning( + "Pool is not active yet. Checking again in 15 seconds.", + ); + tokio::select! { + _ = tokio::time::sleep(tokio::time::Duration::from_secs(15)) => {}, + _ = cancellation_token.cancelled() => return Ok(()), + } + } + Err(e) => { + error!("Failed to get pool info: {e}"); + return Ok(()); + } + } + }; + + let stun_check = StunCheck::new(Duration::from_secs(5), 0); + let detected_external_ip = match stun_check.get_public_ip().await { + Ok(ip) => ip, + Err(e) => { + error!("❌ Failed to get public IP: {e}"); + std::process::exit(1); + } + }; + + let node_config = Node { + id: node_wallet_instance + .wallet + .default_signer() + .address() + .to_string(), + ip_address: external_ip.clone().unwrap_or(detected_external_ip.clone()), + port: 0, + provider_address: provider_wallet_instance + .wallet + .default_signer() + .address() + .to_string(), + compute_specs: None, + compute_pool_id: *compute_pool_id, + worker_p2p_id: state.get_p2p_id().to_string(), + worker_p2p_addresses: None, + }; + + let issue_tracker = Arc::new(RwLock::new(IssueReport::new())); + let mut hardware_check = HardwareChecker::new(Some(issue_tracker.clone())); + let mut node_config = match hardware_check + .check_hardware(node_config, storage_path.clone()) + .await + { + Ok(config) => config, + Err(e) => { + Console::user_error(&format!("❌ Hardware check failed: {e}")); + std::process::exit(1); + } + }; + let software_checker = SoftwareChecker::new(Some(issue_tracker.clone())); + if let Err(err) = software_checker.check_software(&node_config).await { + Console::user_error(&format!("❌ Software check failed: {err}")); + std::process::exit(1); + } + + if let Some(external_ip) = external_ip { + if *external_ip != detected_external_ip { + Console::warning( + &format!( + "Automatically detected external IP {detected_external_ip} does not match the provided external IP {external_ip}" + ), + ); + } + } + + let issues = issue_tracker.read().await; + issues.print_issues(); + if issues.has_critical_issues() { + if !*skip_system_checks { + Console::user_error("❌ Critical issues found. Exiting."); + std::process::exit(1); + } else { + Console::warning("Critical issues found. Ignoring and continuing."); + } + } + let required_specs = match ComputeRequirements::from_str(&pool_info.pool_data_uri) { + Ok(specs) => Some(specs), + Err(e) => { + log::debug!("❌ Could not parse pool compute specs: {e}"); + None + } + }; + + // Check if node meets the pool's compute requirements + if let Some(ref compute_specs) = node_config.compute_specs { + if let Some(ref required_specs) = required_specs { + if !compute_specs.meets(required_specs) { + Console::user_error( + "❌ Your node does not meet the compute requirements for this pool.", + ); + info!("Required compute requirements:\n{required_specs}"); + if !*skip_system_checks { + std::process::exit(1); + } else { + Console::warning( + "Ignoring compute requirements mismatch and continuing.", + ); + } + } else { + Console::success( + "✅ Your node meets the compute requirements for this pool.", + ); + } + } else { + Console::success("✅ No specific compute requirements for this pool."); + } + } else { + Console::warning( + "Cannot verify compute requirements: node specs not available.", + ); + if !*skip_system_checks { + std::process::exit(1); + } else { + Console::warning("Ignoring missing compute specs and continuing."); + } + } + + let metrics_store = Arc::new(MetricsStore::new()); + let heartbeat_metrics_clone = metrics_store.clone(); + let bridge_contracts = contracts.clone(); + let bridge_wallet = node_wallet_instance.clone(); + + let ipfs = if *with_ipfs_upload { + let conn_limits = + rust_ipfs::ConnectionLimits::default().with_max_established(Some(100)); + let builder = rust_ipfs::UninitializedIpfsDefault::new() + .set_default_listener() + .with_default() + .set_connection_limits(conn_limits) + .set_listening_addrs(vec![ + format!("/ip4/0.0.0.0/tcp/{ipfs_port}") + .parse() + .expect("valid multiaddr"), + format!("/ip4/0.0.0.0/udp/{ipfs_port}/quic-v1") + .parse() + .expect("valid multiaddr"), + ]) + .listen_as_external_addr() + .with_upnp(); + + let ipfs = match builder.start().await { + Ok(ipfs) => ipfs, + Err(e) => { + error!("❌ Failed to initialize IPFS node: {e}"); + std::process::exit(1); + } + }; + + if let Err(e) = ipfs.default_bootstrap().await { + error!("❌ Failed to add default bootstrap nodes to IPFS: {e}"); + std::process::exit(1); + } + + if let Err(e) = ipfs.bootstrap().await { + error!("❌ Failed to bootstrap IPFS node: {e}"); + std::process::exit(1); + } + + Console::success("IPFS node initialized and bootstrapped successfully"); + Some(ipfs) + } else { + None + }; + + let docker_storage_path = node_config + .compute_specs + .as_ref() + .expect("Hardware check should have populated compute_specs") + .storage_path + .clone(); + let task_bridge = match TaskBridge::new( + None, + metrics_store, + Some(bridge_contracts), + Some(node_config.clone()), + Some(bridge_wallet), + docker_storage_path.clone(), + state.clone(), + ipfs, + ) { + Ok(bridge) => bridge, + Err(e) => { + error!("❌ Failed to create Task Bridge: {e}"); + std::process::exit(1); + } + }; + + let system_memory = node_config + .compute_specs + .as_ref() + .map(|specs| specs.ram_mb.unwrap_or(0)); + + let gpu = node_config + .compute_specs + .clone() + .and_then(|specs| specs.gpu.clone()); + let docker_service = Arc::new(DockerService::new( + cancellation_token.clone(), + gpu, + system_memory, + task_bridge + .get_socket_path() + .to_str() + .expect("path is valid utf-8 string") + .to_string(), + docker_storage_path, + node_wallet_instance + .wallet + .default_signer() + .address() + .to_string(), + *disable_host_network_mode, + )); + + let bridge_cancellation_token = cancellation_token.clone(); + tokio::spawn(async move { + tokio::select! { + _ = bridge_cancellation_token.cancelled() => { + } + _ = task_bridge.run() => { + } + } + }); + let heartbeat_state = state.clone(); + let heartbeat_service = HeartbeatService::new( + Duration::from_secs(10), + cancellation_token.clone(), + task_handles.clone(), + node_wallet_instance.clone(), + docker_service.clone(), + heartbeat_metrics_clone.clone(), + heartbeat_state, + ); + + let gpu_count: u32 = match &node_config.compute_specs { + Some(specs) => specs + .gpu + .as_ref() + .map(|gpu| gpu.count.unwrap_or(0)) + .unwrap_or(0), + None => 0, + }; + let compute_units = U256::from(std::cmp::max(1, gpu_count * 1000)); + + Console::section("Syncing with Network"); + + // Check if provider exists first + let provider_exists = match provider_ops.check_provider_exists().await { + Ok(exists) => exists, + Err(e) => { + error!("❌ Failed to check if provider exists: {e}"); + std::process::exit(1); + } + }; + + let Some(stake_manager) = contracts.stake_manager.as_ref() else { + error!("❌ Stake manager not initialized"); + std::process::exit(1); + }; + + Console::title("Provider Status"); + let is_whitelisted = match provider_ops.check_provider_whitelisted().await { + Ok(is_whitelisted) => is_whitelisted, + Err(e) => { + error!("Failed to check provider whitelist status: {e}"); + std::process::exit(1); + } + }; + + if provider_exists && is_whitelisted { + Console::success("Provider is registered and whitelisted"); + } else { + let required_stake = match stake_manager + .calculate_stake(compute_units, U256::from(0)) + .await + { + Ok(stake) => stake, + Err(e) => { + error!("❌ Failed to calculate required stake: {e}"); + std::process::exit(1); + } + }; + Console::info("Required stake", &format_ether(required_stake).to_string()); + + if let Err(e) = provider_ops + .retry_register_provider( + required_stake, + *funding_retry_count, + cancellation_token.clone(), + ) + .await + { + error!("❌ Failed to register provider: {e}"); + std::process::exit(1); + } + } + + let compute_node_exists = match compute_node_ops.check_compute_node_exists().await { + Ok(exists) => exists, + Err(e) => { + error!("❌ Failed to check if compute node exists: {e}"); + std::process::exit(1); + } + }; + + let provider_total_compute = match contracts + .compute_registry + .get_provider_total_compute( + provider_wallet_instance.wallet.default_signer().address(), + ) + .await + { + Ok(compute) => compute, + Err(e) => { + error!("❌ Failed to get provider total compute: {e}"); + std::process::exit(1); + } + }; + + let provider_stake = stake_manager + .get_stake(provider_wallet_instance.wallet.default_signer().address()) + .await + .unwrap_or_default(); + + // If we are already registered we do not need additionally compute units + let compute_units = match compute_node_exists { + true => U256::from(0), + false => compute_units, + }; + + let required_stake = match stake_manager + .calculate_stake(compute_units, provider_total_compute) + .await + { + Ok(stake) => stake, + Err(e) => { + error!("❌ Failed to calculate required stake: {e}"); + std::process::exit(1); + } + }; + + if required_stake > provider_stake { + Console::info( + "Provider stake is less than required stake", + &format!( + "Required: {} tokens, Current: {} tokens", + format_ether(required_stake), + format_ether(provider_stake) + ), + ); + + match provider_ops + .increase_stake(required_stake - provider_stake) + .await + { + Ok(_) => { + Console::success("Successfully increased stake"); + } + Err(e) => { + error!("❌ Failed to increase stake: {e}"); + std::process::exit(1); + } + } + } + + Console::title("Compute Node Status"); + if compute_node_exists { + // TODO: What if we have two nodes? + Console::success("Compute node is registered"); + recover_last_state = true; + } else { + match compute_node_ops.add_compute_node(compute_units).await { + Ok(added_node) => { + if added_node { + // If we are adding a new compute node we wait for a proper + // invite and do not recover from previous state + recover_last_state = false; + } + } + Err(e) => { + error!("❌ Failed to add compute node: {e}"); + std::process::exit(1); + } + } + } + + // Start P2P service + Console::title("🔗 Starting P2P Service"); + let heartbeat = match heartbeat_service.clone() { + Ok(service) => service, + Err(e) => { + error!("❌ Heartbeat service is not available: {e}"); + std::process::exit(1); + } + }; + + let validators = match contracts.prime_network.get_validator_role().await { + Ok(validators) => validators, + Err(e) => { + error!("Failed to get validator role: {e}"); + std::process::exit(1); + } + }; + + if validators.is_empty() { + error!("❌ No validator roles found on contracts - cannot start worker without validators"); + error!("This means the smart contract has no registered validators, which is required for signature validation"); + error!("Please ensure validators are properly registered on the PrimeNetwork contract before starting the worker"); + std::process::exit(1); + } + + let mut allowed_addresses = vec![pool_info.creator, pool_info.compute_manager_key]; + allowed_addresses.extend(validators); + + let validator_addresses = std::collections::HashSet::from_iter(allowed_addresses); + let bootnodes: Vec = bootnodes + .split(',') + .filter_map(|addr| match addr.to_string().try_into() { + Ok(multiaddr) => Some(multiaddr), + Err(e) => { + error!("❌ Invalid bootnode address '{addr}': {e}"); + None + } + }) + .collect(); + if bootnodes.is_empty() { + error!("❌ No valid bootnodes provided. Please provide at least one valid bootnode address."); + std::process::exit(1); + } + + let (p2p_service, kademlia_action_tx) = match crate::p2p::Service::new( + state.get_p2p_keypair().clone(), + *libp2p_port, + bootnodes, + node_wallet_instance.clone(), + validator_addresses, + docker_service.clone(), + heartbeat.clone(), + state.clone(), + contracts.clone(), + provider_wallet_instance.clone(), + cancellation_token.clone(), + ) { + Ok(service) => service, + Err(e) => { + error!("❌ Failed to start P2P service: {e}"); + std::process::exit(1); + } + }; + + let peer_id = p2p_service.peer_id(); + node_config.worker_p2p_id = p2p_service.peer_id().to_string(); + let external_p2p_address = + format!("/ip4/{}/tcp/{}", node_config.ip_address, *libp2p_port); + node_config.worker_p2p_addresses = Some( + p2p_service + .listen_addrs() + .iter() + .map(|addr| addr.to_string()) + .chain(vec![external_p2p_address]) + .collect(), + ); + tokio::task::spawn(p2p_service.run()); + + Console::success(&format!("P2P service started with ID: {peer_id}")); + + // sleep so that dht is bootstrapped before publishing. + // TODO: should update p2p service to expose this better (https://github.com/PrimeIntellect-ai/protocol/issues/628) + tokio::time::sleep(Duration::from_secs(1)).await; + + let record_key = p2p::worker_dht_key_with_peer_id(&peer_id); + let (kad_action, mut result_rx) = KademliaAction::PutRecord { + key: record_key.as_bytes().to_vec(), + value: serde_json::to_vec(&node_config).unwrap(), + } + .into_kademlia_action_with_channel(); + if let Err(e) = kademlia_action_tx.send(kad_action).await { + error!("❌ Failed to send Kademlia action: {e}"); + std::process::exit(1); + } + + while let Some(result) = result_rx.recv().await { + match result { + Ok(res) => { + match res { + p2p::KademliaQueryResult::PutRecord(res) => match res { + Ok(_) => { + Console::success("Worker info published to DHT"); + } + Err(e) => { + error!("❌ Failed to put record in DHT: {e}"); + } + }, + _ => { + // this case should never happen + error!( + "❌ Unexpected result from putting record in DHT: {res:?}" + ); + std::process::exit(1); + } + } + } + Err(e) => { + error!("❌ Failed to publish worker info to DHT: {e}"); + std::process::exit(1); + } + } + } + + let (kad_action, mut result_rx) = + KademliaAction::StartProviding(p2p::WORKER_DHT_KEY.as_bytes().to_vec()) + .into_kademlia_action_with_channel(); + if let Err(e) = kademlia_action_tx.send(kad_action).await { + error!("❌ Failed to send Kademlia action: {e}"); + std::process::exit(1); + } + + while let Some(result) = result_rx.recv().await { + match result { + Ok(res) => { + match res { + p2p::KademliaQueryResult::StartProviding(res) => { + match res { + Ok(_) => { + Console::success( + "Advertising ourselves as a worker in the DHT", + ); + } + Err(e) => { + error!("❌ Failed to start providing worker info in DHT: {e}"); + std::process::exit(1); + } + } + } + _ => { + // this case should never happen + error!("❌ Unexpected result from starting providing worker info in DHT: {res:?}"); + std::process::exit(1); + } + } + } + Err(e) => { + error!("❌ Failed to start providing worker info in DHT: {e}"); + std::process::exit(1); + } + } + } + + Console::section("Starting Worker with Task Bridge"); + + // Start monitoring compute node status on chain + provider_ops.start_monitoring(provider_ops_cancellation); + + let pool_id = state.get_compute_pool_id(); + if let Err(err) = + compute_node_ops.start_monitoring(cancellation_token.clone(), pool_id) + { + error!("❌ Failed to start node monitoring: {err}"); + std::process::exit(1); + } + + if recover_last_state { + info!("Recovering from previous state: {recover_last_state}"); + heartbeat.activate_heartbeat_if_endpoint_exists().await; + } + + // Keep the worker running and listening for P2P connections + Console::success("Worker is now running and listening for P2P connections..."); + + // Wait for cancellation signal to gracefully shutdown + cancellation_token.cancelled().await; + + Console::info( + "Shutdown signal received", + "Gracefully shutting down worker...", + ); + + Ok(()) + } + Commands::Check {} => { + Console::section("🔍 PRIME WORKER SYSTEM CHECK"); + let issues = Arc::new(RwLock::new(IssueReport::new())); + + // Run checks + let mut hardware_checker = HardwareChecker::new(Some(issues.clone())); + let software_checker = SoftwareChecker::new(Some(issues.clone())); + let node_config = Node { + id: String::new(), + ip_address: String::new(), + port: 0, + compute_specs: None, + provider_address: String::new(), + compute_pool_id: 0, + worker_p2p_id: "empty".to_string(), // TODO: this should be a different type, as peer id is not needed for this code path + worker_p2p_addresses: None, + }; + + let node_config = match hardware_checker.check_hardware(node_config, None).await { + Ok(node_config) => node_config, + Err(err) => { + Console::user_error(&format!("❌ Hardware check failed: {err}")); + std::process::exit(1); + } + }; + + if let Err(err) = software_checker.check_software(&node_config).await { + Console::user_error(&format!("❌ Software check failed: {err}")); + std::process::exit(1); + } + + let issues = issues.read().await; + issues.print_issues(); + + if issues.has_critical_issues() { + Console::user_error("❌ Critical issues found. Exiting."); + std::process::exit(1); + } + + Ok(()) + } + Commands::GenerateWallets {} => { + let provider_signer = PrivateKeySigner::random(); + let node_signer = PrivateKeySigner::random(); + + let provider_key = hex::encode(provider_signer.credential().to_bytes()); + let node_key = hex::encode(node_signer.credential().to_bytes()); + + println!("Provider wallet:"); + println!(" Address: {}", provider_signer.address()); + println!(" Private key: {provider_key}"); + println!("\nNode wallet:"); + println!(" Address: {}", node_signer.address()); + println!(" Private key: {node_key}"); + println!("\nTo set environment variables in your current shell session:"); + println!("export PRIVATE_KEY_PROVIDER={provider_key}"); + println!("export PRIVATE_KEY_NODE={node_key}"); + + Ok(()) + } + + Commands::GenerateNodeWallet {} => { + let node_signer = PrivateKeySigner::random(); + let node_key = hex::encode(node_signer.credential().to_bytes()); + + println!("Node wallet:"); + println!(" Address: {}", node_signer.address()); + println!(" Private key: {node_key}"); + println!("\nTo set environment variable in your current shell session:"); + println!("export PRIVATE_KEY_NODE={node_key}"); + + Ok(()) + } + + Commands::Balance { + private_key, + rpc_url, + } => { + let private_key = if let Some(key) = private_key { + key.clone() + } else { + std::env::var("PRIVATE_KEY_PROVIDER").expect("PRIVATE_KEY_PROVIDER must be set") + }; + + let provider_wallet = + Wallet::new(&private_key, Url::parse(rpc_url).unwrap()).unwrap(); + + let contracts = ContractBuilder::new(provider_wallet.provider()) + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_compute_pool() + .build() + .unwrap(); + + let provider_balance = contracts + .ai_token + .balance_of(provider_wallet.wallet.default_signer().address()) + .await + .unwrap(); + + let format_balance = format_ether(provider_balance).to_string(); + + println!("Provider balance: {format_balance}"); + Ok(()) + } + Commands::SignMessage { + message, + private_key_provider, + private_key_node, + } => { + let private_key_provider = if let Some(key) = private_key_provider { + key.clone() + } else { + std::env::var("PRIVATE_KEY_PROVIDER").expect("PRIVATE_KEY_PROVIDER must be set") + }; + + let private_key_node = if let Some(key) = private_key_node { + key.clone() + } else { + std::env::var("PRIVATE_KEY_NODE").expect("PRIVATE_KEY_NODE must be set") + }; + + let provider_wallet = Wallet::new( + &private_key_provider, + Url::parse("http://localhost:8545").unwrap(), + ) + .unwrap(); + let node_wallet = Wallet::new( + &private_key_node, + Url::parse("http://localhost:8545").unwrap(), + ) + .unwrap(); + + let message_hash = provider_wallet.signer.sign_message(message.as_bytes()); + let node_signature = node_wallet.signer.sign_message(message.as_bytes()); + + let provider_signature = message_hash.await?; + let node_signature = node_signature.await?; + let combined_signature = + [provider_signature.as_bytes(), node_signature.as_bytes()].concat(); + + println!("\nSignature: {}", hex::encode(combined_signature)); + + Ok(()) + } + Commands::Deregister { + private_key_provider, + private_key_node, + rpc_url, + compute_pool_id, + } => { + let private_key_provider = if let Some(key) = private_key_provider { + key.clone() + } else { + std::env::var("PRIVATE_KEY_PROVIDER").expect("PRIVATE_KEY_PROVIDER must be set") + }; + + let private_key_node = if let Some(key) = private_key_node { + key.clone() + } else { + std::env::var("PRIVATE_KEY_NODE").expect("PRIVATE_KEY_NODE must be set") + }; + + let provider_wallet_instance = + match Wallet::new(&private_key_provider, Url::parse(rpc_url).unwrap()) { + Ok(wallet) => wallet, + Err(err) => { + Console::user_error(&format!("Failed to create wallet: {err}")); + std::process::exit(1); + } + }; + + let node_wallet_instance = + match Wallet::new(&private_key_node, Url::parse(rpc_url).unwrap()) { + Ok(wallet) => wallet, + Err(err) => { + Console::user_error(&format!("❌ Failed to create wallet: {err}")); + std::process::exit(1); + } + }; + + /* + Initialize dependencies - services, contracts, operations + */ + + let contracts = ContractBuilder::new(provider_wallet_instance.provider()) + .with_compute_registry() + .with_ai_token() + .with_prime_network() + .with_compute_pool() + .with_stake_manager() + .build() + .unwrap(); + + let provider_address = provider_wallet_instance.wallet.default_signer().address(); + let node_address = node_wallet_instance.wallet.default_signer().address(); + + let provider_ops = ProviderOperations::new( + provider_wallet_instance.clone(), + contracts.clone(), + false, + ); + + let compute_node_exists = match contracts + .compute_registry + .get_node(provider_address, node_address) + .await + { + Ok(_) => true, + Err(e) => { + Console::user_error(&format!( + "❌ Failed to check if compute node exists: {e}" + )); + std::process::exit(1); + } + }; + + let pool_id = U256::from(*compute_pool_id); + + if compute_node_exists { + match contracts + .compute_pool + .leave_compute_pool( + pool_id, + provider_wallet_instance.wallet.default_signer().address(), + node_wallet_instance.wallet.default_signer().address(), + ) + .await + { + Ok(result) => { + Console::success(&format!("Leave compute pool tx: {result:?}")); + } + Err(e) => { + Console::user_error(&format!("❌ Failed to leave compute pool: {e}")); + std::process::exit(1); + } + } + match remove_compute_node(contracts, provider_address, node_address).await { + Ok(_removed_node) => { + Console::success("Compute node removed"); + match provider_ops.reclaim_stake(U256::from(0)).await { + Ok(_) => { + Console::success("Successfully reclaimed stake"); + } + Err(e) => { + Console::user_error(&format!( + "❌ Failed to reclaim stake: {e}" + )); + std::process::exit(1); + } + } + } + Err(e) => { + Console::user_error(&format!("❌ Failed to remove compute node: {e}")); + std::process::exit(1); + } + } + } else { + Console::success("Compute node is not registered"); + } + + Ok(()) + } + } + } +} + +async fn remove_compute_node( + contracts: Contracts, + provider_address: Address, + node_address: Address, +) -> Result> { + Console::title("🔄 Removing compute node"); + let remove_node_tx = contracts + .prime_network + .remove_compute_node(provider_address, node_address) + .await?; + Console::success(&format!("Remove node tx: {remove_node_tx:?}")); + Ok(true) +} diff --git a/crates/worker/src/cli/command.rs b/crates/worker/src/cli/command.rs deleted file mode 100644 index 9215b379..00000000 --- a/crates/worker/src/cli/command.rs +++ /dev/null @@ -1,1157 +0,0 @@ -use crate::checks::hardware::HardwareChecker; -use crate::checks::issue::IssueReport; -use crate::checks::software::SoftwareChecker; -use crate::checks::stun::StunCheck; -use crate::console::Console; -use crate::docker::taskbridge::TaskBridge; -use crate::docker::DockerService; -use crate::metrics::store::MetricsStore; -use crate::operations::compute_node::ComputeNodeOperations; -use crate::operations::heartbeat::service::HeartbeatService; -use crate::operations::provider::ProviderOperations; -use crate::state::system_state::SystemState; -use crate::TaskHandles; -use alloy::primitives::utils::format_ether; -use alloy::primitives::Address; -use alloy::primitives::U256; -use alloy::signers::local::PrivateKeySigner; -use alloy::signers::Signer; -use clap::{Parser, Subcommand}; -use log::{error, info}; -use p2p::KademliaAction; -use shared::models::node::ComputeRequirements; -use shared::models::node::Node; -use shared::web3::contracts::core::builder::ContractBuilder; -use shared::web3::contracts::core::builder::Contracts; -use shared::web3::contracts::structs::compute_pool::PoolStatus; -use shared::web3::wallet::Wallet; -use shared::web3::wallet::WalletProvider; -use std::str::FromStr; -use std::sync::Arc; -use std::time::Duration; -use tokio::sync::RwLock; -use tokio_util::sync::CancellationToken; -use url::Url; - -const APP_VERSION: &str = match option_env!("WORKER_VERSION") { - Some(version) => version, - None => env!("CARGO_PKG_VERSION"), -}; - -#[derive(Parser)] -#[command(author, version = APP_VERSION, about, long_about = None)] -pub struct Cli { - #[command(subcommand)] - pub command: Commands, -} - -#[derive(Subcommand)] -pub enum Commands { - Run { - /// RPC URL - #[arg(long, default_value = option_env!("WORKER_RPC_URL").unwrap_or("http://localhost:8545"))] - rpc_url: String, - - /// Port number for the worker to listen on - DEPRECATED - #[arg(long, default_value = "8080")] - port: u16, - - /// Port for libp2p service - #[arg(long, default_value = "4002")] - libp2p_port: u16, - - /// Comma-separated list of libp2p bootnode multiaddresses - /// Example: `/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ,/ip4/104.131.131.82/udp/4001/quic-v1/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ` - #[arg(long, default_value = "")] - bootnodes: String, - - /// External IP address for the worker to advertise - #[arg(long)] - external_ip: Option, - - /// Compute pool ID - #[arg(long)] - compute_pool_id: u32, - - /// Dry run the command without starting the worker - #[arg(long, default_value = "false")] - dry_run: bool, - - /// Optional state storage directory overwrite - #[arg(long)] - state_dir_overwrite: Option, - - /// Disable state storing - #[arg(long, default_value = "false")] - disable_state_storing: bool, - - /// Auto recover from previous state - #[arg(long, default_value = "false")] - no_auto_recover: bool, - - /// Private key for the provider (not recommended, use environment variable PRIVATE_KEY_PROVIDER instead) - #[arg(long)] - private_key_provider: Option, - - /// Private key for the node (not recommended, use environment variable PRIVATE_KEY_NODE instead) - #[arg(long)] - private_key_node: Option, - - /// Auto accept transactions - #[arg(long, default_value = "false")] - auto_accept: bool, - - /// Retry count until provider has enough balance to stake (0 for unlimited retries) - #[arg(long, default_value = "0")] - funding_retry_count: u32, - - /// Skip system requirement checks (for development/testing) - #[arg(long, default_value = "false")] - skip_system_checks: bool, - - /// Loki URL - #[arg(long)] - loki_url: Option, - - /// Log level - #[arg(long)] - log_level: Option, - - /// Storage path for worker data (overrides automatic selection) - #[arg(long)] - storage_path: Option, - - /// Disable host network mode - #[arg(long, default_value = "false")] - disable_host_network_mode: bool, - - #[arg(long, default_value = "false")] - with_ipfs_upload: bool, - - #[arg(long, default_value = "5001")] - ipfs_port: u16, - }, - Check {}, - - /// Generate new wallets for provider and node - GenerateWallets {}, - - /// Generate new wallet for node only - GenerateNodeWallet {}, - - /// Get balance of provider and node - Balance { - /// Private key for the provider - #[arg(long)] - private_key: Option, - - /// RPC URL - #[arg(long, default_value = option_env!("WORKER_RPC_URL").unwrap_or("http://localhost:8545"))] - rpc_url: String, - }, - - /// Sign Message - SignMessage { - /// Message to sign - #[arg(long)] - message: String, - - /// Private key for the provider - #[arg(long)] - private_key_provider: Option, - - /// Private key for the node - #[arg(long)] - private_key_node: Option, - }, - - /// Deregister worker from compute pool - Deregister { - /// Private key for the provider - #[arg(long)] - private_key_provider: Option, - - /// Private key for the node - #[arg(long)] - private_key_node: Option, - - /// RPC URL - #[arg(long, default_value = option_env!("WORKER_RPC_URL").unwrap_or("http://localhost:8545"))] - rpc_url: String, - - /// Compute pool ID - #[arg(long)] - compute_pool_id: u32, - }, -} - -pub async fn execute_command( - command: &Commands, - cancellation_token: CancellationToken, - task_handles: TaskHandles, -) -> Result<(), Box> { - match command { - Commands::Run { - port: _, - libp2p_port, - bootnodes, - external_ip, - compute_pool_id, - dry_run: _, - rpc_url, - state_dir_overwrite, - disable_state_storing, - no_auto_recover, - private_key_provider, - private_key_node, - auto_accept, - funding_retry_count, - skip_system_checks, - loki_url: _, - log_level: _, - storage_path, - disable_host_network_mode, - with_ipfs_upload, - ipfs_port, - } => { - if *disable_state_storing && !(*no_auto_recover) { - Console::user_error( - "Cannot disable state storing and enable auto recover at the same time. Use --no-auto-recover to disable auto recover.", - ); - std::process::exit(1); - } - let state = match SystemState::new( - state_dir_overwrite.clone(), - *disable_state_storing, - *compute_pool_id, - ) { - Ok(state) => state, - Err(e) => { - error!("❌ Failed to initialize system state: {e}"); - std::process::exit(1); - } - }; - - let state = Arc::new(state); - - let private_key_provider = if let Some(key) = private_key_provider { - Console::warning("Using private key from command line is not recommended. Consider using PRIVATE_KEY_PROVIDER environment variable instead."); - key.clone() - } else { - std::env::var("PRIVATE_KEY_PROVIDER").expect("PRIVATE_KEY_PROVIDER must be set") - }; - - let private_key_node = if let Some(key) = private_key_node { - Console::warning("Using private key from command line is not recommended. Consider using PRIVATE_KEY_NODE environment variable instead."); - key.clone() - } else { - std::env::var("PRIVATE_KEY_NODE").expect("PRIVATE_KEY_NODE must be set") - }; - - let mut recover_last_state = !(*no_auto_recover); - let version = APP_VERSION; - Console::section("🚀 PRIME WORKER INITIALIZATION - beta"); - Console::info("Version", version); - - /* - Initialize Wallet instances - */ - let provider_wallet_instance = - match Wallet::new(&private_key_provider, Url::parse(rpc_url).unwrap()) { - Ok(wallet) => wallet, - Err(err) => { - error!("Failed to create wallet: {err}"); - std::process::exit(1); - } - }; - - let node_wallet_instance = - match Wallet::new(&private_key_node, Url::parse(rpc_url).unwrap()) { - Ok(wallet) => wallet, - Err(err) => { - error!("❌ Failed to create wallet: {err}"); - std::process::exit(1); - } - }; - - /* - Initialize dependencies - services, contracts, operations - */ - let contracts = ContractBuilder::new(provider_wallet_instance.provider()) - .with_compute_registry() - .with_ai_token() - .with_prime_network() - .with_compute_pool() - .with_stake_manager() - .build() - .unwrap(); - - let provider_ops = ProviderOperations::new( - provider_wallet_instance.clone(), - contracts.clone(), - *auto_accept, - ); - - let provider_ops_cancellation = cancellation_token.clone(); - - let compute_node_state = state.clone(); - let compute_node_ops = ComputeNodeOperations::new( - &provider_wallet_instance, - &node_wallet_instance, - contracts.clone(), - compute_node_state, - ); - - let pool_id = U256::from(*compute_pool_id); - let pool_info = loop { - match contracts.compute_pool.get_pool_info(pool_id).await { - Ok(pool) if pool.status == PoolStatus::ACTIVE => break Arc::new(pool), - Ok(_) => { - Console::warning("Pool is not active yet. Checking again in 15 seconds."); - tokio::select! { - _ = tokio::time::sleep(tokio::time::Duration::from_secs(15)) => {}, - _ = cancellation_token.cancelled() => return Ok(()), - } - } - Err(e) => { - error!("Failed to get pool info: {e}"); - return Ok(()); - } - } - }; - - let stun_check = StunCheck::new(Duration::from_secs(5), 0); - let detected_external_ip = match stun_check.get_public_ip().await { - Ok(ip) => ip, - Err(e) => { - error!("❌ Failed to get public IP: {e}"); - std::process::exit(1); - } - }; - - let node_config = Node { - id: node_wallet_instance - .wallet - .default_signer() - .address() - .to_string(), - ip_address: external_ip.clone().unwrap_or(detected_external_ip.clone()), - port: 0, - provider_address: provider_wallet_instance - .wallet - .default_signer() - .address() - .to_string(), - compute_specs: None, - compute_pool_id: *compute_pool_id, - worker_p2p_id: state.get_p2p_id().to_string(), - worker_p2p_addresses: None, - }; - - let issue_tracker = Arc::new(RwLock::new(IssueReport::new())); - let mut hardware_check = HardwareChecker::new(Some(issue_tracker.clone())); - let mut node_config = match hardware_check - .check_hardware(node_config, storage_path.clone()) - .await - { - Ok(config) => config, - Err(e) => { - Console::user_error(&format!("❌ Hardware check failed: {e}")); - std::process::exit(1); - } - }; - let software_checker = SoftwareChecker::new(Some(issue_tracker.clone())); - if let Err(err) = software_checker.check_software(&node_config).await { - Console::user_error(&format!("❌ Software check failed: {err}")); - std::process::exit(1); - } - - if let Some(external_ip) = external_ip { - if *external_ip != detected_external_ip { - Console::warning( - &format!( - "Automatically detected external IP {detected_external_ip} does not match the provided external IP {external_ip}" - ), - ); - } - } - - let issues = issue_tracker.read().await; - issues.print_issues(); - if issues.has_critical_issues() { - if !*skip_system_checks { - Console::user_error("❌ Critical issues found. Exiting."); - std::process::exit(1); - } else { - Console::warning("Critical issues found. Ignoring and continuing."); - } - } - let required_specs = match ComputeRequirements::from_str(&pool_info.pool_data_uri) { - Ok(specs) => Some(specs), - Err(e) => { - log::debug!("❌ Could not parse pool compute specs: {e}"); - None - } - }; - - // Check if node meets the pool's compute requirements - if let Some(ref compute_specs) = node_config.compute_specs { - if let Some(ref required_specs) = required_specs { - if !compute_specs.meets(required_specs) { - Console::user_error( - "❌ Your node does not meet the compute requirements for this pool.", - ); - info!("Required compute requirements:\n{required_specs}"); - if !*skip_system_checks { - std::process::exit(1); - } else { - Console::warning( - "Ignoring compute requirements mismatch and continuing.", - ); - } - } else { - Console::success( - "✅ Your node meets the compute requirements for this pool.", - ); - } - } else { - Console::success("✅ No specific compute requirements for this pool."); - } - } else { - Console::warning("Cannot verify compute requirements: node specs not available."); - if !*skip_system_checks { - std::process::exit(1); - } else { - Console::warning("Ignoring missing compute specs and continuing."); - } - } - - let metrics_store = Arc::new(MetricsStore::new()); - let heartbeat_metrics_clone = metrics_store.clone(); - let bridge_contracts = contracts.clone(); - let bridge_wallet = node_wallet_instance.clone(); - - let ipfs = if *with_ipfs_upload { - let conn_limits = - rust_ipfs::ConnectionLimits::default().with_max_established(Some(100)); - let builder = rust_ipfs::UninitializedIpfsDefault::new() - .set_default_listener() - .with_default() - .set_connection_limits(conn_limits) - .set_listening_addrs(vec![ - format!("/ip4/0.0.0.0/tcp/{ipfs_port}") - .parse() - .expect("valid multiaddr"), - format!("/ip4/0.0.0.0/udp/{ipfs_port}/quic-v1") - .parse() - .expect("valid multiaddr"), - ]) - .listen_as_external_addr() - .with_upnp(); - - let ipfs = match builder.start().await { - Ok(ipfs) => ipfs, - Err(e) => { - error!("❌ Failed to initialize IPFS node: {e}"); - std::process::exit(1); - } - }; - - if let Err(e) = ipfs.default_bootstrap().await { - error!("❌ Failed to add default bootstrap nodes to IPFS: {e}"); - std::process::exit(1); - } - - if let Err(e) = ipfs.bootstrap().await { - error!("❌ Failed to bootstrap IPFS node: {e}"); - std::process::exit(1); - } - - Console::success("IPFS node initialized and bootstrapped successfully"); - Some(ipfs) - } else { - None - }; - - let docker_storage_path = node_config - .compute_specs - .as_ref() - .expect("Hardware check should have populated compute_specs") - .storage_path - .clone(); - let task_bridge = match TaskBridge::new( - None, - metrics_store, - Some(bridge_contracts), - Some(node_config.clone()), - Some(bridge_wallet), - docker_storage_path.clone(), - state.clone(), - ipfs, - ) { - Ok(bridge) => bridge, - Err(e) => { - error!("❌ Failed to create Task Bridge: {e}"); - std::process::exit(1); - } - }; - - let system_memory = node_config - .compute_specs - .as_ref() - .map(|specs| specs.ram_mb.unwrap_or(0)); - - let gpu = node_config - .compute_specs - .clone() - .and_then(|specs| specs.gpu.clone()); - let docker_service = Arc::new(DockerService::new( - cancellation_token.clone(), - gpu, - system_memory, - task_bridge - .get_socket_path() - .to_str() - .expect("path is valid utf-8 string") - .to_string(), - docker_storage_path, - node_wallet_instance - .wallet - .default_signer() - .address() - .to_string(), - *disable_host_network_mode, - )); - - let bridge_cancellation_token = cancellation_token.clone(); - tokio::spawn(async move { - tokio::select! { - _ = bridge_cancellation_token.cancelled() => { - } - _ = task_bridge.run() => { - } - } - }); - let heartbeat_state = state.clone(); - let heartbeat_service = HeartbeatService::new( - Duration::from_secs(10), - cancellation_token.clone(), - task_handles.clone(), - node_wallet_instance.clone(), - docker_service.clone(), - heartbeat_metrics_clone.clone(), - heartbeat_state, - ); - - let gpu_count: u32 = match &node_config.compute_specs { - Some(specs) => specs - .gpu - .as_ref() - .map(|gpu| gpu.count.unwrap_or(0)) - .unwrap_or(0), - None => 0, - }; - let compute_units = U256::from(std::cmp::max(1, gpu_count * 1000)); - - Console::section("Syncing with Network"); - - // Check if provider exists first - let provider_exists = match provider_ops.check_provider_exists().await { - Ok(exists) => exists, - Err(e) => { - error!("❌ Failed to check if provider exists: {e}"); - std::process::exit(1); - } - }; - - let Some(stake_manager) = contracts.stake_manager.as_ref() else { - error!("❌ Stake manager not initialized"); - std::process::exit(1); - }; - - Console::title("Provider Status"); - let is_whitelisted = match provider_ops.check_provider_whitelisted().await { - Ok(is_whitelisted) => is_whitelisted, - Err(e) => { - error!("Failed to check provider whitelist status: {e}"); - std::process::exit(1); - } - }; - - if provider_exists && is_whitelisted { - Console::success("Provider is registered and whitelisted"); - } else { - let required_stake = match stake_manager - .calculate_stake(compute_units, U256::from(0)) - .await - { - Ok(stake) => stake, - Err(e) => { - error!("❌ Failed to calculate required stake: {e}"); - std::process::exit(1); - } - }; - Console::info("Required stake", &format_ether(required_stake).to_string()); - - if let Err(e) = provider_ops - .retry_register_provider( - required_stake, - *funding_retry_count, - cancellation_token.clone(), - ) - .await - { - error!("❌ Failed to register provider: {e}"); - std::process::exit(1); - } - } - - let compute_node_exists = match compute_node_ops.check_compute_node_exists().await { - Ok(exists) => exists, - Err(e) => { - error!("❌ Failed to check if compute node exists: {e}"); - std::process::exit(1); - } - }; - - let provider_total_compute = match contracts - .compute_registry - .get_provider_total_compute( - provider_wallet_instance.wallet.default_signer().address(), - ) - .await - { - Ok(compute) => compute, - Err(e) => { - error!("❌ Failed to get provider total compute: {e}"); - std::process::exit(1); - } - }; - - let provider_stake = stake_manager - .get_stake(provider_wallet_instance.wallet.default_signer().address()) - .await - .unwrap_or_default(); - - // If we are already registered we do not need additionally compute units - let compute_units = match compute_node_exists { - true => U256::from(0), - false => compute_units, - }; - - let required_stake = match stake_manager - .calculate_stake(compute_units, provider_total_compute) - .await - { - Ok(stake) => stake, - Err(e) => { - error!("❌ Failed to calculate required stake: {e}"); - std::process::exit(1); - } - }; - - if required_stake > provider_stake { - Console::info( - "Provider stake is less than required stake", - &format!( - "Required: {} tokens, Current: {} tokens", - format_ether(required_stake), - format_ether(provider_stake) - ), - ); - - match provider_ops - .increase_stake(required_stake - provider_stake) - .await - { - Ok(_) => { - Console::success("Successfully increased stake"); - } - Err(e) => { - error!("❌ Failed to increase stake: {e}"); - std::process::exit(1); - } - } - } - - Console::title("Compute Node Status"); - if compute_node_exists { - // TODO: What if we have two nodes? - Console::success("Compute node is registered"); - recover_last_state = true; - } else { - match compute_node_ops.add_compute_node(compute_units).await { - Ok(added_node) => { - if added_node { - // If we are adding a new compute node we wait for a proper - // invite and do not recover from previous state - recover_last_state = false; - } - } - Err(e) => { - error!("❌ Failed to add compute node: {e}"); - std::process::exit(1); - } - } - } - - // Start P2P service - Console::title("🔗 Starting P2P Service"); - let heartbeat = match heartbeat_service.clone() { - Ok(service) => service, - Err(e) => { - error!("❌ Heartbeat service is not available: {e}"); - std::process::exit(1); - } - }; - - let validators = match contracts.prime_network.get_validator_role().await { - Ok(validators) => validators, - Err(e) => { - error!("Failed to get validator role: {e}"); - std::process::exit(1); - } - }; - - if validators.is_empty() { - error!("❌ No validator roles found on contracts - cannot start worker without validators"); - error!("This means the smart contract has no registered validators, which is required for signature validation"); - error!("Please ensure validators are properly registered on the PrimeNetwork contract before starting the worker"); - std::process::exit(1); - } - - let mut allowed_addresses = vec![pool_info.creator, pool_info.compute_manager_key]; - allowed_addresses.extend(validators); - - let validator_addresses = std::collections::HashSet::from_iter(allowed_addresses); - let bootnodes: Vec = bootnodes - .split(',') - .filter_map(|addr| match addr.to_string().try_into() { - Ok(multiaddr) => Some(multiaddr), - Err(e) => { - error!("❌ Invalid bootnode address '{addr}': {e}"); - None - } - }) - .collect(); - if bootnodes.is_empty() { - error!("❌ No valid bootnodes provided. Please provide at least one valid bootnode address."); - std::process::exit(1); - } - - let (p2p_service, kademlia_action_tx) = match crate::p2p::Service::new( - state.get_p2p_keypair().clone(), - *libp2p_port, - bootnodes, - node_wallet_instance.clone(), - validator_addresses, - docker_service.clone(), - heartbeat.clone(), - state.clone(), - contracts.clone(), - provider_wallet_instance.clone(), - cancellation_token.clone(), - ) { - Ok(service) => service, - Err(e) => { - error!("❌ Failed to start P2P service: {e}"); - std::process::exit(1); - } - }; - - let peer_id = p2p_service.peer_id(); - node_config.worker_p2p_id = p2p_service.peer_id().to_string(); - let external_p2p_address = - format!("/ip4/{}/tcp/{}", node_config.ip_address, *libp2p_port); - node_config.worker_p2p_addresses = Some( - p2p_service - .listen_addrs() - .iter() - .map(|addr| addr.to_string()) - .chain(vec![external_p2p_address]) - .collect(), - ); - tokio::task::spawn(p2p_service.run()); - - Console::success(&format!("P2P service started with ID: {peer_id}")); - - // sleep so that dht is bootstrapped before publishing. - // TODO: should update p2p service to expose this better (https://github.com/PrimeIntellect-ai/protocol/issues/628) - tokio::time::sleep(Duration::from_secs(1)).await; - - let record_key = p2p::worker_dht_key_with_peer_id(&peer_id); - let (kad_action, mut result_rx) = KademliaAction::PutRecord { - key: record_key.as_bytes().to_vec(), - value: serde_json::to_vec(&node_config).unwrap(), - } - .into_kademlia_action_with_channel(); - if let Err(e) = kademlia_action_tx.send(kad_action).await { - error!("❌ Failed to send Kademlia action: {e}"); - std::process::exit(1); - } - - while let Some(result) = result_rx.recv().await { - match result { - Ok(res) => { - match res { - p2p::KademliaQueryResult::PutRecord(res) => match res { - Ok(_) => { - Console::success("Worker info published to DHT"); - } - Err(e) => { - error!("❌ Failed to put record in DHT: {e}"); - } - }, - _ => { - // this case should never happen - error!("❌ Unexpected result from putting record in DHT: {res:?}"); - std::process::exit(1); - } - } - } - Err(e) => { - error!("❌ Failed to publish worker info to DHT: {e}"); - std::process::exit(1); - } - } - } - - let (kad_action, mut result_rx) = - KademliaAction::StartProviding(p2p::WORKER_DHT_KEY.as_bytes().to_vec()) - .into_kademlia_action_with_channel(); - if let Err(e) = kademlia_action_tx.send(kad_action).await { - error!("❌ Failed to send Kademlia action: {e}"); - std::process::exit(1); - } - - while let Some(result) = result_rx.recv().await { - match result { - Ok(res) => { - match res { - p2p::KademliaQueryResult::StartProviding(res) => match res { - Ok(_) => { - Console::success( - "Advertising ourselves as a worker in the DHT", - ); - } - Err(e) => { - error!("❌ Failed to start providing worker info in DHT: {e}"); - std::process::exit(1); - } - }, - _ => { - // this case should never happen - error!("❌ Unexpected result from starting providing worker info in DHT: {res:?}"); - std::process::exit(1); - } - } - } - Err(e) => { - error!("❌ Failed to start providing worker info in DHT: {e}"); - std::process::exit(1); - } - } - } - - Console::section("Starting Worker with Task Bridge"); - - // Start monitoring compute node status on chain - provider_ops.start_monitoring(provider_ops_cancellation); - - let pool_id = state.get_compute_pool_id(); - if let Err(err) = compute_node_ops.start_monitoring(cancellation_token.clone(), pool_id) - { - error!("❌ Failed to start node monitoring: {err}"); - std::process::exit(1); - } - - if recover_last_state { - info!("Recovering from previous state: {recover_last_state}"); - heartbeat.activate_heartbeat_if_endpoint_exists().await; - } - - // Keep the worker running and listening for P2P connections - Console::success("Worker is now running and listening for P2P connections..."); - - // Wait for cancellation signal to gracefully shutdown - cancellation_token.cancelled().await; - - Console::info( - "Shutdown signal received", - "Gracefully shutting down worker...", - ); - - Ok(()) - } - Commands::Check {} => { - Console::section("🔍 PRIME WORKER SYSTEM CHECK"); - let issues = Arc::new(RwLock::new(IssueReport::new())); - - // Run checks - let mut hardware_checker = HardwareChecker::new(Some(issues.clone())); - let software_checker = SoftwareChecker::new(Some(issues.clone())); - let node_config = Node { - id: String::new(), - ip_address: String::new(), - port: 0, - compute_specs: None, - provider_address: String::new(), - compute_pool_id: 0, - worker_p2p_id: "empty".to_string(), // TODO: this should be a different type, as peer id is not needed for this code path - worker_p2p_addresses: None, - }; - - let node_config = match hardware_checker.check_hardware(node_config, None).await { - Ok(node_config) => node_config, - Err(err) => { - Console::user_error(&format!("❌ Hardware check failed: {err}")); - std::process::exit(1); - } - }; - - if let Err(err) = software_checker.check_software(&node_config).await { - Console::user_error(&format!("❌ Software check failed: {err}")); - std::process::exit(1); - } - - let issues = issues.read().await; - issues.print_issues(); - - if issues.has_critical_issues() { - Console::user_error("❌ Critical issues found. Exiting."); - std::process::exit(1); - } - - Ok(()) - } - Commands::GenerateWallets {} => { - let provider_signer = PrivateKeySigner::random(); - let node_signer = PrivateKeySigner::random(); - - let provider_key = hex::encode(provider_signer.credential().to_bytes()); - let node_key = hex::encode(node_signer.credential().to_bytes()); - - println!("Provider wallet:"); - println!(" Address: {}", provider_signer.address()); - println!(" Private key: {provider_key}"); - println!("\nNode wallet:"); - println!(" Address: {}", node_signer.address()); - println!(" Private key: {node_key}"); - println!("\nTo set environment variables in your current shell session:"); - println!("export PRIVATE_KEY_PROVIDER={provider_key}"); - println!("export PRIVATE_KEY_NODE={node_key}"); - - Ok(()) - } - - Commands::GenerateNodeWallet {} => { - let node_signer = PrivateKeySigner::random(); - let node_key = hex::encode(node_signer.credential().to_bytes()); - - println!("Node wallet:"); - println!(" Address: {}", node_signer.address()); - println!(" Private key: {node_key}"); - println!("\nTo set environment variable in your current shell session:"); - println!("export PRIVATE_KEY_NODE={node_key}"); - - Ok(()) - } - - Commands::Balance { - private_key, - rpc_url, - } => { - let private_key = if let Some(key) = private_key { - key.clone() - } else { - std::env::var("PRIVATE_KEY_PROVIDER").expect("PRIVATE_KEY_PROVIDER must be set") - }; - - let provider_wallet = Wallet::new(&private_key, Url::parse(rpc_url).unwrap()).unwrap(); - - let contracts = ContractBuilder::new(provider_wallet.provider()) - .with_compute_registry() - .with_ai_token() - .with_prime_network() - .with_compute_pool() - .build() - .unwrap(); - - let provider_balance = contracts - .ai_token - .balance_of(provider_wallet.wallet.default_signer().address()) - .await - .unwrap(); - - let format_balance = format_ether(provider_balance).to_string(); - - println!("Provider balance: {format_balance}"); - Ok(()) - } - Commands::SignMessage { - message, - private_key_provider, - private_key_node, - } => { - let private_key_provider = if let Some(key) = private_key_provider { - key.clone() - } else { - std::env::var("PRIVATE_KEY_PROVIDER").expect("PRIVATE_KEY_PROVIDER must be set") - }; - - let private_key_node = if let Some(key) = private_key_node { - key.clone() - } else { - std::env::var("PRIVATE_KEY_NODE").expect("PRIVATE_KEY_NODE must be set") - }; - - let provider_wallet = Wallet::new( - &private_key_provider, - Url::parse("http://localhost:8545").unwrap(), - ) - .unwrap(); - let node_wallet = Wallet::new( - &private_key_node, - Url::parse("http://localhost:8545").unwrap(), - ) - .unwrap(); - - let message_hash = provider_wallet.signer.sign_message(message.as_bytes()); - let node_signature = node_wallet.signer.sign_message(message.as_bytes()); - - let provider_signature = message_hash.await?; - let node_signature = node_signature.await?; - let combined_signature = - [provider_signature.as_bytes(), node_signature.as_bytes()].concat(); - - println!("\nSignature: {}", hex::encode(combined_signature)); - - Ok(()) - } - Commands::Deregister { - private_key_provider, - private_key_node, - rpc_url, - compute_pool_id, - } => { - let private_key_provider = if let Some(key) = private_key_provider { - key.clone() - } else { - std::env::var("PRIVATE_KEY_PROVIDER").expect("PRIVATE_KEY_PROVIDER must be set") - }; - - let private_key_node = if let Some(key) = private_key_node { - key.clone() - } else { - std::env::var("PRIVATE_KEY_NODE").expect("PRIVATE_KEY_NODE must be set") - }; - - let provider_wallet_instance = - match Wallet::new(&private_key_provider, Url::parse(rpc_url).unwrap()) { - Ok(wallet) => wallet, - Err(err) => { - Console::user_error(&format!("Failed to create wallet: {err}")); - std::process::exit(1); - } - }; - - let node_wallet_instance = - match Wallet::new(&private_key_node, Url::parse(rpc_url).unwrap()) { - Ok(wallet) => wallet, - Err(err) => { - Console::user_error(&format!("❌ Failed to create wallet: {err}")); - std::process::exit(1); - } - }; - - /* - Initialize dependencies - services, contracts, operations - */ - - let contracts = ContractBuilder::new(provider_wallet_instance.provider()) - .with_compute_registry() - .with_ai_token() - .with_prime_network() - .with_compute_pool() - .with_stake_manager() - .build() - .unwrap(); - - let provider_address = provider_wallet_instance.wallet.default_signer().address(); - let node_address = node_wallet_instance.wallet.default_signer().address(); - - let provider_ops = - ProviderOperations::new(provider_wallet_instance.clone(), contracts.clone(), false); - - let compute_node_exists = match contracts - .compute_registry - .get_node(provider_address, node_address) - .await - { - Ok(_) => true, - Err(e) => { - Console::user_error(&format!("❌ Failed to check if compute node exists: {e}")); - std::process::exit(1); - } - }; - - let pool_id = U256::from(*compute_pool_id); - - if compute_node_exists { - match contracts - .compute_pool - .leave_compute_pool( - pool_id, - provider_wallet_instance.wallet.default_signer().address(), - node_wallet_instance.wallet.default_signer().address(), - ) - .await - { - Ok(result) => { - Console::success(&format!("Leave compute pool tx: {result:?}")); - } - Err(e) => { - Console::user_error(&format!("❌ Failed to leave compute pool: {e}")); - std::process::exit(1); - } - } - match remove_compute_node(contracts, provider_address, node_address).await { - Ok(_removed_node) => { - Console::success("Compute node removed"); - match provider_ops.reclaim_stake(U256::from(0)).await { - Ok(_) => { - Console::success("Successfully reclaimed stake"); - } - Err(e) => { - Console::user_error(&format!("❌ Failed to reclaim stake: {e}")); - std::process::exit(1); - } - } - } - Err(e) => { - Console::user_error(&format!("❌ Failed to remove compute node: {e}")); - std::process::exit(1); - } - } - } else { - Console::success("Compute node is not registered"); - } - - Ok(()) - } - } -} - -async fn remove_compute_node( - contracts: Contracts, - provider_address: Address, - node_address: Address, -) -> Result> { - Console::title("🔄 Removing compute node"); - let remove_node_tx = contracts - .prime_network - .remove_compute_node(provider_address, node_address) - .await?; - Console::success(&format!("Remove node tx: {remove_node_tx:?}")); - Ok(true) -} diff --git a/crates/worker/src/cli/mod.rs b/crates/worker/src/cli/mod.rs deleted file mode 100644 index e703cd42..00000000 --- a/crates/worker/src/cli/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub(crate) mod command; -pub use command::{execute_command, Cli}; diff --git a/crates/worker/src/lib.rs b/crates/worker/src/lib.rs index 1752bd8e..99b4bb82 100644 --- a/crates/worker/src/lib.rs +++ b/crates/worker/src/lib.rs @@ -8,7 +8,6 @@ mod p2p; mod state; mod utils; -pub use cli::execute_command; pub use cli::Cli; pub use utils::logging::setup_logging; diff --git a/crates/worker/src/main.rs b/crates/worker/src/main.rs index e8032d5a..7c0b1243 100644 --- a/crates/worker/src/main.rs +++ b/crates/worker/src/main.rs @@ -1,16 +1,17 @@ use clap::Parser; +use shared::utils::signal::trigger_cancellation_on_signal; use std::panic; use std::sync::Arc; -use tokio::signal::unix::{signal, SignalKind}; use tokio::sync::Mutex; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use worker::TaskHandles; -use worker::{execute_command, setup_logging, Cli}; +use worker::{setup_logging, Cli}; #[tokio::main] async fn main() -> Result<(), Box> { + // TODO: see if there are any better DS to handle this let task_handles: TaskHandles = Arc::new(Mutex::new(Vec::>::new())); let cli = Cli::parse(); @@ -20,6 +21,7 @@ async fn main() -> Result<(), Box> { } // Set up panic hook to log panics + // TODO: this could be shared via a util module/crate panic::set_hook(Box::new(|panic_info| { let location = panic_info .location() @@ -40,37 +42,15 @@ async fn main() -> Result<(), Box> { ); })); - let mut sigterm = signal(SignalKind::terminate())?; - let mut sigint = signal(SignalKind::interrupt())?; - let mut sighup = signal(SignalKind::hangup())?; - let mut sigquit = signal(SignalKind::quit())?; - let cancellation_token = CancellationToken::new(); - let signal_token = cancellation_token.clone(); - let command_token = cancellation_token.clone(); - let signal_handle = tokio::spawn(async move { - tokio::select! { - _ = sigterm.recv() => { - log::info!("Received termination signal"); - } - _ = sigint.recv() => { - log::info!("Received interrupt signal"); - } - _ = sighup.recv() => { - log::info!("Received hangup signal"); - } - _ = sigquit.recv() => { - log::info!("Received quit signal"); - } - } - signal_token.cancel(); - }); + let signal_handle = trigger_cancellation_on_signal(cancellation_token.clone())?; task_handles.lock().await.push(signal_handle); let task_handles_clone = task_handles.clone(); + let command_token = cancellation_token.clone(); tokio::select! { - cmd_result = execute_command(&cli.command, command_token, task_handles_clone) => { + cmd_result = cli.run(command_token, task_handles_clone) => { if let Err(e) = cmd_result { log::error!("Command execution error: {e}"); } @@ -80,6 +60,7 @@ async fn main() -> Result<(), Box> { } } + // TODO: what happens if lock is held by another task? let mut handles = task_handles.lock().await; for handle in handles.iter() { diff --git a/crates/worker/src/utils/logging.rs b/crates/worker/src/utils/logging.rs index ed50024c..da484bef 100644 --- a/crates/worker/src/utils/logging.rs +++ b/crates/worker/src/utils/logging.rs @@ -4,8 +4,7 @@ use tracing_subscriber::fmt; use tracing_subscriber::prelude::*; use url::Url; -use crate::cli::command::Commands; -use crate::cli::Cli; +use crate::cli::{Cli, Commands}; use anyhow::Result; use std::time::{SystemTime, UNIX_EPOCH}; use time::macros::format_description;