diff --git a/Cargo.lock b/Cargo.lock index 38b542dc..4268a291 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1880,7 +1880,7 @@ dependencies = [ "iroh-quinn-proto", "iroh-quinn-udp", "iroh-relay", - "n0-future", + "n0-future 0.1.3", "n0-snafu", "n0-watcher", "nested_enum_utils", @@ -1925,6 +1925,7 @@ dependencies = [ "ed25519-dalek", "n0-snafu", "nested_enum_utils", + "postcard", "rand_core 0.9.3", "serde", "snafu", @@ -1940,7 +1941,7 @@ dependencies = [ "clap", "comfy-table", "data-encoding", - "derive_more 2.0.1", + "derive_more 1.0.0", "ed25519-dalek", "futures-concurrency", "futures-lite", @@ -1953,8 +1954,10 @@ dependencies = [ "iroh-metrics", "iroh-quinn", "irpc", - "n0-future", + "irpc-iroh", + "n0-future 0.2.0", "n0-snafu", + "n0-watcher", "nested_enum_utils", "postcard", "rand 0.9.2", @@ -1963,6 +1966,7 @@ dependencies = [ "serde", "serde_json", "snafu", + "strum", "testresult", "tokio", "tokio-util", @@ -2087,7 +2091,7 @@ dependencies = [ "iroh-quinn", "iroh-quinn-proto", "lru 0.16.1", - "n0-future", + "n0-future 0.1.3", "n0-snafu", "nested_enum_utils", "num_enum", @@ -2137,7 +2141,7 @@ dependencies = [ "futures-util", "iroh-quinn", "irpc-derive", - "n0-future", + "n0-future 0.1.3", "postcard", "rcgen 0.13.2", "rustls", @@ -2160,6 +2164,24 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "irpc-iroh" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80e06eb3077e16299f86816ca5f5d795abccf6e5c58340d701ecc4166bdcf2ea" +dependencies = [ + "anyhow", + "getrandom 0.3.3", + "iroh", + "iroh-base", + "irpc", + "n0-future 0.1.3", + "postcard", + "serde", + "tokio", + "tracing", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -2383,6 +2405,27 @@ dependencies = [ "web-time", ] +[[package]] +name = "n0-future" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89d7dd42bd0114c9daa9c4f2255d692a73bba45767ec32cf62892af6fe5d31f6" +dependencies = [ + "cfg_aliases", + "derive_more 1.0.0", + "futures-buffered", + "futures-lite", + "futures-util", + "js-sys", + "pin-project", + "send_wrapper", + "tokio", + "tokio-util", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-time", +] + [[package]] name = "n0-snafu" version = "0.2.2" @@ -2403,7 +2446,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c31462392a10d5ada4b945e840cbec2d5f3fee752b96c4b33eb41414d8f45c2a" dependencies = [ "derive_more 1.0.0", - "n0-future", + "n0-future 0.1.3", "snafu", ] @@ -2546,7 +2589,7 @@ dependencies = [ "iroh-quinn-udp", "js-sys", "libc", - "n0-future", + "n0-future 0.1.3", "n0-watcher", "nested_enum_utils", "netdev 0.37.3", @@ -4297,6 +4340,7 @@ dependencies = [ "futures-util", "hashbrown 0.15.4", "pin-project-lite", + "slab", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 60d13f39..551affb8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,21 +35,12 @@ unused-async = "warn" blake3 = "1.8" bytes = { version = "1.7", features = ["serde"] } data-encoding = "2.6.0" -derive_more = { version = "2.0.1", features = [ - "add", - "debug", - "deref", - "display", - "from", - "try_into", - "into", -] } -ed25519-dalek = { version = "3.0.0-pre.1", features = ["serde", "rand_core"] } +derive_more = { version = "1.0.0", features = ["add", "debug", "deref", "display", "from", "try_into", "into", "deref_mut"] } hex = "0.4.3" indexmap = "2.0" iroh-metrics = { version = "0.36", default-features = false } iroh-base = { version = "0.93", default-features = false, features = ["key"] } -n0-future = "0.1.2" +n0-future = "0.2" postcard = { version = "1", default-features = false, features = [ "alloc", "use-std", @@ -64,13 +55,16 @@ futures-concurrency = { version = "7.6.1", optional = true } futures-util = { version = "0.3.30", optional = true } iroh = { version = "0.93", default-features = false, optional = true } tokio = { version = "1", optional = true, features = ["io-util", "sync"] } -tokio-util = { version = "0.7.12", optional = true, features = ["codec"] } +tokio-util = { version = "0.7.12", optional = true, features = ["codec", "time"] } tracing = "0.1" irpc = { version = "0.9.0", optional = true, default-features = false, features = [ "derive", "stream", "spans", ] } +irpc-iroh = { version = "0.9.0", optional = true } +n0-watcher = { version = "0.3.0", optional = true } +strum = { version = "0.27.2", features = ["derive"], optional = true } n0-snafu = { version = "0.2.2", optional = true } nested_enum_utils = { version = "0.2.2", optional = true } snafu = { version = "0.8.5", features = ["rust_1_81"], optional = true } @@ -103,6 +97,7 @@ tokio = { version = "1", features = [ "fs", ] } clap = { version = "4", features = ["derive"] } +ed25519-dalek = { version = "3.0.0-pre.1", features = ["serde", "rand_core"] } humantime-serde = { version = "1.1.1" } iroh = { version = "0.93", default-features = false, features = [ "metrics", @@ -117,16 +112,19 @@ url = "2.4.0" [features] default = ["net", "metrics"] net = [ - "dep:irpc", + "dep:futures-concurrency", "dep:futures-lite", - "dep:iroh", - "dep:tokio", - "dep:tokio-util", "dep:futures-util", - "dep:futures-concurrency", - "dep:nested_enum_utils", + "dep:iroh", + "dep:irpc", + "dep:irpc-iroh", "dep:n0-snafu", + "dep:n0-watcher", + "dep:nested_enum_utils", "dep:snafu", + "dep:strum", + "dep:tokio", + "dep:tokio-util", ] rpc = [ "dep:irpc", diff --git a/src/api.rs b/src/api.rs index ce751b87..f24da9a6 100644 --- a/src/api.rs +++ b/src/api.rs @@ -373,7 +373,7 @@ pub struct Message { } /// Command for a gossip topic. -#[derive(Serialize, Deserialize, derive_more::Debug, Clone)] +#[derive(Serialize, Deserialize, derive_more::Debug, Clone, strum::Display)] pub enum Command { /// Broadcasts a message to all nodes in the swarm. Broadcast(#[debug("Bytes({})", _0.len())] Bytes), @@ -383,6 +383,18 @@ pub enum Command { JoinPeers(Vec), } +impl From for crate::proto::Command { + fn from(value: Command) -> Self { + match value { + Command::Broadcast(bytes) => Self::Broadcast(bytes, crate::proto::Scope::Swarm), + Command::BroadcastNeighbors(bytes) => { + Self::Broadcast(bytes, crate::proto::Scope::Neighbors) + } + Command::JoinPeers(peers) => Self::Join(peers), + } + } +} + /// Options for joining a gossip topic. #[derive(Serialize, Deserialize, Debug)] pub struct JoinOptions { @@ -424,7 +436,7 @@ mod tests { use crate::{ api::{Event, GossipApi}, - net::{test::create_endpoint, Gossip}, + net::{tests::create_endpoint, Gossip}, proto::TopicId, ALPN, }; diff --git a/src/metrics.rs b/src/metrics.rs index 44ccaeae..9f7d4a36 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1,6 +1,13 @@ //! Metrics for iroh-gossip use iroh_metrics::{Counter, MetricsGroup}; +use serde::Serialize; + +use crate::proto::{ + self, + state::MessageKind, + topic::{InEvent, OutEvent}, +}; /// Enum of metrics for the module #[derive(Debug, Default, MetricsGroup)] @@ -26,20 +33,79 @@ pub struct Metrics { pub neighbor_up: Counter, /// Number of times we disconnected from a peer pub neighbor_down: Counter, + /// Number of messages we broadcasted to all nodes + pub msgs_broadcast_swarm: Counter, + /// Number of messages we broadcasted to direct neighbors + pub msgs_broadcast_neighbors: Counter, + /// Number of topcis we joined. + pub topics_joined: Counter, + /// Number of topcis we left. + pub topics_quit: Counter, + /// Number of times we successfully dialed a remote node. + pub peers_dialed_success: Counter, + /// Number of times we failed to dial a remote node. + pub peers_dialed_failure: Counter, + /// Number of times we accepted a connection from a remote node. + pub peers_accepted: Counter, /// Number of times the main actor loop ticked pub actor_tick_main: Counter, - /// Number of times the actor ticked for a message received - pub actor_tick_rx: Counter, - /// Number of times the actor ticked for an endpoint event - pub actor_tick_endpoint: Counter, - /// Number of times the actor ticked for a dialer event - pub actor_tick_dialer: Counter, - /// Number of times the actor ticked for a successful dialer event - pub actor_tick_dialer_success: Counter, - /// Number of times the actor ticked for a failed dialer event - pub actor_tick_dialer_failure: Counter, - /// Number of times the actor ticked for an incoming event - pub actor_tick_in_event_rx: Counter, - /// Number of times the actor ticked for a timer event - pub actor_tick_timers: Counter, +} + +impl Metrics { + /// Track an [`InEvent`]. + pub fn track_in_event(&self, in_event: &InEvent) { + match in_event { + InEvent::RecvMessage(_, message) => match message.kind() { + MessageKind::Data => { + self.msgs_data_recv.inc(); + self.msgs_data_recv_size + .inc_by(message.size().unwrap_or(0) as u64); + } + MessageKind::Control => { + self.msgs_ctrl_recv.inc(); + self.msgs_ctrl_recv_size + .inc_by(message.size().unwrap_or(0) as u64); + } + }, + InEvent::Command(cmd) => match cmd { + proto::Command::Broadcast(_, scope) => match scope { + proto::Scope::Swarm => inc(&self.msgs_broadcast_swarm), + proto::Scope::Neighbors => inc(&self.msgs_broadcast_neighbors), + }, + proto::Command::Join(_) => {} + proto::Command::Quit => {} + }, + InEvent::TimerExpired(_) => {} + InEvent::PeerDisconnected(_) => {} + InEvent::UpdatePeerData(_) => {} + } + } + + /// Track an [`OutEvent`]. + pub fn track_out_event(&self, out_event: &OutEvent) { + match out_event { + OutEvent::SendMessage(_to, message) => match message.kind() { + MessageKind::Data => { + self.msgs_data_sent.inc(); + self.msgs_data_sent_size + .inc_by(message.size().unwrap_or(0) as u64); + } + MessageKind::Control => { + self.msgs_ctrl_sent.inc(); + self.msgs_ctrl_sent_size + .inc_by(message.size().unwrap_or(0) as u64); + } + }, + OutEvent::EmitEvent(event) => match event { + proto::Event::NeighborUp(_peer) => inc(&self.neighbor_up), + proto::Event::NeighborDown(_peer) => inc(&self.neighbor_down), + _ => {} + }, + _ => {} + } + } +} + +pub(crate) fn inc(counter: &Counter) { + counter.inc(); } diff --git a/src/net.rs b/src/net.rs index 14902c2d..fedcfa6b 100644 --- a/src/net.rs +++ b/src/net.rs @@ -1,74 +1,72 @@ //! Networking for the `iroh-gossip` protocol +#[cfg(test)] +use std::sync::atomic::AtomicBool; use std::{ - collections::{hash_map::Entry, BTreeSet, HashMap, HashSet, VecDeque}, - net::SocketAddr, - pin::Pin, + collections::{hash_map, BTreeSet, HashMap, HashSet, VecDeque}, + ops::DerefMut, sync::Arc, - task::{Context, Poll}, + time::Duration, }; use bytes::Bytes; -use futures_concurrency::stream::{stream_group, StreamGroup}; -use futures_util::FutureExt as _; use iroh::{ - endpoint::Connection, + endpoint::{ConnectError, Connection}, protocol::{AcceptError, ProtocolHandler}, - Endpoint, NodeAddr, NodeId, PublicKey, RelayUrl, Watcher, + Endpoint, NodeAddr, NodeId, Watcher, +}; +use irpc::{ + channel::{self, mpsc::RecvError}, + WithChannels, }; -use irpc::WithChannels; use n0_future::{ - task::{self, AbortOnDropHandle, JoinSet}, + stream::Boxed as BoxStream, + task::{self, AbortOnDropHandle}, time::Instant, - Stream, StreamExt as _, + MergeUnbounded, Stream, StreamExt, }; -use nested_enum_utils::common_fields; -use rand::{rngs::StdRng, SeedableRng}; -use serde::{Deserialize, Serialize}; +use n0_watcher::{Direct, Watchable}; +use rand::rngs::StdRng; use snafu::Snafu; -use tokio::sync::{broadcast, mpsc, oneshot}; -use tokio_util::sync::CancellationToken; -use tracing::{debug, error, error_span, trace, warn, Instrument}; +use tokio::{ + sync::{broadcast, mpsc}, + task::JoinSet, +}; +use tracing::{debug, error_span, info, instrument, trace, warn, Instrument}; -use self::util::{RecvLoop, SendLoop, Timers}; +use self::{ + dialer::Dialer, + discovery::GossipDiscovery, + net_proto::GossipMessage, + util::{AddrInfo, ConnectionCounter, Guarded, IrohRemoteConnection, Timers}, +}; use crate::{ - api::{self, Command, Event, GossipApi, RpcMessage}, - metrics::Metrics, - proto::{self, HyparviewConfig, PeerData, PlumtreeConfig, Scope, TopicId}, + api::{self, GossipApi}, + metrics::{inc, Metrics}, + net::util::accept_stream, + proto::{self, Config, HyparviewConfig, PeerData, PlumtreeConfig, TopicId}, }; +mod dialer; +mod discovery; mod util; /// ALPN protocol name pub const GOSSIP_ALPN: &[u8] = b"/iroh-gossip/1"; -/// Channel capacity for the send queue (one per connection) -const SEND_QUEUE_CAP: usize = 64; -/// Channel capacity for the ToActor message queue (single) -const TO_ACTOR_CAP: usize = 64; -/// Channel capacity for the InEvent message queue (single) -const IN_EVENT_CAP: usize = 1024; -/// Channel capacity for broadcast subscriber event queue (one per topic) -const TOPIC_EVENT_CAP: usize = 256; -/// Name used for logging when new node addresses are added from gossip. -const SOURCE_NAME: &str = "gossip"; - -/// Events emitted from the gossip protocol -pub type ProtoEvent = proto::Event; -/// Commands for the gossip protocol -pub type ProtoCommand = proto::Command; - -type InEvent = proto::InEvent; -type OutEvent = proto::OutEvent; -type Timer = proto::Timer; -type ProtoMessage = proto::Message; +type InEvent = proto::topic::InEvent; +type OutEvent = proto::topic::OutEvent; +type Timer = proto::topic::Timer; +type ProtoMessage = proto::topic::Message; +type ProtoEvent = proto::topic::Event; +type State = proto::topic::State; +type Command = proto::topic::Command; /// Publish and subscribe on gossiping topics. /// /// Each topic is a separate broadcast tree with separate memberships. -/// /// A topic has to be joined before you can publish or subscribe on the topic. -/// To join the swarm for a topic, you have to know the [`PublicKey`] of at least one peer that also joined the topic. +/// To join the swarm for a topic, you have to know the [`NodeId`] of at least one peer that also joined the topic. /// /// Messages published on the swarm will be delivered to all peers that joined the swarm for that /// topic. You will also be relaying (gossiping) messages published by other peers. @@ -82,48 +80,27 @@ type ProtoMessage = proto::Message; /// /// The gossip actor will, however, initiate new connections to other peers by itself. #[derive(Debug, Clone)] -pub struct Gossip { - pub(crate) inner: Arc, -} +pub struct Gossip(Arc); impl std::ops::Deref for Gossip { type Target = GossipApi; fn deref(&self) -> &Self::Target { - &self.inner.api + &self.0.api } } -#[derive(Debug)] +#[derive(derive_more::Debug)] enum LocalActorMessage { - HandleConnection(Connection), - Shutdown { reply: oneshot::Sender<()> }, -} - -#[allow(missing_docs)] -#[common_fields({ - backtrace: Option, - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, -})] -#[derive(Debug, Snafu)] -#[non_exhaustive] -pub enum Error { - ActorDropped {}, -} - -impl From> for Error { - fn from(_value: mpsc::error::SendError) -> Self { - ActorDroppedSnafu.build() - } -} -impl From for Error { - fn from(_value: oneshot::error::RecvError) -> Self { - ActorDroppedSnafu.build() - } + #[debug("HandleConnection({})", _0.fmt_short())] + HandleConnection(NodeId, Connection), + #[debug("Connect({}, {})", _0.fmt_short(), _1.fmt_short())] + Connect(NodeId, TopicId), + #[debug("SetPeerData({}, {})", _0.fmt_short(), _1.as_bytes().len())] + SetPeerData(NodeId, PeerData), } #[derive(Debug)] -pub(crate) struct Inner { +struct Inner { api: GossipApi, local_tx: mpsc::Sender, _actor_handle: AbortOnDropHandle<()>, @@ -133,16 +110,15 @@ pub(crate) struct Inner { impl ProtocolHandler for Gossip { async fn accept(&self, connection: Connection) -> Result<(), AcceptError> { - self.handle_connection(connection) + let remote = connection.remote_node_id()?; + self.handle_connection(remote, connection) .await .map_err(AcceptError::from_err)?; Ok(()) } async fn shutdown(&self) { - if let Err(err) = self.shutdown().await { - warn!("error while shutting down gossip: {err:#}"); - } + // TODO: Graceful shutdown? } } @@ -187,26 +163,7 @@ impl Builder { /// Spawn a gossip actor and get a handle for it pub fn spawn(self, endpoint: Endpoint) -> Gossip { - let metrics = Arc::new(Metrics::default()); - let (actor, rpc_tx, local_tx) = - Actor::new(endpoint, self.config, metrics.clone(), self.alpn); - let me = actor.endpoint.node_id().fmt_short(); - let max_message_size = actor.state.max_message_size(); - - let actor_handle = task::spawn(actor.run().instrument(error_span!("gossip", %me))); - - let api = GossipApi::local(rpc_tx); - - Gossip { - inner: Inner { - api, - local_tx, - _actor_handle: AbortOnDropHandle::new(actor_handle), - max_message_size, - metrics, - } - .into(), - } + Gossip::new(endpoint, self.config, self.alpn) } } @@ -222,973 +179,930 @@ impl Gossip { /// Listen on a quinn endpoint for incoming RPC connections. #[cfg(feature = "rpc")] pub async fn listen(self, endpoint: quinn::Endpoint) { - self.inner.api.listen(endpoint).await + self.0.api.listen(endpoint).await } /// Get the maximum message size configured for this gossip actor. pub fn max_message_size(&self) -> usize { - self.inner.max_message_size + self.0.max_message_size } /// Handle an incoming [`Connection`]. /// /// Make sure to check the ALPN protocol yourself before passing the connection. - pub async fn handle_connection(&self, conn: Connection) -> Result<(), Error> { - self.inner + pub async fn handle_connection( + &self, + remote: NodeId, + connection: Connection, + ) -> Result<(), ActorStoppedError> { + self.0 .local_tx - .send(LocalActorMessage::HandleConnection(conn)) - .await?; - Ok(()) - } - - /// Shutdown the gossip instance. - /// - /// This leaves all topics, sending `Disconnect` messages to peers, and then - /// stops the gossip actor loop and drops all state and connections. - pub async fn shutdown(&self) -> Result<(), Error> { - let (reply, reply_rx) = oneshot::channel(); - self.inner - .local_tx - .send(LocalActorMessage::Shutdown { reply }) - .await?; - reply_rx.await?; + .send(LocalActorMessage::HandleConnection(remote, connection)) + .await + .map_err(|_| ActorStoppedSnafu.build())?; Ok(()) } /// Returns the metrics tracked for this gossip instance. pub fn metrics(&self) -> &Arc { - &self.inner.metrics + &self.0.metrics + } + + fn new(endpoint: Endpoint, config: Config, alpn: Option) -> Self { + let metrics = Arc::new(Metrics::default()); + let max_message_size = config.max_message_size; + let me = endpoint.node_id(); + let (api_tx, local_tx, actor) = Actor::new(endpoint, config, alpn, metrics.clone()); + let actor_task = task::spawn( + actor + .run() + .instrument(error_span!("gossip", me=%me.fmt_short())), + ); + + Self(Arc::new(Inner { + local_tx, + max_message_size, + api: GossipApi::local(api_tx), + metrics, + _actor_handle: AbortOnDropHandle::new(actor_task), + })) + } + + #[cfg(test)] + fn new_with_actor(endpoint: Endpoint, config: Config, alpn: Option) -> (Self, Actor) { + let metrics = Arc::new(Metrics::default()); + let max_message_size = config.max_message_size; + let (api_tx, local_tx, actor) = Actor::new(endpoint, config, alpn, metrics.clone()); + let handle = Self(Arc::new(Inner { + local_tx, + max_message_size, + api: GossipApi::local(api_tx), + metrics, + _actor_handle: AbortOnDropHandle::new(task::spawn(std::future::pending())), + })); + (handle, actor) + } +} + +mod net_proto { + use irpc::{channel::mpsc, rpc_requests}; + use serde::{Deserialize, Serialize}; + + use crate::proto::TopicId; + + #[derive(Debug, Serialize, Deserialize, Clone)] + #[non_exhaustive] + pub struct JoinRequest { + pub topic_id: TopicId, + } + + #[rpc_requests(message = GossipMessage)] + #[derive(Debug, Serialize, Deserialize)] + pub enum Request { + #[rpc(tx=mpsc::Sender, rx=mpsc::Receiver)] + Join(JoinRequest), } } -/// Actor that sends and handles messages between the connection and main state loops +/// Error emitted when the gossip actor stopped. +#[derive(Debug, Snafu)] +pub struct ActorStoppedError; + +#[derive(strum::Display)] +enum ActorToTopic { + Api(ApiJoinRequest), + Connected { + remote: NodeId, + tx: Guarded>, + rx: Guarded>, + }, + ConnectionFailed(NodeId), +} + +type ApiJoinRequest = WithChannels; +type ApiRecvStream = BoxStream>; +type RemoteRecvStream = BoxStream<(NodeId, Result, RecvError>)>; +type AcceptRemoteRequestsStream = + MergeUnbounded>)>>; + struct Actor { - alpn: Bytes, - /// Protocol state - state: proto::State, - /// The endpoint through which we dial peers + me: NodeId, endpoint: Endpoint, - /// Dial machine to connect to peers - dialer: Dialer, - /// Input messages to the actor - rpc_rx: mpsc::Receiver, + alpn: Bytes, + config: Config, local_rx: mpsc::Receiver, - /// Sender for the state input (cloned into the connection loops) - in_event_tx: mpsc::Sender, - /// Input events to the state (emitted from the connection loops) - in_event_rx: mpsc::Receiver, - /// Queued timers - timers: Timers, - /// Map of topics to their state. - topics: HashMap, - /// Map of peers to their state. - peers: HashMap, - /// Stream of commands from topic handles. - command_rx: stream_group::Keyed, - /// Internal queue of topic to close because all handles were dropped. - quit_queue: VecDeque, - /// Tasks for the connection loops, to keep track of panics. - connection_tasks: JoinSet<(NodeId, Connection, Result<(), ConnectionLoopError>)>, + local_tx: mpsc::Sender, + api_rx: mpsc::Receiver, + topics: HashMap, + pending_remotes_with_topics: HashMap>, + topic_tasks: JoinSet, + remotes: HashMap, + close_connections: JoinSet<(NodeId, Connection)>, + dialer: Dialer, + our_peer_data: n0_watcher::Watchable, metrics: Arc, - topic_event_forwarders: JoinSet, + node_addr_updates: BoxStream, + accepting: AcceptRemoteRequestsStream, + discovery: GossipDiscovery, } impl Actor { fn new( endpoint: Endpoint, - config: proto::Config, - metrics: Arc, + config: Config, alpn: Option, + metrics: Arc, ) -> ( - Self, - mpsc::Sender, + mpsc::Sender, mpsc::Sender, + Self, ) { - let peer_id = endpoint.node_id(); - let dialer = Dialer::new(endpoint.clone()); - let state = proto::State::new( - peer_id, - Default::default(), - config, - rand::rngs::StdRng::from_rng(&mut rand::rng()), - ); - let (rpc_tx, rpc_rx) = mpsc::channel(TO_ACTOR_CAP); - let (local_tx, local_rx) = mpsc::channel(16); - let (in_event_tx, in_event_rx) = mpsc::channel(IN_EVENT_CAP); - - let actor = Actor { - alpn: alpn.unwrap_or_else(|| GOSSIP_ALPN.to_vec().into()), - endpoint, - state, - dialer, - rpc_rx, - in_event_rx, - in_event_tx, - timers: Timers::new(), - command_rx: StreamGroup::new().keyed(), - peers: Default::default(), - topics: Default::default(), - quit_queue: Default::default(), - connection_tasks: Default::default(), - metrics, - local_rx, - topic_event_forwarders: Default::default(), - }; - - (actor, rpc_tx, local_tx) + let (api_tx, api_rx) = tokio::sync::mpsc::channel(16); + let (local_tx, local_rx) = tokio::sync::mpsc::channel(16); + + let me = endpoint.node_id(); + let node_addr_updates = endpoint.watch_node_addr().stream(); + let discovery = GossipDiscovery::default(); + endpoint.discovery().add(discovery.clone()); + let initial_peer_data = AddrInfo::from(endpoint.node_addr()).encode(); + // let peer_data = endpoint + // .watch_node_addr() + // .map(|addr| AddrInfo::from(addr).encode()) + // .unwrap(); + ( + api_tx, + local_tx.clone(), + Actor { + endpoint, + me, + config, + api_rx, + local_tx, + local_rx, + node_addr_updates: Box::pin(node_addr_updates), + dialer: Dialer::default(), + our_peer_data: Watchable::new(initial_peer_data), + alpn: alpn.unwrap_or_else(|| crate::ALPN.to_vec().into()), + metrics: metrics.clone(), + topics: Default::default(), + pending_remotes_with_topics: Default::default(), + remotes: Default::default(), + close_connections: JoinSet::new(), + topic_tasks: JoinSet::new(), + accepting: Default::default(), + discovery, + }, + ) } - pub async fn run(mut self) { - let mut addr_update_stream = self.setup().await; + async fn run(mut self) { + while self.tick().await {} + } - let mut i = 0; - while self.event_loop(&mut addr_update_stream, i).await { - i += 1; - } + #[cfg(test)] + #[instrument("gossip", skip_all, fields(me=%self.me.fmt_short()))] + pub(crate) async fn finish(self) { + self.run().await } - /// Performs the initial actor setup to run the [`Actor::event_loop`]. - /// - /// This updates our current address and return it. It also returns the home relay stream and - /// direct addr stream. - async fn setup(&mut self) -> impl Stream + Send + Unpin + use<> { - let addr_update_stream = self.endpoint.watch_node_addr().stream(); - // TODO(Frando): Fail if endpoint disconnected? - let initial_addr = self.endpoint.node_addr(); - self.handle_addr_update(initial_addr).await; - addr_update_stream + #[cfg(test)] + #[instrument("gossip", skip_all, fields(me=%self.me.fmt_short()))] + pub(crate) async fn steps(&mut self, n: usize) -> Result<(), ActorStoppedError> { + for _ in 0..n { + if !self.tick().await { + return Err(ActorStoppedError); + } + } + Ok(()) } - /// One event loop processing step. - /// - /// None is returned when no further processing should be performed. - async fn event_loop( - &mut self, - addr_updates: &mut (impl Stream + Send + Unpin), - i: usize, - ) -> bool { + async fn tick(&mut self) -> bool { + trace!("wait for tick"); self.metrics.actor_tick_main.inc(); tokio::select! { - biased; - conn = self.local_rx.recv() => { - match conn { - Some(LocalActorMessage::Shutdown { reply }) => { - debug!("received shutdown message, quit all topics"); - self.quit_queue.extend(self.topics.keys().copied()); - self.process_quit_queue().await; - debug!("all topics quit, stop gossip actor"); - reply.send(()).ok(); - return false; - }, - Some(LocalActorMessage::HandleConnection(conn)) => { - if let Ok(remote_node_id) = conn.remote_node_id() { - self.handle_connection(remote_node_id, ConnOrigin::Accept, conn); - } - } + addr = self.node_addr_updates.next() => { + trace!("tick: node_addr_update"); + match addr { None => { - debug!("all gossip handles dropped, stop gossip actor"); - return false; + warn!("address stream returned None - endpoint has shut down"); + false + } + Some(addr) => { + let data = AddrInfo::from(addr).encode(); + self.our_peer_data.set(data).ok(); + true } } } - msg = self.rpc_rx.recv() => { - trace!(?i, "tick: to_actor_rx"); - self.metrics.actor_tick_rx.inc(); + Some(msg) = self.local_rx.recv() => { + trace!("tick: local_rx {msg:?}"); match msg { - Some(msg) => { - self.handle_rpc_msg(msg, Instant::now()).await; + LocalActorMessage::HandleConnection(node_id, connection) => { + self.handle_remote_connection(node_id, Ok(connection), Direction::Accept).await; } - None => { - debug!("all gossip handles dropped, stop gossip actor"); - return false; + LocalActorMessage::Connect(node_id, topic_id) => { + self.connect(node_id, topic_id); + } + LocalActorMessage::SetPeerData(node_id, data) => { + match AddrInfo::decode(&data) { + Err(err) => warn!(remote=%node_id.fmt_short(), ?err, len=data.inner().len(), "Failed to decode peer data"), + Ok(info) => { + debug!(peer = ?node_id, "add known addrs: {info:?}"); + let node_addr = info.into_node_addr(node_id); + self.discovery.add(node_addr); + } + } } } - }, - Some((key, (topic, command))) = self.command_rx.next(), if !self.command_rx.is_empty() => { - trace!(?i, "tick: command_rx"); - self.handle_command(topic, key, command).await; - }, - Some(new_address) = addr_updates.next() => { - trace!(?i, "tick: new_address"); - self.metrics.actor_tick_endpoint.inc(); - self.handle_addr_update(new_address).await; + true } - (peer_id, res) = self.dialer.next_conn() => { - trace!(?i, "tick: dialer"); - self.metrics.actor_tick_dialer.inc(); + Some((node_id, res)) = self.dialer.next(), if !self.dialer.is_empty() => { + trace!(remote=%node_id.fmt_short(), ok=res.is_ok(), "tick: dialed"); + self.handle_remote_connection(node_id, res, Direction::Dial).await; + true + } + Some((node_id, res)) = self.accepting.next(), if !self.accepting.is_empty() => { + trace!(remote=%node_id.fmt_short(), res=?res.as_ref().map(|_| ()), "tick: accepting"); match res { - Some(Ok(conn)) => { - debug!(peer = %peer_id.fmt_short(), "dial successful"); - self.metrics.actor_tick_dialer_success.inc(); - self.handle_connection(peer_id, ConnOrigin::Dial, conn); + Ok(request) => self.handle_remote_message(node_id, request).await, + Err(reason) => { + debug!(remote=%node_id.fmt_short(), ?reason, "accept loop for remote closed"); } - Some(Err(err)) => { - warn!(peer = %peer_id.fmt_short(), "dial failed: {err}"); - self.metrics.actor_tick_dialer_failure.inc(); - let peer_state = self.peers.get(&peer_id); - let is_active = matches!(peer_state, Some(PeerState::Active { .. })); - if !is_active { - self.handle_in_event(InEvent::PeerDisconnected(peer_id), Instant::now()) - .await; - } + } + true + } + msg = self.api_rx.recv() => { + trace!(some=msg.is_some(), "tick: api_rx"); + match msg { + Some(msg) => { + self.handle_api_message(msg).await; + true } None => { - warn!(peer = %peer_id.fmt_short(), "dial disconnected"); - self.metrics.actor_tick_dialer_failure.inc(); + trace!("all api senders dropped, stop actor"); + false } } } - event = self.in_event_rx.recv() => { - trace!(?i, "tick: in_event_rx"); - self.metrics.actor_tick_in_event_rx.inc(); - let event = event.expect("unreachable: in_event_tx is never dropped before receiver"); - self.handle_in_event(event, Instant::now()).await; - } - _ = self.timers.wait_next() => { - trace!(?i, "tick: timers"); - self.metrics.actor_tick_timers.inc(); - let now = Instant::now(); - while let Some((_instant, timer)) = self.timers.pop_before(now) { - self.handle_in_event(InEvent::TimerExpired(timer), now).await; - } - } - Some(res) = self.connection_tasks.join_next(), if !self.connection_tasks.is_empty() => { - trace!(?i, "tick: connection_tasks"); - let (peer_id, conn, result) = res.expect("connection task panicked"); - self.handle_connection_task_finished(peer_id, conn, result).await; - } - Some(res) = self.topic_event_forwarders.join_next(), if !self.topic_event_forwarders.is_empty() => { - let topic_id = res.expect("topic event forwarder panicked"); - if let Some(state) = self.topics.get_mut(&topic_id) { - if !state.still_needed() { - self.quit_queue.push_back(topic_id); - self.process_quit_queue().await; + Some(res) = self.close_connections.join_next(), if !self.close_connections.is_empty() => { + let (node_id, connection) = res.expect("connection task panicked"); + trace!(remote=%node_id.fmt_short(), "tick: connection closed"); + if let Some(state) = self.remotes.get(&node_id) { + if state.same_connection(&connection) { + self.remotes.remove(&node_id); } } + true + } + Some(actor) = self.topic_tasks.join_next(), if !self.topic_tasks.is_empty() => { + let actor = actor.expect("topic actor task panicked"); + trace!(topic=%actor.topic_id.fmt_short(), "tick: topic actor finished"); + self.topics.remove(&actor.topic_id); + true } + else => unreachable!("reached else arm, but all fallible cases should be handled"), } - - true } - async fn handle_addr_update(&mut self, node_addr: NodeAddr) { - // let peer_data = our_peer_data(&self.endpoint, current_addresses); - let peer_data = encode_peer_data(&node_addr.into()); - self.handle_in_event(InEvent::UpdatePeerData(peer_data), Instant::now()) - .await + #[cfg(test)] + fn endpoint(&self) -> &Endpoint { + &self.endpoint } - async fn handle_command( + fn drain_pending_dials( &mut self, - topic: TopicId, - key: stream_group::Key, - command: Option, - ) { - debug!(?topic, ?key, ?command, "handle command"); - let Some(state) = self.topics.get_mut(&topic) else { - // TODO: unreachable? - warn!("received command for unknown topic"); - return; - }; - match command { - Some(command) => { - let command = match command { - Command::Broadcast(message) => ProtoCommand::Broadcast(message, Scope::Swarm), - Command::BroadcastNeighbors(message) => { - ProtoCommand::Broadcast(message, Scope::Neighbors) - } - Command::JoinPeers(peers) => ProtoCommand::Join(peers), - }; - self.handle_in_event(proto::InEvent::Command(topic, command), Instant::now()) - .await; - } - None => { - state.command_rx_keys.remove(&key); - if !state.still_needed() { - self.quit_queue.push_back(topic); - self.process_quit_queue().await; - } - } - } + remote: &NodeId, + ) -> impl Iterator { + self.pending_remotes_with_topics + .remove(remote) + .into_iter() + .flatten() + .flat_map(|topic_id| self.topics.get(&topic_id).map(|handle| (topic_id, handle))) } - fn handle_connection(&mut self, peer_id: NodeId, origin: ConnOrigin, conn: Connection) { - let (send_tx, send_rx) = mpsc::channel(SEND_QUEUE_CAP); - let conn_id = conn.stable_id(); - - let queue = match self.peers.entry(peer_id) { - Entry::Occupied(mut entry) => entry.get_mut().accept_conn(send_tx, conn_id), - Entry::Vacant(entry) => { - entry.insert(PeerState::Active { - active_send_tx: send_tx, - active_conn_id: conn_id, - other_conns: Vec::new(), - }); - Vec::new() - } + fn connect(&mut self, remote: NodeId, topic_id: TopicId) { + let Some(handle) = self.topics.get(&topic_id) else { + return; }; - - let max_message_size = self.state.max_message_size(); - let in_event_tx = self.in_event_tx.clone(); - - // Spawn a task for this connection - self.connection_tasks.spawn( - async move { - let res = connection_loop( - peer_id, - conn.clone(), - origin, - send_rx, - in_event_tx, - max_message_size, - queue, - ) - .await; - (peer_id, conn, res) - } - .instrument(error_span!("conn", peer = %peer_id.fmt_short())), - ); + if let Some(state) = self.remotes.get(&remote) { + let tx = handle.tx.clone(); + let state = state.clone(); + // TODO: Track task? + task::spawn(async move { + let msg = state.open_topic(topic_id).await; + tx.send(msg).await.ok(); + }); + } else { + self.dialer + .queue_dial(&self.endpoint, remote, self.alpn.clone()); + self.pending_remotes_with_topics + .entry(remote) + .or_default() + .insert(topic_id); + } } - #[tracing::instrument(name = "conn", skip_all, fields(peer = %peer_id.fmt_short()))] - async fn handle_connection_task_finished( + #[instrument("connection", skip_all, fields(remote=%remote.fmt_short()))] + async fn handle_remote_connection( &mut self, - peer_id: NodeId, - conn: Connection, - task_result: Result<(), ConnectionLoopError>, + remote: NodeId, + res: Result, + direction: Direction, ) { - if conn.close_reason().is_none() { - conn.close(0u32.into(), b"close from disconnect"); + match (res.as_ref(), direction) { + (Ok(_), Direction::Dial) => inc(&self.metrics.peers_dialed_success), + (Err(_), Direction::Dial) => inc(&self.metrics.peers_dialed_failure), + (Ok(_), Direction::Accept) => inc(&self.metrics.peers_accepted), + (Err(_), Direction::Accept) => {} } - let reason = conn.close_reason().expect("just closed"); - let error = task_result.err(); - debug!(%reason, ?error, "connection closed"); - if let Some(PeerState::Active { - active_conn_id, - other_conns, - .. - }) = self.peers.get_mut(&peer_id) - { - if conn.stable_id() == *active_conn_id { - debug!("active send connection closed, mark peer as disconnected"); - self.handle_in_event(InEvent::PeerDisconnected(peer_id), Instant::now()) - .await; - } else { - other_conns.retain(|x| *x != conn.stable_id()); - debug!("remaining {} other connections", other_conns.len() + 1); + let connection = match res { + Err(err) => { + debug!(?err, "Connection failed"); + for (_, handle) in self.drain_pending_dials(&remote) { + handle + .send(ActorToTopic::ConnectionFailed(remote)) + .await + .ok(); + } + return; } - } else { - debug!("peer already marked as disconnected"); - } - } + Ok(connection) => connection, + }; - async fn handle_rpc_msg(&mut self, msg: RpcMessage, now: Instant) { - trace!("handle to_actor {msg:?}"); - match msg { - RpcMessage::Join(msg) => { - let WithChannels { - inner, - rx, - tx, - // TODO(frando): make use of span? - span: _, - } = msg; - let api::JoinRequest { - topic_id, - bootstrap, - } = inner; - let TopicState { - neighbors, - event_sender, - command_rx_keys, - } = self.topics.entry(topic_id).or_default(); - let mut sender_dead = false; - if !neighbors.is_empty() { - for neighbor in neighbors.iter() { - if let Err(_err) = tx.try_send(Event::NeighborUp(*neighbor)).await { - sender_dead = true; - break; - } - } + let state = RemoteState::new(remote, connection.clone(), direction); + + // Open requests for pending topics. + for (topic_id, handle) in self.drain_pending_dials(&remote) { + let tx = handle.tx.clone(); + let state = state.clone(); + task::spawn( + async move { + let msg = state.open_topic(topic_id).await; + tx.send(msg).await.ok(); } + .instrument(tracing::Span::current()), + ); + } - if !sender_dead { - let fut = - topic_subscriber_loop(tx, event_sender.subscribe()).map(move |_| topic_id); - self.topic_event_forwarders - .spawn(fut.instrument(tracing::Span::current())); + // Read incoming requests. + let counter = state.counter.clone(); + self.accepting.push(Box::pin( + accept_stream::(connection.clone()) + .map(move |req| (remote, req.map(|r| counter.guard(r)))), + )); + + // Close on idle (if dialed) or await close (if accepted). + let counter = state.counter.clone(); + let fut = async move { + match direction { + Direction::Dial => { + counter.idle_for(Duration::from_millis(500)).await; + info!("close connection (from dial): unused"); + connection.close(1u32.into(), b"idle"); + } + Direction::Accept => { + let reason = connection.closed().await; + info!(?reason, "connection closed (from accept)") } - let command_rx = TopicCommandStream::new(topic_id, Box::pin(rx.into_stream())); - let key = self.command_rx.insert(command_rx); - command_rx_keys.insert(key); - - self.handle_in_event( - InEvent::Command( - topic_id, - ProtoCommand::Join(bootstrap.into_iter().collect()), - ), - now, - ) - .await; } - } - } + (remote, connection) + }; + self.close_connections + .spawn(fut.instrument(error_span!("conn", remote=%remote.fmt_short()))); - async fn handle_in_event(&mut self, event: InEvent, now: Instant) { - self.handle_in_event_inner(event, now).await; - self.process_quit_queue().await; + self.remotes.insert(remote, state); } - async fn process_quit_queue(&mut self) { - while let Some(topic_id) = self.quit_queue.pop_front() { - self.handle_in_event_inner( - InEvent::Command(topic_id, ProtoCommand::Quit), - Instant::now(), - ) - .await; - if self.topics.remove(&topic_id).is_some() { - tracing::debug!(%topic_id, "publishers and subscribers gone; unsubscribing"); + #[instrument("request", skip_all, fields(remote=%remote.fmt_short()))] + async fn handle_remote_message(&mut self, remote: NodeId, request: Guarded) { + let (request, guard) = request.split(); + let (topic_id, request) = match request { + GossipMessage::Join(req) => (req.inner.topic_id, req), + }; + if let Some(topic) = self.topics.get(&topic_id) { + if let Err(_err) = topic + .send(ActorToTopic::Connected { + remote, + tx: Guarded::new(request.tx, guard.clone()), + rx: Guarded::new(request.rx, guard.clone()), + }) + .await + { + warn!(topic=%topic_id.fmt_short(), "Topic actor dead"); } + } else { + debug!(topic=%topic_id.fmt_short(), "ignore request: unknown topic"); } } - async fn handle_in_event_inner(&mut self, event: InEvent, now: Instant) { - if matches!(event, InEvent::TimerExpired(_)) { - trace!(?event, "handle in_event"); - } else { - debug!(?event, "handle in_event"); + async fn handle_api_message(&mut self, msg: api::RpcMessage) { + let (topic_id, msg) = match msg { + api::RpcMessage::Join(msg) => (msg.inner.topic_id, msg), }; - let out = self.state.handle(event, now, Some(&self.metrics)); - for event in out { - if matches!(event, OutEvent::ScheduleTimer(_, _)) { - trace!(?event, "handle out_event"); - } else { - debug!(?event, "handle out_event"); - }; - match event { - OutEvent::SendMessage(peer_id, message) => { - let state = self.peers.entry(peer_id).or_default(); - match state { - PeerState::Active { active_send_tx, .. } => { - if let Err(_err) = active_send_tx.send(message).await { - // Removing the peer is handled by the in_event PeerDisconnected sent - // in [`Self::handle_connection_task_finished`]. - warn!( - peer = %peer_id.fmt_short(), - "failed to send: connection task send loop terminated", - ); - } - } - PeerState::Pending { queue } => { - if queue.is_empty() { - debug!(peer = %peer_id.fmt_short(), "start to dial"); - self.dialer.queue_dial(peer_id, self.alpn.clone()); - } - queue.push(message); - } - } - } - OutEvent::EmitEvent(topic_id, event) => { - let Some(state) = self.topics.get_mut(&topic_id) else { - // TODO: unreachable? - warn!(?topic_id, "gossip state emitted event for unknown topic"); - continue; - }; - let TopicState { - neighbors, - event_sender, - .. - } = state; - match &event { - ProtoEvent::NeighborUp(neighbor) => { - neighbors.insert(*neighbor); - } - ProtoEvent::NeighborDown(neighbor) => { - neighbors.remove(neighbor); - } - _ => {} - } - event_sender.send(event).ok(); - if !state.still_needed() { - self.quit_queue.push_back(topic_id); - } - } - OutEvent::ScheduleTimer(delay, timer) => { - self.timers.insert(now + delay, timer); - } - OutEvent::DisconnectPeer(peer_id) => { - // signal disconnection by dropping the senders to the connection - debug!(peer=%peer_id.fmt_short(), "gossip state indicates disconnect: drop peer"); - self.peers.remove(&peer_id); - } - OutEvent::PeerData(node_id, data) => match decode_peer_data(&data) { - Err(err) => warn!("Failed to decode {data:?} from {node_id}: {err}"), - Ok(info) => { - debug!(peer = ?node_id, "add known addrs: {info:?}"); - let node_addr = NodeAddr { - node_id, - relay_url: info.relay_url, - direct_addresses: info.direct_addresses, - }; - if let Err(err) = self - .endpoint - .add_node_addr_with_source(node_addr, SOURCE_NAME) - { - debug!(peer = ?node_id, "add known failed: {err:?}"); - } - } - }, - } + let topic = self.topics.entry(topic_id).or_insert_with(|| { + let (handle, actor) = TopicHandle::new( + self.me, + topic_id, + self.config.clone(), + self.local_tx.clone(), + self.our_peer_data.watch(), + self.metrics.clone(), + ); + self.topic_tasks.spawn( + actor + .run() + .instrument(error_span!("topic", topic=%topic_id.fmt_short())), + ); + handle + }); + if topic.send(ActorToTopic::Api(msg)).await.is_err() { + warn!(topic=%topic_id.fmt_short(), "Topic actor dead"); } } } -type ConnId = usize; - -#[derive(Debug)] -enum PeerState { - Pending { - queue: Vec, - }, - Active { - active_send_tx: mpsc::Sender, - active_conn_id: ConnId, - other_conns: Vec, - }, +#[derive(Clone)] +struct RemoteState { + node_id: NodeId, + conn_id: usize, + client: irpc::Client, + #[allow(dead_code)] + direction: Direction, + counter: ConnectionCounter, } -impl PeerState { - fn accept_conn( - &mut self, - send_tx: mpsc::Sender, - conn_id: ConnId, - ) -> Vec { - match self { - PeerState::Pending { queue } => { - let queue = std::mem::take(queue); - *self = PeerState::Active { - active_send_tx: send_tx, - active_conn_id: conn_id, - other_conns: Vec::new(), - }; - queue - } - PeerState::Active { - active_send_tx, - active_conn_id, - other_conns, - } => { - // We already have an active connection. We keep the old connection intact, - // but only use the new connection for sending from now on. - // By dropping the `send_tx` of the old connection, the send loop part of - // the `connection_loop` of the old connection will terminate, which will also - // notify the peer that the old connection may be dropped. - other_conns.push(*active_conn_id); - *active_send_tx = send_tx; - *active_conn_id = conn_id; - Vec::new() +impl RemoteState { + fn new(node_id: NodeId, connection: Connection, direction: Direction) -> Self { + let conn_id = connection.stable_id(); + let irpc_conn = IrohRemoteConnection::new(connection); + let client = irpc::Client::boxed(irpc_conn); + let counter = ConnectionCounter::new(); + RemoteState { + client, + direction, + conn_id, + counter, + node_id, + } + } + + fn same_connection(&self, conn: &Connection) -> bool { + self.conn_id == conn.stable_id() + } + + async fn open_topic(&self, topic_id: TopicId) -> ActorToTopic { + let guard = self.counter.get_one(); + let req = net_proto::JoinRequest { topic_id }; + match self.client.bidi_streaming(req.clone(), 64, 64).await { + Ok((tx, rx)) => ActorToTopic::Connected { + remote: self.node_id, + tx: Guarded::new(tx, guard.clone()), + rx: Guarded::new(rx, guard), + }, + Err(err) => { + warn!(?topic_id, ?err, "failed to open stream with remote"); + ActorToTopic::ConnectionFailed(self.node_id) } } } } -impl Default for PeerState { - fn default() -> Self { - PeerState::Pending { queue: Vec::new() } - } +#[derive(Debug, Copy, Clone)] +enum Direction { + Dial, + Accept, } -#[derive(Debug)] -struct TopicState { - neighbors: BTreeSet, - event_sender: broadcast::Sender, - /// Keys identifying command receivers in [`Actor::command_rx`]. - /// - /// This represents the receiver side of gossip's publish public API. - command_rx_keys: HashSet, +struct TopicHandle { + tx: mpsc::Sender, + #[cfg(test)] + joined: Arc, } -impl Default for TopicState { - fn default() -> Self { - let (event_sender, _) = broadcast::channel(TOPIC_EVENT_CAP); - Self { +impl TopicHandle { + fn new( + me: NodeId, + topic_id: TopicId, + config: proto::Config, + to_actor_tx: mpsc::Sender, + peer_data: Direct, + metrics: Arc, + ) -> (Self, TopicActor) { + let (tx, rx) = mpsc::channel(16); + // TODO: peer_data + let state = State::new(me, None, config); + #[cfg(test)] + let joined = Arc::new(AtomicBool::new(false)); + let (forward_event_tx, _) = broadcast::channel(512); + let actor = TopicActor { + topic_id, + state, + actor_rx: rx, + to_actor_tx, + peer_data, + forward_event_tx, + metrics, + init: false, + #[cfg(test)] + joined: joined.clone(), + timers: Default::default(), neighbors: Default::default(), - command_rx_keys: Default::default(), - event_sender, - } + out_events: Default::default(), + api_receivers: Default::default(), + remote_senders: Default::default(), + remote_receivers: Default::default(), + drop_peers_queue: Default::default(), + forward_event_tasks: Default::default(), + }; + let handle = Self { + tx, + #[cfg(test)] + joined, + }; + (handle, actor) } -} -impl TopicState { - /// Check if the topic still has any publisher or subscriber. - fn still_needed(&self) -> bool { - !self.command_rx_keys.is_empty() && self.event_sender.receiver_count() > 0 + async fn send(&self, msg: ActorToTopic) -> Result<(), mpsc::error::SendError> { + self.tx.send(msg).await } #[cfg(test)] fn joined(&self) -> bool { - !self.neighbors.is_empty() + self.joined.load(std::sync::atomic::Ordering::Relaxed) } } -/// Whether a connection is initiated by us (Dial) or by the remote peer (Accept) -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum ConnOrigin { - Accept, - Dial, -} - -#[allow(missing_docs)] -#[common_fields({ - backtrace: Option, -})] -#[derive(Debug, Snafu)] -#[snafu(module)] -#[non_exhaustive] -enum ConnectionLoopError { - #[snafu(transparent)] - Write { - source: self::util::WriteError, - }, - #[snafu(transparent)] - Read { - source: self::util::ReadError, - }, - #[snafu(transparent)] - Connection { - source: iroh::endpoint::ConnectionError, - }, - ActorDropped {}, -} - -impl From> for ConnectionLoopError { - fn from(_value: mpsc::error::SendError) -> Self { - self::connection_loop_error::ActorDroppedSnafu.build() - } +struct TopicActor { + topic_id: TopicId, + to_actor_tx: mpsc::Sender, + state: State, + actor_rx: mpsc::Receiver, + timers: Timers, + neighbors: BTreeSet, + peer_data: Direct, + out_events: VecDeque, + api_receivers: MergeUnbounded, + remote_senders: HashMap, + remote_receivers: MergeUnbounded, + forward_event_tx: broadcast::Sender, + forward_event_tasks: JoinSet<()>, + #[cfg(test)] + joined: Arc, + init: bool, + drop_peers_queue: HashSet, + metrics: Arc, } -async fn connection_loop( - from: PublicKey, - conn: Connection, - origin: ConnOrigin, - send_rx: mpsc::Receiver, - in_event_tx: mpsc::Sender, - max_message_size: usize, - queue: Vec, -) -> Result<(), ConnectionLoopError> { - debug!(?origin, "connection established"); - - let mut send_loop = SendLoop::new(conn.clone(), send_rx, max_message_size); - let mut recv_loop = RecvLoop::new(from, conn, in_event_tx, max_message_size); - - let send_fut = send_loop.run(queue).instrument(error_span!("send")); - let recv_fut = recv_loop.run().instrument(error_span!("recv")); +impl TopicActor { + pub async fn run(mut self) -> Self { + self.metrics.topics_joined.inc(); + let peer_data = self.peer_data.clone().stream(); + tokio::pin!(peer_data); + loop { + tokio::select! { + Some(msg) = self.actor_rx.recv() => { + trace!("tick: actor_rx {msg}"); + self.handle_actor_message(msg).await; + }, + Some(cmd) = self.api_receivers.next(), if !self.api_receivers.is_empty() => { + self.handle_api_command(cmd).await; + } + Some((remote, message)) = self.remote_receivers.next(), if !self.remote_receivers.is_empty() => { + trace!(remote=%remote.fmt_short(), msg=?message, "tick: remote_rx"); + self.handle_remote_message(remote, message).await; + } + Some(data) = peer_data.next() => { + self.handle_in_event(InEvent::UpdatePeerData(data)).await; + } + _ = self.timers.wait_next() => { + let now = Instant::now(); + while let Some((_instant, timer)) = self.timers.pop_before(now) { + self.handle_in_event(InEvent::TimerExpired(timer)).await; + } + } + _ = self.forward_event_tasks.join_next(), if !self.forward_event_tasks.is_empty() => {} + else => break, + } - let (send_res, recv_res) = tokio::join!(send_fut, recv_fut); - send_res?; - recv_res?; - Ok(()) -} + if !self.drop_peers_queue.is_empty() { + let now = Instant::now(); + for peer in self.drop_peers_queue.drain() { + self.out_events + .extend(self.state.handle(InEvent::PeerDisconnected(peer), now)); + } + self.process_out_events(now).await; + } -#[derive(Default, Debug, Clone, Serialize, Deserialize)] -struct AddrInfo { - relay_url: Option, - direct_addresses: BTreeSet, -} + if self.to_actor_tx.is_closed() { + warn!("Channel to main actor closed: abort topic loop"); + break; + } + if self.init && self.api_receivers.is_empty() && self.forward_event_tasks.is_empty() { + debug!("Closing topic: All API subscribers dropped"); + break; + } + } + self.metrics.topics_quit.inc(); + self + } -impl From for AddrInfo { - fn from( - NodeAddr { - relay_url, - direct_addresses, - .. - }: NodeAddr, - ) -> Self { - Self { - relay_url, - direct_addresses, + async fn handle_actor_message(&mut self, msg: ActorToTopic) { + match msg { + ActorToTopic::Connected { remote, rx, tx } => { + self.remote_receivers + .push(Box::pin(into_stream(rx).map(move |msg| (remote, msg)))); + let sender = self.remote_senders.entry(remote).or_default(); + if let Err(err) = sender.init(tx).await { + warn!("Remote failed while pushing queued messages: {err:?}"); + } + } + ActorToTopic::Api(req) => { + self.init = true; + let WithChannels { inner, tx, rx, .. } = req; + let initial_neighbors = self.neighbors.clone().into_iter(); + self.forward_event_tasks.spawn( + forward_events(tx, self.forward_event_tx.subscribe(), initial_neighbors) + .instrument(tracing::Span::current()), + ); + self.api_receivers.push(Box::pin(into_stream2(rx))); + self.handle_in_event(InEvent::Command(Command::Join( + inner.bootstrap.into_iter().collect(), + ))) + .await; + } + ActorToTopic::ConnectionFailed(node_id) => { + self.handle_in_event(InEvent::PeerDisconnected(node_id)) + .await + } } } -} -fn encode_peer_data(info: &AddrInfo) -> PeerData { - let bytes = postcard::to_stdvec(info).expect("serializing AddrInfo may not fail"); - PeerData::new(bytes) -} + async fn handle_remote_message( + &mut self, + remote: NodeId, + message: Result, RecvError>, + ) { + let event = match message { + Ok(Some(message)) => InEvent::RecvMessage(remote, message), + Ok(None) => { + debug!(remote=%remote.fmt_short(), "Recv stream from remote closed"); + InEvent::PeerDisconnected(remote) + } + Err(err) => { + warn!(remote=%remote.fmt_short(), ?err, "Recv stream from remote failed"); + InEvent::PeerDisconnected(remote) + } + }; + self.handle_in_event(event).await; + } -fn decode_peer_data(peer_data: &PeerData) -> Result { - let bytes = peer_data.as_bytes(); - if bytes.is_empty() { - return Ok(AddrInfo::default()); + async fn handle_api_command(&mut self, command: Result) { + let Ok(command) = command else { + return; + }; + trace!("tick: api command {command}"); + self.handle_in_event(InEvent::Command(command.into())).await; } - let info = postcard::from_bytes(bytes)?; - Ok(info) -} -async fn topic_subscriber_loop( - sender: irpc::channel::mpsc::Sender, - mut topic_events: broadcast::Receiver, -) { - loop { - tokio::select! { - biased; - msg = topic_events.recv() => { - let event = match msg { - Err(broadcast::error::RecvError::Closed) => break, - Err(broadcast::error::RecvError::Lagged(_)) => Event::Lagged, - Ok(event) => event.into(), - }; - if sender.send(event).await.is_err() { - break; - } - } - _ = sender.closed() => break, - } + async fn handle_in_event(&mut self, event: InEvent) { + trace!("tick: in event {event:?}"); + let now = Instant::now(); + self.metrics.track_in_event(&event); + self.out_events.extend(self.state.handle(event, now)); + self.process_out_events(now).await; } -} -/// A stream of commands for a gossip subscription. -type BoxedCommandReceiver = - n0_future::stream::Boxed>; + async fn process_out_events(&mut self, now: Instant) { + while let Some(event) = self.out_events.pop_front() { + trace!("tick: out event {event:?}"); + self.metrics.track_out_event(&event); + match event { + OutEvent::SendMessage(node_id, message) => { + self.send(node_id, message).await; + } + OutEvent::EmitEvent(event) => { + self.handle_event(event); + } + OutEvent::ScheduleTimer(delay, timer) => { + self.timers.insert(now + delay, timer); + } + OutEvent::DisconnectPeer(node_id) => { + self.remote_senders.remove(&node_id); + } + OutEvent::PeerData(node_id, peer_data) => { + self.to_actor_tx + .send(LocalActorMessage::SetPeerData(node_id, peer_data)) + .await + .ok(); + } + } + } + } -#[derive(derive_more::Debug)] -struct TopicCommandStream { - topic_id: TopicId, - #[debug("CommandStream")] - stream: BoxedCommandReceiver, - closed: bool, -} + #[instrument(skip_all, fields(remote=%remote.fmt_short()))] + async fn send(&mut self, remote: NodeId, message: ProtoMessage) { + let sender = match self.remote_senders.entry(remote) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + debug!("requesting new connection"); + self.to_actor_tx + .send(LocalActorMessage::Connect(remote, self.topic_id)) + .await + .ok(); + entry.insert(Default::default()) + } + }; + if let Err(err) = sender.send(message).await { + warn!(?err, remote=%remote.fmt_short(), "failed to send message"); + self.drop_peers_queue.insert(remote); + } + } -impl TopicCommandStream { - fn new(topic_id: TopicId, stream: BoxedCommandReceiver) -> Self { - Self { - topic_id, - stream, - closed: false, + fn handle_event(&mut self, event: ProtoEvent) { + match &event { + ProtoEvent::NeighborUp(n) => { + #[cfg(test)] + self.joined + .store(true, std::sync::atomic::Ordering::Relaxed); + self.neighbors.insert(*n); + } + ProtoEvent::NeighborDown(n) => { + self.neighbors.remove(n); + } + ProtoEvent::Received(_) => {} } + self.forward_event_tx.send(event).ok(); } } -impl Stream for TopicCommandStream { - type Item = (TopicId, Option); - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.closed { - return Poll::Ready(None); +async fn forward_events( + tx: channel::mpsc::Sender, + mut sub: broadcast::Receiver, + initial_neighbors: impl Iterator, +) { + for neighbor in initial_neighbors { + if let Err(_err) = tx.send(api::Event::NeighborUp(neighbor)).await { + break; } - match Pin::new(&mut self.stream).poll_next(cx) { - Poll::Ready(Some(Ok(item))) => Poll::Ready(Some((self.topic_id, Some(item)))), - Poll::Ready(None) | Poll::Ready(Some(Err(_))) => { - self.closed = true; - Poll::Ready(Some((self.topic_id, None))) - } - Poll::Pending => Poll::Pending, + } + loop { + let event = tokio::select! { + biased; + event = sub.recv() => event, + _ = tx.closed() => break + }; + let event: api::Event = match event { + Ok(event) => event.into(), + Err(broadcast::error::RecvError::Lagged(_)) => api::Event::Lagged, + Err(broadcast::error::RecvError::Closed) => break, + }; + if let Err(_err) = tx.send(event).await { + break; } } } #[derive(Debug)] -struct Dialer { - endpoint: Endpoint, - pending: JoinSet<( - NodeId, - Option>, - )>, - pending_dials: HashMap, +enum MaybeSender { + Active(Guarded>), + Pending(Vec), } -impl Dialer { - /// Create a new dialer for a [`Endpoint`] - fn new(endpoint: Endpoint) -> Self { - Self { - endpoint, - pending: Default::default(), - pending_dials: Default::default(), +impl MaybeSender { + async fn send(&mut self, message: ProtoMessage) -> Result<(), channel::SendError> { + match self { + Self::Active(sender) => sender.send(message).await, + Self::Pending(messages) => { + messages.push(message); + Ok(()) + } } } - /// Starts to dial a node by [`NodeId`]. - fn queue_dial(&mut self, node_id: NodeId, alpn: Bytes) { - if self.is_pending(node_id) { - return; - } - let cancel = CancellationToken::new(); - self.pending_dials.insert(node_id, cancel.clone()); - let endpoint = self.endpoint.clone(); - self.pending.spawn( - async move { - let res = tokio::select! { - biased; - _ = cancel.cancelled() => None, - res = endpoint.connect(node_id, &alpn) => Some(res), - }; - (node_id, res) + async fn init( + &mut self, + sender: Guarded>, + ) -> Result<(), channel::SendError> { + debug!("Initializing new sender"); + *self = match self { + Self::Active(_old) => { + debug!("Dropping old sender"); + Self::Active(sender) } - .instrument(tracing::Span::current()), - ); + Self::Pending(queue) => { + debug!("Sending {} queued messages", queue.len()); + for msg in queue.drain(..) { + sender.send(msg).await?; + } + Self::Active(sender) + } + }; + Ok(()) } +} - /// Checks if a node is currently being dialed. - fn is_pending(&self, node: NodeId) -> bool { - self.pending_dials.contains_key(&node) +impl Default for MaybeSender { + fn default() -> Self { + Self::Pending(Vec::new()) } +} - /// Waits for the next dial operation to complete. - /// `None` means disconnected - async fn next_conn( - &mut self, - ) -> ( - NodeId, - Option>, - ) { - match self.pending_dials.is_empty() { - false => { - let (node_id, res) = loop { - match self.pending.join_next().await { - Some(Ok((node_id, res))) => { - self.pending_dials.remove(&node_id); - break (node_id, res); - } - Some(Err(e)) => { - error!("next conn error: {:?}", e); - } - None => { - error!("no more pending conns available"); - std::future::pending().await - } - } - }; +// TODO: Upstream to irpc: This differs from Receiver::into_stream: it returns +// None after the first error, whereas upstream would loop on the error +fn into_stream( + receiver: impl DerefMut> + Send + Sync + 'static, +) -> impl Stream, RecvError>> + Send + Sync + 'static { + n0_future::stream::unfold(Some(receiver), |recv| async move { + let mut recv = recv?; + let res = recv.recv().await; + match res { + Err(err) => Some((Err(err), None)), + Ok(Some(res)) => Some((Ok(Some(res)), Some(recv))), + Ok(None) => Some((Ok(None), None)), + } + }) +} - (node_id, res) - } - true => std::future::pending().await, +fn into_stream2( + receiver: channel::mpsc::Receiver, +) -> impl Stream> + Send + Sync + 'static { + n0_future::stream::unfold(Some(receiver), |recv| async move { + let mut recv = recv?; + match recv.recv().await { + Err(err) => Some((Err(err), None)), + Ok(Some(res)) => Some((Ok(res), Some(recv))), + Ok(None) => None, } - } + }) } #[cfg(test)] -pub(crate) mod test { - use std::time::Duration; +pub(crate) mod tests { + use std::{future::Future, time::Duration}; use bytes::Bytes; use futures_concurrency::future::TryJoin; use iroh::{ discovery::static_provider::StaticProvider, endpoint::BindError, protocol::Router, - RelayMap, RelayMode, SecretKey, + NodeAddr, RelayMap, RelayMode, SecretKey, }; use n0_snafu::{Result, ResultExt}; - use rand::Rng; + use rand::{CryptoRng, Rng, SeedableRng}; use tokio::{spawn, time::timeout}; use tokio_util::sync::CancellationToken; - use tracing::{info, instrument}; + use tracing::info; use tracing_test::traced_test; use super::*; - use crate::api::ApiError; - - struct ManualActorLoop { - actor: Actor, - step: usize, - } - - impl std::ops::Deref for ManualActorLoop { - type Target = Actor; - - fn deref(&self) -> &Self::Target { - &self.actor - } - } - - impl std::ops::DerefMut for ManualActorLoop { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.actor - } - } - - type EndpointHandle = tokio::task::JoinHandle>; - - impl ManualActorLoop { - #[instrument(skip_all, fields(me = %actor.endpoint.node_id().fmt_short()))] - async fn new(mut actor: Actor) -> Self { - let _ = actor.setup().await; - Self { actor, step: 0 } - } - - #[instrument(skip_all, fields(me = %self.endpoint.node_id().fmt_short()))] - async fn step(&mut self) -> bool { - let ManualActorLoop { actor, step } = self; - *step += 1; - // ignore updates that change our published address. This gives us better control over - // events since the endpoint it no longer emitting changes - let addr_update_stream = &mut futures_lite::stream::pending(); - actor.event_loop(addr_update_stream, *step).await - } - - async fn steps(&mut self, n: usize) { - for _ in 0..n { - self.step().await; - } - } - - async fn finish(mut self) { - while self.step().await {} - } - } + use crate::{ + api::{ApiError, Event, GossipReceiver, GossipSender}, + ALPN, + }; impl Gossip { - /// Creates a testing gossip instance and its actor without spawning it. - /// - /// This creates the endpoint and spawns the endpoint loop as well. The handle for the - /// endpoing task is returned along the gossip instance and actor. Since the actor is not - /// actually spawned as [`Builder::spawn`] would, the gossip instance will have a - /// handle to a dummy task instead. - async fn t_new_with_actor( + pub(super) async fn t_new( rng: &mut rand_chacha::ChaCha12Rng, config: proto::Config, relay_map: RelayMap, cancel: &CancellationToken, - ) -> Result<(Self, Actor, EndpointHandle), BindError> { - let endpoint = create_endpoint(rng, relay_map, None).await?; - let metrics = Arc::new(Metrics::default()); - - let (actor, to_actor_tx, conn_tx) = Actor::new(endpoint, config, metrics.clone(), None); - let max_message_size = actor.state.max_message_size(); - - let _actor_handle = - AbortOnDropHandle::new(task::spawn(futures_lite::future::pending())); - let gossip = Self { - inner: Inner { - api: GossipApi::local(to_actor_tx), - local_tx: conn_tx, - _actor_handle, - max_message_size, - metrics, - } - .into(), - }; - - let endpoint_task = task::spawn(endpoint_loop( - actor.endpoint.clone(), - gossip.clone(), - cancel.child_token(), - )); - - Ok((gossip, actor, endpoint_task)) + ) -> n0_snafu::Result<(Self, Endpoint, impl Future, impl Drop)> { + let (gossip, actor, ep_handle) = + Gossip::t_new_with_actor(rng, config, relay_map, cancel).await?; + let ep = actor.endpoint().clone(); + let me = ep.node_id().fmt_short(); + let actor_handle = + task::spawn(actor.run().instrument(tracing::error_span!("gossip", %me))); + Ok((gossip, ep, ep_handle, AbortOnDropHandle::new(actor_handle))) } - - /// Crates a new testing gossip instance with the normal actor loop. - async fn t_new( + pub(super) async fn t_new_with_actor( rng: &mut rand_chacha::ChaCha12Rng, config: proto::Config, relay_map: RelayMap, cancel: &CancellationToken, - ) -> Result<(Self, Endpoint, EndpointHandle, impl Drop + use<>), BindError> { - let (g, actor, ep_handle) = - Gossip::t_new_with_actor(rng, config, relay_map, cancel).await?; - let ep = actor.endpoint.clone(); - let me = ep.node_id().fmt_short(); - let actor_handle = - task::spawn(actor.run().instrument(tracing::error_span!("gossip", %me))); - Ok((g, ep, ep_handle, AbortOnDropHandle::new(actor_handle))) + ) -> n0_snafu::Result<(Self, Actor, impl Future)> { + let endpoint = Endpoint::builder() + .secret_key(SecretKey::generate(rng)) + .relay_mode(RelayMode::Custom(relay_map)) + .insecure_skip_relay_cert_verify(true) + .bind() + .await?; + + endpoint.online().await; + let (gossip, mut actor) = Gossip::new_with_actor(endpoint.clone(), config, None); + actor.node_addr_updates = Box::pin(n0_future::stream::pending()); + let router = Router::builder(endpoint) + .accept(GOSSIP_ALPN, gossip.clone()) + .spawn(); + let cancel = cancel.clone(); + let router_task = tokio::task::spawn(async move { + cancel.cancelled().await; + router.shutdown().await.ok(); + drop(router); + }); + let router_fut = async move { + router_task.await.expect("router task panicked"); + }; + Ok((gossip, actor, router_fut)) } } @@ -1199,7 +1113,7 @@ pub(crate) mod test { ) -> Result { let ep = Endpoint::builder() .secret_key(SecretKey::generate(rng)) - .alpns(vec![GOSSIP_ALPN.to_vec()]) + .alpns(vec![ALPN.to_vec()]) .relay_mode(RelayMode::Custom(relay_map)) .insecure_skip_relay_cert_verify(true) .bind() @@ -1233,7 +1147,9 @@ pub(crate) mod test { continue; } }; - gossip.handle_connection(connecting.await.e()?).await? + let connection = connecting.await.e()?; + let remote_node_id = connection.remote_node_id()?; + gossip.handle_connection(remote_node_id, connection).await? } } } @@ -1374,7 +1290,7 @@ pub(crate) mod test { /// - Subscribe both nodes to the same topic. The first node will subscribe twice and connect /// to the second node. The second node will subscribe without bootstrap. /// - Ensure that the first node removes the subscription iff all topic handles have been - /// dropped + /// dropped. // NOTE: this is a regression test. #[tokio::test] #[traced_test] @@ -1384,15 +1300,14 @@ pub(crate) mod test { let (relay_map, relay_url, _guard) = iroh::test_utils::run_relay_server().await.unwrap(); // create the first node with a manual actor loop - let (go1, actor, ep1_handle) = + let (go1, mut actor, ep1_handle) = Gossip::t_new_with_actor(rng, Default::default(), relay_map.clone(), &ct).await?; - let mut actor = ManualActorLoop::new(actor).await; // create the second node with the usual actor loop let (go2, ep2, ep2_handle, _test_actor_handle) = Gossip::t_new(rng, Default::default(), relay_map, &ct).await?; - let node_id1 = actor.endpoint.node_id(); + let node_id1 = actor.endpoint().node_id(); let node_id2 = ep2.node_id(); tracing::info!( node_1 = %node_id1.fmt_short(), @@ -1423,7 +1338,7 @@ pub(crate) mod test { } tracing::debug!("subscribe stream ended"); - Result::<_, n0_snafu::Error>::Ok(()) + Ok::<_, n0_snafu::Error>(()) }; tokio::select! { @@ -1431,14 +1346,14 @@ pub(crate) mod test { res = subscribe_fut => res, } } - .instrument(tracing::debug_span!("node_2", %node_id2)); + .instrument(tracing::debug_span!("node_2", node_id2=%node_id2.fmt_short())); let go2_handle = task::spawn(go2_task); // first node let addr2 = NodeAddr::new(node_id2).with_relay_url(relay_url); let static_provider = StaticProvider::new(); static_provider.add_node_info(addr2); - actor.endpoint.discovery().add(static_provider); + actor.endpoint().discovery().add(static_provider); // we use a channel to signal advancing steps to the task let (tx, mut rx) = mpsc::channel::<()>(1); let ct1 = ct.clone(); @@ -1462,34 +1377,37 @@ pub(crate) mod test { ct1.cancelled().await; drop(go1); - Result::<_, n0_snafu::Error>::Ok(()) + Ok::<_, n0_snafu::Error>(()) } - .instrument(tracing::debug_span!("node_1", %node_id1)); + .instrument(tracing::debug_span!("node_1", node_id1 = %node_id1.fmt_short())); let go1_handle = task::spawn(go1_task); // advance and check that the topic is now subscribed - actor.steps(3).await; // handle our subscribe; - // get peer connection; - // receive the other peer's information for a NeighborUp + actor.steps(4).await?; // api_rx subscribe; + // internal_rx connection request (from topic actor); + // dialer connected; + // internal_rx update peer data (from topic actor); + tracing::info!("subscribe and join done, should be joined"); let state = actor.topics.get(&topic).expect("get registered topic"); assert!(state.joined()); // signal the second subscribe, we should remain subscribed tx.send(()).await.e()?; - actor.steps(3).await; // subscribe; first receiver gone; first sender gone + actor.steps(1).await?; // api_rx subscribe; let state = actor.topics.get(&topic).expect("get registered topic"); assert!(state.joined()); // signal to drop the second handle, the topic should no longer be subscribed tx.send(()).await.e()?; - actor.steps(2).await; // second receiver gone; second sender gone + actor.steps(1).await?; // topic task finished + assert!(!actor.topics.contains_key(&topic)); // cleanup and ensure everything went as expected ct.cancel(); - let wait = Duration::from_secs(2); - timeout(wait, ep1_handle).await.e()?.e()??; - timeout(wait, ep2_handle).await.e()?.e()??; + let wait = Duration::from_secs(5); + timeout(wait, ep1_handle).await.e()?; + timeout(wait, ep2_handle).await.e()?; timeout(wait, go1_handle).await.e()?.e()??; timeout(wait, go2_handle).await.e()?.e()??; timeout(wait, actor.finish()).await.e()?; @@ -1503,7 +1421,7 @@ pub(crate) mod test { /// unsubscribe and then resubscribe and connection between the nodes should succeed both /// times. // NOTE: This is a regression test - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] #[traced_test] async fn can_reconnect() -> Result { let rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1); @@ -1527,7 +1445,6 @@ pub(crate) mod test { let topic: TopicId = blake3::hash(b"can_reconnect").into(); tracing::info!(%topic, "joining"); - let ct2 = ct.child_token(); // channel used to signal the second gossip instance to advance the test let (tx, mut rx) = mpsc::channel::<()>(1); let addr1 = NodeAddr::new(node_id1).with_relay_url(relay_url.clone()); @@ -1547,13 +1464,10 @@ pub(crate) mod test { let mut sub = go2.subscribe(topic, vec![node_id1]).await?; sub.joined().await?; - tracing::info!("subscription successful!"); - - ct2.cancelled().await; Result::<_, ApiError>::Ok(()) } - .instrument(tracing::debug_span!("node_2", %node_id2)); + .instrument(tracing::debug_span!("node_2", node_id2=%node_id2.fmt_short())); let go2_handle = task::spawn(go2_task); let addr2 = NodeAddr::new(node_id2).with_relay_url(relay_url); @@ -1563,12 +1477,14 @@ pub(crate) mod test { let mut sub = go1.subscribe(topic, vec![node_id2]).await?; // wait for subscribed notification sub.joined().await?; + info!("go1 joined"); // signal node_2 to unsubscribe tx.send(()).await.e()?; + info!("wait for neighbor down"); // we should receive a Neighbor down event - let conn_timeout = Duration::from_millis(500); + let conn_timeout = Duration::from_millis(1000); let ev = timeout(conn_timeout, sub.try_next()).await.e()??; assert_eq!(ev, Some(Event::NeighborDown(node_id2))); tracing::info!("node 2 left"); @@ -1576,19 +1492,21 @@ pub(crate) mod test { // signal node_2 to subscribe again tx.send(()).await.e()?; - let conn_timeout = Duration::from_millis(500); + info!("wait for neighbor up"); + let conn_timeout = Duration::from_millis(1000); let ev = timeout(conn_timeout, sub.try_next()).await.e()??; assert_eq!(ev, Some(Event::NeighborUp(node_id2))); tracing::info!("node 2 rejoined!"); - // cleanup and ensure everything went as expected - ct.cancel(); + // wait for go2 to also be rejoined, then the task terminates let wait = Duration::from_secs(2); - timeout(wait, ep1_handle).await.e()?.e()??; - timeout(wait, ep2_handle).await.e()?.e()??; timeout(wait, go2_handle).await.e()?.e()??; + ct.cancel(); + // cleanup and ensure everything went as expected + timeout(wait, ep1_handle).await.e()?; + timeout(wait, ep2_handle).await.e()?; - Result::Ok(()) + Ok(()) } #[tokio::test] @@ -1621,9 +1539,7 @@ pub(crate) mod test { .bind() .await?; let gossip = Gossip::builder().spawn(ep.clone()); - let router = Router::builder(ep) - .accept(GOSSIP_ALPN, gossip.clone()) - .spawn(); + let router = Router::builder(ep).accept(ALPN, gossip.clone()).spawn(); Ok((router, gossip)) } @@ -1761,4 +1677,77 @@ pub(crate) mod test { router2.shutdown().await.e()?; Ok(()) } + + #[tokio::test] + #[traced_test] + async fn gossip_rely_on_gossip_discovery() -> n0_snafu::Result<()> { + let rng = &mut rand_chacha::ChaCha12Rng::seed_from_u64(1); + + async fn spawn( + rng: &mut impl CryptoRng, + ) -> n0_snafu::Result<(NodeId, Router, Gossip, GossipSender, GossipReceiver)> { + let topic_id = TopicId::from([0u8; 32]); + let ep = Endpoint::builder() + .secret_key(SecretKey::generate(rng)) + .relay_mode(RelayMode::Disabled) + .bind() + .await?; + let node_id = ep.node_id(); + let gossip = Gossip::builder().spawn(ep.clone()); + let router = Router::builder(ep) + .accept(GOSSIP_ALPN, gossip.clone()) + .spawn(); + let topic = gossip.subscribe(topic_id, vec![]).await?; + let (sender, receiver) = topic.split(); + Ok((node_id, router, gossip, sender, receiver)) + } + + // spawn 3 nodes without relay or discovery + let (n1, r1, _g1, _tx1, mut rx1) = spawn(rng).await?; + let (n2, r2, _g2, tx2, mut rx2) = spawn(rng).await?; + let (n3, r3, _g3, tx3, mut rx3) = spawn(rng).await?; + + println!("nodes {:?}", [n1, n2, n3]); + + // create a static discovery that has only node 1 addr info set + let addr1 = r1.endpoint().node_addr(); + let disco = StaticProvider::new(); + disco.add_node_info(addr1); + + // add addr info of node1 to node2 and join node1 + r2.endpoint().discovery().add(disco.clone()); + tx2.join_peers(vec![n1]).await?; + + // await join node2 -> nodde1 + timeout(Duration::from_secs(3), rx1.joined()).await.e()??; + timeout(Duration::from_secs(3), rx2.joined()).await.e()??; + + // add addr info of node1 to node3 and join node1 + r3.endpoint().discovery().add(disco.clone()); + tx3.join_peers(vec![n1]).await?; + + // await join at node3: n1 and n2 + // n2 only works because because we use gossip discovery! + let ev = timeout(Duration::from_secs(3), rx3.next()).await.e()?; + assert!(matches!(ev, Some(Ok(Event::NeighborUp(_))))); + let ev = timeout(Duration::from_secs(3), rx3.next()).await.e()?; + assert!(matches!(ev, Some(Ok(Event::NeighborUp(_))))); + + assert_eq!(sorted(rx3.neighbors()), sorted([n1, n2])); + + let ev = timeout(Duration::from_secs(3), rx2.next()).await.e()?; + assert!(matches!(ev, Some(Ok(Event::NeighborUp(n))) if n == n3)); + + let ev = timeout(Duration::from_secs(3), rx1.next()).await.e()?; + assert!(matches!(ev, Some(Ok(Event::NeighborUp(n))) if n == n3)); + + tokio::try_join!(r1.shutdown(), r2.shutdown(), r3.shutdown()).e()?; + Ok(()) + } + + fn sorted(input: impl IntoIterator) -> Vec { + let mut out: Vec<_> = input.into_iter().collect(); + out.sort(); + out + } } diff --git a/src/net/dialer.rs b/src/net/dialer.rs new file mode 100644 index 00000000..f2b9cee1 --- /dev/null +++ b/src/net/dialer.rs @@ -0,0 +1,46 @@ +use std::collections::HashSet; + +use bytes::Bytes; +use iroh::{ + endpoint::{ConnectError, Connection}, + Endpoint, NodeId, +}; +use tokio::task::JoinSet; +use tracing::Instrument; + +#[derive(Debug, Default)] +pub(crate) struct Dialer { + pending_tasks: JoinSet<(NodeId, Result)>, + pending_nodes: HashSet, +} + +impl Dialer { + /// Starts to dial a node by [`NodeId`]. + pub(crate) fn queue_dial(&mut self, endpoint: &Endpoint, node_id: NodeId, alpn: Bytes) { + if self.pending_nodes.insert(node_id) { + let endpoint = endpoint.clone(); + let fut = async move { (node_id, endpoint.connect(node_id, &alpn).await) } + .instrument(tracing::Span::current()); + self.pending_tasks.spawn(fut); + } + } + + pub(crate) fn is_empty(&self) -> bool { + self.pending_tasks.is_empty() + } + + /// Waits for the next dial operation to complete. + /// `None` means disconnected + /// + /// Will be pending forever if no connections are in progress. + pub(crate) async fn next(&mut self) -> Option<(NodeId, Result)> { + match self.pending_tasks.join_next().await { + Some(res) => { + let (node_id, res) = res.expect("connect task panicked"); + self.pending_nodes.remove(&node_id); + Some((node_id, res)) + } + None => None, + } + } +} diff --git a/src/net/discovery.rs b/src/net/discovery.rs new file mode 100644 index 00000000..0f7dcc2b --- /dev/null +++ b/src/net/discovery.rs @@ -0,0 +1,170 @@ +//! A discovery service to gather addressing info collected from gossip Join and ForwardJoin messages. + +use std::{ + collections::{btree_map::Entry, BTreeMap}, + sync::{Arc, RwLock}, + time::Duration, +}; + +use iroh::discovery::{Discovery, DiscoveryError, DiscoveryItem, NodeData, NodeInfo}; +use iroh_base::NodeId; +use n0_future::{ + boxed::BoxStream, + stream::{self, StreamExt}, + task::AbortOnDropHandle, + time::SystemTime, +}; + +pub(crate) struct RetentionOpts { + retention: Duration, + check_interval: Duration, +} + +impl Default for RetentionOpts { + fn default() -> Self { + Self { + retention: Duration::from_secs(60 * 5), + check_interval: Duration::from_secs(30), + } + } +} + +/// A static node discovery that expires nodes after some time. +/// +/// It is added to the endpoint when constructing a gossip instance, and the gossip actor +/// then adds node addresses as received with Join or ForwardJoin messages. +#[derive(Debug, Clone)] +pub(crate) struct GossipDiscovery { + nodes: NodeMap, + _task_handle: Arc>, +} + +type NodeMap = Arc>>; + +#[derive(Debug)] +struct StoredNodeInfo { + data: NodeData, + last_updated: SystemTime, +} + +impl Default for GossipDiscovery { + fn default() -> Self { + Self::new() + } +} + +impl GossipDiscovery { + const PROVENANCE: &'static str = "gossip"; + + /// Creates a new gossip discovery instance. + pub(crate) fn new() -> Self { + Self::with_opts(Default::default()) + } + + pub(crate) fn with_opts(opts: RetentionOpts) -> Self { + let nodes: NodeMap = Default::default(); + let task = { + let nodes = Arc::downgrade(&nodes); + n0_future::task::spawn(async move { + loop { + n0_future::time::sleep(opts.check_interval).await; + let Some(nodes) = nodes.upgrade() else { + break; + }; + let now = SystemTime::now(); + nodes.write().expect("poisoned").retain(|_k, v| { + let age = now.duration_since(v.last_updated).unwrap_or(Duration::MAX); + age <= opts.retention + }); + } + }) + }; + Self { + nodes, + _task_handle: Arc::new(AbortOnDropHandle::new(task)), + } + } + + /// Augments node addressing information for the given node ID. + /// + /// The provided addressing information is combined with the existing info in the static + /// provider. Any new direct addresses are added to those already present while the + /// relay URL is overwritten. + pub(crate) fn add(&self, node_info: impl Into) { + let last_updated = SystemTime::now(); + let NodeInfo { node_id, data } = node_info.into(); + let mut guard = self.nodes.write().expect("poisoned"); + match guard.entry(node_id) { + Entry::Occupied(mut entry) => { + let existing = entry.get_mut(); + existing + .data + .add_direct_addresses(data.direct_addresses().iter().copied()); + existing.data.set_relay_url(data.relay_url().cloned()); + existing.data.set_user_data(data.user_data().cloned()); + existing.last_updated = last_updated; + } + Entry::Vacant(entry) => { + entry.insert(StoredNodeInfo { data, last_updated }); + } + } + } +} + +impl Discovery for GossipDiscovery { + fn resolve(&self, node_id: NodeId) -> Option>> { + let guard = self.nodes.read().expect("poisoned"); + let info = guard.get(&node_id)?; + let last_updated = info + .last_updated + .duration_since(SystemTime::UNIX_EPOCH) + .expect("time drift") + .as_micros() as u64; + let item = DiscoveryItem::new( + NodeInfo::from_parts(node_id, info.data.clone()), + Self::PROVENANCE, + Some(last_updated), + ); + Some(stream::iter(Some(Ok(item))).boxed()) + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use iroh::{discovery::Discovery, NodeAddr, SecretKey}; + use n0_future::StreamExt; + + use super::{GossipDiscovery, RetentionOpts}; + + #[tokio::test] + async fn test_retention() { + let opts = RetentionOpts { + check_interval: Duration::from_millis(100), + retention: Duration::from_millis(500), + }; + let disco = GossipDiscovery::with_opts(opts); + + let k1 = SecretKey::generate(&mut rand::rng()); + let a1 = NodeAddr::new(k1.public()); + + disco.add(a1); + + assert!(matches!( + disco.resolve(k1.public()).unwrap().next().await, + Some(Ok(_)) + )); + + tokio::time::sleep(Duration::from_millis(200)).await; + + assert!(matches!( + disco.resolve(k1.public()).unwrap().next().await, + Some(Ok(_)) + )); + + tokio::time::sleep(Duration::from_millis(700)).await; + + assert!(disco.resolve(k1.public()).is_none()); + } +} diff --git a/src/net/util.rs b/src/net/util.rs index 4a02fd75..1d1af92c 100644 --- a/src/net/util.rs +++ b/src/net/util.rs @@ -1,393 +1,113 @@ //! Utilities for iroh-gossip networking use std::{ - collections::{hash_map, HashMap}, - io, + collections::BTreeSet, + net::SocketAddr, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, time::Duration, }; -use bytes::{Bytes, BytesMut}; -use iroh::{ - endpoint::{Connection, RecvStream, SendStream}, - NodeId, -}; +use iroh::{endpoint::Connection, NodeAddr, NodeId, RelayUrl}; +use irpc::rpc::RemoteService; use n0_future::{ + future::Boxed as BoxFuture, time::{sleep_until, Instant}, - FuturesUnordered, StreamExt, -}; -use nested_enum_utils::common_fields; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use snafu::Snafu; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - sync::mpsc, - task::JoinSet, + Stream, }; -use tracing::{debug, trace, Instrument}; - -use super::{InEvent, ProtoMessage}; -use crate::proto::{util::TimerMap, TopicId}; - -/// Errors related to message writing -#[allow(missing_docs)] -#[common_fields({ - backtrace: Option, -})] -#[derive(Debug, Snafu)] -#[snafu(module)] -#[non_exhaustive] -pub(crate) enum WriteError { - /// Connection error - #[snafu(transparent)] - Connection { - source: iroh::endpoint::ConnectionError, - }, - /// Serialization failed - #[snafu(transparent)] - Ser { source: postcard::Error }, - /// IO error - #[snafu(transparent)] - Io { source: std::io::Error }, - /// Message was larger than the configured maximum message size - #[snafu(display("message too large"))] - TooLarge {}, -} +use serde::{Deserialize, Serialize}; +use tokio::sync::Notify; -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct StreamHeader { - pub(crate) topic_id: TopicId, -} +use crate::proto::{util::TimerMap, PeerData}; -impl StreamHeader { - pub(crate) async fn read( - stream: &mut RecvStream, - buffer: &mut BytesMut, - max_message_size: usize, - ) -> Result { - let header: Self = read_frame(stream, buffer, max_message_size) - .await? - .ok_or_else(|| { - ReadError::from(io::Error::new( - io::ErrorKind::UnexpectedEof, - "stream ended before header", - )) - })?; - Ok(header) - } +/// A connection to a remote service. +#[derive(Debug, Clone)] +pub struct IrohRemoteConnection(Connection); - pub(crate) async fn write( - self, - stream: &mut SendStream, - buffer: &mut Vec, - max_message_size: usize, - ) -> Result<(), WriteError> { - write_frame(stream, &self, buffer, max_message_size).await?; - Ok(()) +impl IrohRemoteConnection { + pub fn new(connection: Connection) -> Self { + Self(connection) } } -pub(crate) struct RecvLoop { - remote_node_id: NodeId, - conn: Connection, - max_message_size: usize, - in_event_tx: mpsc::Sender, -} - -impl RecvLoop { - pub(crate) fn new( - remote_node_id: NodeId, - conn: Connection, - in_event_tx: mpsc::Sender, - max_message_size: usize, - ) -> Self { - Self { - remote_node_id, - conn, - max_message_size, - in_event_tx, - } +impl irpc::rpc::RemoteConnection for IrohRemoteConnection { + fn clone_boxed(&self) -> Box { + Box::new(self.clone()) } - pub(crate) async fn run(&mut self) -> Result<(), ReadError> { - let mut read_futures = FuturesUnordered::new(); - let mut conn_is_closed = false; - let closed = self.conn.closed(); - tokio::pin!(closed); - while !conn_is_closed || !read_futures.is_empty() { - tokio::select! { - _ = &mut closed, if !conn_is_closed => { - conn_is_closed = true; - } - stream = self.conn.accept_uni(), if !conn_is_closed => { - let stream = match stream { - Ok(stream) => stream, - Err(_) => { - conn_is_closed = true; - continue; - } - }; - let state = RecvStreamState::new(stream, self.max_message_size).await?; - debug!(topic=%state.header.topic_id.fmt_short(), "stream opened"); - read_futures.push(state.next()); - } - Some(res) = read_futures.next(), if !read_futures.is_empty() => { - let (state, msg) = match res { - Ok((state, msg)) => (state, msg), - Err(err) => { - debug!("recv stream closed with error: {err:#}"); - continue; - } - }; - match msg { - None => debug!(topic=%state.header.topic_id.fmt_short(), "stream closed"), - Some(msg) => { - if self.in_event_tx.send(InEvent::RecvMessage(self.remote_node_id, msg)).await.is_err() { - debug!("stop recv loop: actor closed"); - break; - } - read_futures.push(state.next()); - } - } - } - } - } - debug!("recv loop closed"); - Ok(()) - } -} - -#[derive(Debug)] -struct RecvStreamState { - stream: RecvStream, - header: StreamHeader, - buffer: BytesMut, - max_message_size: usize, -} - -impl RecvStreamState { - async fn new(mut stream: RecvStream, max_message_size: usize) -> Result { - let mut buffer = BytesMut::new(); - let header = StreamHeader::read(&mut stream, &mut buffer, max_message_size).await?; - Ok(Self { - buffer: BytesMut::new(), - max_message_size, - stream, - header, + fn open_bi( + &self, + ) -> BoxFuture< + Result<(iroh::endpoint::SendStream, iroh::endpoint::RecvStream), irpc::RequestError>, + > { + let this = self.0.clone(); + Box::pin(async move { + let pair = this.open_bi().await?; + Ok(pair) }) } +} - /// Reads the next message from the stream. - /// - /// Returns `self` and the next message, or `None` if the stream ended gracefully. - /// - /// ## Cancellation safety - /// - /// This function is not cancellation-safe. - async fn next(mut self) -> Result<(Self, Option), ReadError> { - let msg = read_frame(&mut self.stream, &mut self.buffer, self.max_message_size).await?; - let msg = msg.map(|msg| ProtoMessage { - topic: self.header.topic_id, - message: msg, - }); - Ok((self, msg)) - } +pub(crate) fn accept_stream( + connection: Connection, +) -> impl Stream> { + n0_future::stream::unfold(Some(connection), async |conn| { + let conn = conn?; + match irpc_iroh::read_request::(&conn).await { + Err(err) => Some((Err(err), None)), + Ok(None) => None, + Ok(Some(request)) => Some((Ok(request), Some(conn))), + } + }) } -pub(crate) struct SendLoop { - conn: Connection, - streams: HashMap, - buffer: Vec, - max_message_size: usize, - finishing: JoinSet<()>, - send_rx: mpsc::Receiver, +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub(crate) struct AddrInfo { + pub(crate) relay_url: Option, + pub(crate) direct_addresses: BTreeSet, } -impl SendLoop { - pub(crate) fn new( - conn: Connection, - send_rx: mpsc::Receiver, - max_message_size: usize, +impl From for AddrInfo { + fn from( + NodeAddr { + relay_url, + direct_addresses, + .. + }: NodeAddr, ) -> Self { Self { - conn, - max_message_size, - buffer: Default::default(), - streams: Default::default(), - finishing: Default::default(), - send_rx, + relay_url, + direct_addresses, } } +} - pub(crate) async fn run(&mut self, queue: Vec) -> Result<(), WriteError> { - for msg in queue { - self.write_message(&msg).await?; - } - let conn_clone = self.conn.clone(); - let closed = conn_clone.closed(); - tokio::pin!(closed); - loop { - tokio::select! { - biased; - _ = &mut closed => break, - Some(msg) = self.send_rx.recv() => self.write_message(&msg).await?, - _ = self.finishing.join_next(), if !self.finishing.is_empty() => {} - else => break, - } - } - - // Close remaining streams. - for (topic_id, mut stream) in self.streams.drain() { - stream.finish().ok(); - self.finishing.spawn( - async move { - stream.stopped().await.ok(); - debug!(topic=%topic_id.fmt_short(), "stream closed"); - } - .instrument(tracing::Span::current()), - ); - } - if !self.finishing.is_empty() { - trace!( - "send loop closing, waiting for {} send streams to finish", - self.finishing.len() - ); - // Wait for the remote to acknowledge all streams are finished. - if let Err(_elapsed) = n0_future::time::timeout(Duration::from_secs(5), async move { - while self.finishing.join_next().await.is_some() {} - }) - .await - { - debug!("not all send streams finished within timeout, abort") - } - } - debug!("send loop closed"); - Ok(()) +impl AddrInfo { + pub(crate) fn encode(&self) -> PeerData { + let bytes = postcard::to_stdvec(self).expect("serializing AddrInfo may not fail"); + PeerData::new(bytes) } - /// Write a [`ProtoMessage`] as a length-prefixed, postcard-encoded message on its stream. - /// - /// If no stream is opened yet, this opens a new stream for the topic and writes the topic header. - /// - /// This function is not cancellation-safe. - pub async fn write_message(&mut self, message: &ProtoMessage) -> Result<(), WriteError> { - let ProtoMessage { topic, message } = message; - let topic_id = *topic; - let is_last = message.is_disconnect(); - - let mut entry = match self.streams.entry(topic_id) { - hash_map::Entry::Occupied(entry) => entry, - hash_map::Entry::Vacant(entry) => { - let mut stream = self.conn.open_uni().await?; - let header = StreamHeader { topic_id }; - header - .write(&mut stream, &mut self.buffer, self.max_message_size) - .await?; - debug!(topic=%topic_id.fmt_short(), "stream opened"); - entry.insert_entry(stream) - } - }; - let stream = entry.get_mut(); - - write_frame(stream, message, &mut self.buffer, self.max_message_size).await?; - - if is_last { - trace!(topic=%topic_id.fmt_short(), "stream closing"); - let mut stream = entry.remove(); - if stream.finish().is_ok() { - self.finishing.spawn( - async move { - stream.stopped().await.ok(); - debug!(topic=%topic_id.fmt_short(), "stream closed"); - } - .instrument(tracing::Span::current()), - ); - } + pub(crate) fn decode(peer_data: &PeerData) -> Result { + let bytes = peer_data.as_bytes(); + if bytes.is_empty() { + return Ok(AddrInfo::default()); } - - Ok(()) + let info = postcard::from_bytes(bytes)?; + Ok(info) } -} - -/// Errors related to message reading -#[allow(missing_docs)] -#[common_fields({ - backtrace: Option, -})] -#[derive(Debug, Snafu)] -#[snafu(module)] -#[non_exhaustive] -pub(crate) enum ReadError { - /// Deserialization failed - #[snafu(transparent)] - De { source: postcard::Error }, - /// IO error - #[snafu(transparent)] - Io { source: std::io::Error }, - /// Message was larger than the configured maximum message size - #[snafu(display("message too large"))] - TooLarge {}, -} -/// Read a length-prefixed frame and decode with postcard. -pub async fn read_frame( - reader: &mut RecvStream, - buffer: &mut BytesMut, - max_message_size: usize, -) -> Result, ReadError> { - match read_lp(reader, buffer, max_message_size).await? { - None => Ok(None), - Some(data) => { - let message = postcard::from_bytes(&data)?; - Ok(Some(message)) + pub(crate) fn into_node_addr(self, node_id: NodeId) -> NodeAddr { + NodeAddr { + node_id, + relay_url: self.relay_url, + direct_addresses: self.direct_addresses, } } } -/// Reads a length prefixed buffer. -/// -/// Returns the frame as raw bytes. If the end of the stream is reached before -/// the frame length starts, `None` is returned. -pub async fn read_lp( - reader: &mut RecvStream, - buffer: &mut BytesMut, - max_message_size: usize, -) -> Result, ReadError> { - let size = match reader.read_u32().await { - Ok(size) => size, - Err(err) if err.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(err) => return Err(err.into()), - }; - let size = usize::try_from(size).map_err(|_| read_error::TooLargeSnafu.build())?; - if size > max_message_size { - return Err(read_error::TooLargeSnafu.build()); - } - buffer.resize(size, 0u8); - reader - .read_exact(&mut buffer[..]) - .await - .map_err(io::Error::other)?; - Ok(Some(buffer.split_to(size).freeze())) -} - -/// Writes a length-prefixed frame. -pub async fn write_frame( - stream: &mut SendStream, - message: &T, - buffer: &mut Vec, - max_message_size: usize, -) -> Result<(), WriteError> { - let len = postcard::experimental::serialized_size(&message)?; - if len >= max_message_size { - return Err(write_error::TooLargeSnafu.build()); - } - buffer.clear(); - buffer.resize(len, 0u8); - let slice = postcard::to_slice(&message, buffer)?; - stream.write_u32(len as u32).await?; - stream.write_all(slice).await.map_err(io::Error::other)?; - Ok(()) -} - /// A [`TimerMap`] with an async method to wait for the next timer expiration. #[derive(Debug)] pub struct Timers { @@ -403,11 +123,6 @@ impl Default for Timers { } impl Timers { - /// Creates a new timer map. - pub fn new() -> Self { - Self::default() - } - /// Inserts a new entry at the specified instant pub fn insert(&mut self, instant: Instant, item: T) { self.map.insert(instant, item); @@ -429,3 +144,100 @@ impl Timers { self.map.pop_before(now) } } + +#[derive(Debug)] +struct ConnectionCounterInner { + count: AtomicUsize, + notify: Notify, +} + +#[derive(Debug, Clone)] +pub(crate) struct ConnectionCounter { + inner: Arc, +} + +impl ConnectionCounter { + pub(crate) fn new() -> Self { + Self { + inner: Arc::new(ConnectionCounterInner { + count: Default::default(), + notify: Notify::new(), + }), + } + } + + /// Increase the connection count and return a guard for the new connection + pub(crate) fn get_one(&self) -> OneConnection { + self.inner.count.fetch_add(1, Ordering::SeqCst); + OneConnection { + inner: self.inner.clone(), + } + } + + pub(crate) fn guard(&self, item: T) -> Guarded { + Guarded::new(item, self.get_one()) + } + + pub(crate) fn is_idle(&self) -> bool { + self.inner.count.load(Ordering::SeqCst) == 0 + } + + pub(crate) async fn idle(&self) { + self.inner.notify.notified().await + } + + pub(crate) async fn idle_for(&self, duration: Duration) { + let fut = self.idle(); + tokio::pin!(fut); + loop { + (&mut fut).await; + fut.set(self.idle()); + tokio::time::sleep(duration).await; + if self.is_idle() { + break; + } + } + } +} + +/// Guard for one connection +#[derive(Debug)] +pub(crate) struct OneConnection { + inner: Arc, +} + +impl Clone for OneConnection { + fn clone(&self) -> Self { + self.inner.count.fetch_add(1, Ordering::SeqCst); + Self { + inner: self.inner.clone(), + } + } +} + +impl Drop for OneConnection { + fn drop(&mut self) { + let prev = self.inner.count.fetch_sub(1, Ordering::SeqCst); + if prev == 1 { + self.inner.notify.notify_waiters(); + } + } +} + +#[derive(derive_more::Deref, derive_more::DerefMut, Debug)] +pub(crate) struct Guarded { + #[deref] + #[deref_mut] + inner: T, + guard: OneConnection, +} + +impl Guarded { + pub(crate) fn new(inner: T, guard: OneConnection) -> Self { + Self { inner, guard } + } + + pub(crate) fn split(self) -> (T, OneConnection) { + (self.inner, self.guard) + } +} diff --git a/src/proto/topic.rs b/src/proto/topic.rs index a5af0efa..0c1158bd 100644 --- a/src/proto/topic.rs +++ b/src/proto/topic.rs @@ -5,7 +5,7 @@ use std::collections::VecDeque; use bytes::Bytes; use derive_more::From; use n0_future::time::{Duration, Instant}; -use rand::Rng; +use rand::{Rng, SeedableRng}; use serde::{Deserialize, Serialize}; use super::{ @@ -114,6 +114,13 @@ impl Message { } } +impl Message { + /// Get the encoded size of this message + pub fn size(&self) -> postcard::Result { + postcard::experimental::serialized_size(&self) + } +} + /// An event to be emitted to the application for a particular topic. #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)] pub enum Event { @@ -212,14 +219,14 @@ pub struct State { stats: Stats, } -impl State { +impl State { /// Initialize the local state with the default random number generator. /// /// ## Panics /// /// Panics if [`Config::max_message_size`] is below [`MIN_MAX_MESSAGE_SIZE`]. pub fn new(me: PI, me_data: Option, config: Config) -> Self { - Self::with_rng(me, me_data, config, rand::rng()) + Self::with_rng(me, me_data, config, rand::rngs::StdRng::from_os_rng()) } }