Skip to content

Commit 6c21f9c

Browse files
committed
cleanup
1 parent 739bb0b commit 6c21f9c

File tree

2 files changed

+34
-64
lines changed

2 files changed

+34
-64
lines changed

src/net.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -287,13 +287,14 @@ type AcceptRemoteRequestsStream =
287287

288288
struct Actor {
289289
me: NodeId,
290+
endpoint: Endpoint,
290291
alpn: Bytes,
291292
config: Config,
292293
local_rx: mpsc::Receiver<LocalActorMessage>,
293294
local_tx: mpsc::Sender<LocalActorMessage>,
294295
api_rx: mpsc::Receiver<api::RpcMessage>,
295296
topics: HashMap<TopicId, TopicHandle>,
296-
topic_pending_remotes: HashMap<NodeId, HashSet<TopicId>>,
297+
pending_remotes_with_topics: HashMap<NodeId, HashSet<TopicId>>,
297298
topic_tasks: JoinSet<TopicActor>,
298299
remotes: HashMap<NodeId, RemoteState>,
299300
close_connections: JoinSet<(NodeId, Connection)>,
@@ -332,18 +333,19 @@ impl Actor {
332333
api_tx,
333334
local_tx.clone(),
334335
Actor {
336+
endpoint,
335337
me,
336338
config,
337339
api_rx,
338340
local_tx,
339341
local_rx,
340342
node_addr_updates: Box::pin(node_addr_updates),
341-
dialer: Dialer::new(endpoint),
343+
dialer: Dialer::default(),
342344
our_peer_data: Watchable::new(initial_peer_data),
343345
alpn: alpn.unwrap_or_else(|| crate::ALPN.to_vec().into()),
344346
metrics: metrics.clone(),
345347
topics: Default::default(),
346-
topic_pending_remotes: Default::default(),
348+
pending_remotes_with_topics: Default::default(),
347349
remotes: Default::default(),
348350
close_connections: JoinSet::new(),
349351
topic_tasks: JoinSet::new(),
@@ -414,7 +416,7 @@ impl Actor {
414416
}
415417
true
416418
}
417-
(node_id, res) = self.dialer.next_conn() => {
419+
Some((node_id, res)) = self.dialer.next(), if !self.dialer.is_empty() => {
418420
trace!(remote=%node_id.fmt_short(), ok=res.is_ok(), "tick: dialed");
419421
self.handle_remote_connection(node_id, res, Direction::Dial).await;
420422
true
@@ -464,14 +466,14 @@ impl Actor {
464466

465467
#[cfg(test)]
466468
fn endpoint(&self) -> &Endpoint {
467-
self.dialer.endpoint()
469+
&self.endpoint
468470
}
469471

470472
fn drain_pending_dials(
471473
&mut self,
472474
remote: &NodeId,
473475
) -> impl Iterator<Item = (TopicId, &TopicHandle)> {
474-
self.topic_pending_remotes
476+
self.pending_remotes_with_topics
475477
.remove(remote)
476478
.into_iter()
477479
.flatten()
@@ -485,13 +487,15 @@ impl Actor {
485487
if let Some(state) = self.remotes.get(&remote) {
486488
let tx = handle.tx.clone();
487489
let state = state.clone();
490+
// TODO: Track task?
488491
task::spawn(async move {
489492
let msg = state.open_topic(topic_id).await;
490493
tx.send(msg).await.ok();
491494
});
492495
} else {
493-
self.dialer.queue_dial(remote, self.alpn.clone());
494-
self.topic_pending_remotes
496+
self.dialer
497+
.queue_dial(&self.endpoint, remote, self.alpn.clone());
498+
self.pending_remotes_with_topics
495499
.entry(remote)
496500
.or_default()
497501
.insert(topic_id);

src/net/dialer.rs

Lines changed: 22 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,46 @@
1-
use std::collections::HashMap;
1+
use std::collections::HashSet;
22

33
use bytes::Bytes;
44
use iroh::{
55
endpoint::{ConnectError, Connection},
66
Endpoint, NodeId,
77
};
88
use tokio::task::JoinSet;
9-
use tracing::{error, Instrument};
9+
use tracing::Instrument;
1010

11-
#[derive(Debug)]
11+
#[derive(Debug, Default)]
1212
pub(crate) struct Dialer {
13-
endpoint: Endpoint,
14-
pending: JoinSet<(NodeId, Result<Connection, ConnectError>)>,
15-
pending_dials: HashMap<NodeId, ()>,
13+
pending_tasks: JoinSet<(NodeId, Result<Connection, ConnectError>)>,
14+
pending_nodes: HashSet<NodeId>,
1615
}
1716

1817
impl Dialer {
19-
/// Create a new dialer for a [`Endpoint`]
20-
pub(crate) fn new(endpoint: Endpoint) -> Self {
21-
Self {
22-
endpoint,
23-
pending: Default::default(),
24-
pending_dials: Default::default(),
25-
}
26-
}
27-
28-
#[cfg(test)]
29-
pub(crate) fn endpoint(&self) -> &Endpoint {
30-
&self.endpoint
31-
}
32-
3318
/// Starts to dial a node by [`NodeId`].
34-
pub(crate) fn queue_dial(&mut self, node_id: NodeId, alpn: Bytes) {
35-
if self.is_pending(node_id) {
36-
return;
19+
pub(crate) fn queue_dial(&mut self, endpoint: &Endpoint, node_id: NodeId, alpn: Bytes) {
20+
if self.pending_nodes.insert(node_id) {
21+
let endpoint = endpoint.clone();
22+
let fut = async move { (node_id, endpoint.connect(node_id, &alpn).await) }
23+
.instrument(tracing::Span::current());
24+
self.pending_tasks.spawn(fut);
3725
}
38-
self.pending_dials.insert(node_id, ());
39-
let endpoint = self.endpoint.clone();
40-
self.pending.spawn(
41-
async move {
42-
let res = endpoint.connect(node_id, &alpn).await;
43-
(node_id, res)
44-
}
45-
.instrument(tracing::Span::current()),
46-
);
4726
}
4827

49-
/// Checks if a node is currently being dialed.
50-
pub(crate) fn is_pending(&self, node: NodeId) -> bool {
51-
self.pending_dials.contains_key(&node)
28+
pub(crate) fn is_empty(&self) -> bool {
29+
self.pending_tasks.is_empty()
5230
}
5331

5432
/// Waits for the next dial operation to complete.
5533
/// `None` means disconnected
56-
pub(crate) async fn next_conn(&mut self) -> (NodeId, Result<Connection, ConnectError>) {
57-
match self.pending_dials.is_empty() {
58-
false => {
59-
let (node_id, res) = loop {
60-
match self.pending.join_next().await {
61-
Some(Ok((node_id, res))) => {
62-
self.pending_dials.remove(&node_id);
63-
break (node_id, res);
64-
}
65-
Some(Err(e)) => {
66-
error!("next conn error: {:?}", e);
67-
}
68-
None => {
69-
error!("no more pending conns available");
70-
std::future::pending().await
71-
}
72-
}
73-
};
74-
75-
(node_id, res)
34+
///
35+
/// Will be pending forever if no connections are in progress.
36+
pub(crate) async fn next(&mut self) -> Option<(NodeId, Result<Connection, ConnectError>)> {
37+
match self.pending_tasks.join_next().await {
38+
Some(res) => {
39+
let (node_id, res) = res.expect("connect task panicked");
40+
self.pending_nodes.remove(&node_id);
41+
Some((node_id, res))
7642
}
77-
true => std::future::pending().await,
43+
None => None,
7844
}
7945
}
8046
}

0 commit comments

Comments
 (0)