|
1 | | -use std::collections::HashMap; |
| 1 | +use std::collections::HashSet; |
2 | 2 |
|
3 | 3 | use bytes::Bytes; |
4 | 4 | use iroh::{ |
5 | 5 | endpoint::{ConnectError, Connection}, |
6 | 6 | Endpoint, NodeId, |
7 | 7 | }; |
8 | 8 | use tokio::task::JoinSet; |
9 | | -use tracing::{error, Instrument}; |
| 9 | +use tracing::Instrument; |
10 | 10 |
|
11 | | -#[derive(Debug)] |
| 11 | +#[derive(Debug, Default)] |
12 | 12 | pub(crate) struct Dialer { |
13 | | - endpoint: Endpoint, |
14 | | - pending: JoinSet<(NodeId, Result<Connection, ConnectError>)>, |
15 | | - pending_dials: HashMap<NodeId, ()>, |
| 13 | + pending_tasks: JoinSet<(NodeId, Result<Connection, ConnectError>)>, |
| 14 | + pending_nodes: HashSet<NodeId>, |
16 | 15 | } |
17 | 16 |
|
18 | 17 | impl Dialer { |
19 | | - /// Create a new dialer for a [`Endpoint`] |
20 | | - pub(crate) fn new(endpoint: Endpoint) -> Self { |
21 | | - Self { |
22 | | - endpoint, |
23 | | - pending: Default::default(), |
24 | | - pending_dials: Default::default(), |
25 | | - } |
26 | | - } |
27 | | - |
28 | | - #[cfg(test)] |
29 | | - pub(crate) fn endpoint(&self) -> &Endpoint { |
30 | | - &self.endpoint |
31 | | - } |
32 | | - |
33 | 18 | /// Starts to dial a node by [`NodeId`]. |
34 | | - pub(crate) fn queue_dial(&mut self, node_id: NodeId, alpn: Bytes) { |
35 | | - if self.is_pending(node_id) { |
36 | | - return; |
| 19 | + pub(crate) fn queue_dial(&mut self, endpoint: &Endpoint, node_id: NodeId, alpn: Bytes) { |
| 20 | + if self.pending_nodes.insert(node_id) { |
| 21 | + let endpoint = endpoint.clone(); |
| 22 | + let fut = async move { (node_id, endpoint.connect(node_id, &alpn).await) } |
| 23 | + .instrument(tracing::Span::current()); |
| 24 | + self.pending_tasks.spawn(fut); |
37 | 25 | } |
38 | | - self.pending_dials.insert(node_id, ()); |
39 | | - let endpoint = self.endpoint.clone(); |
40 | | - self.pending.spawn( |
41 | | - async move { |
42 | | - let res = endpoint.connect(node_id, &alpn).await; |
43 | | - (node_id, res) |
44 | | - } |
45 | | - .instrument(tracing::Span::current()), |
46 | | - ); |
47 | 26 | } |
48 | 27 |
|
49 | | - /// Checks if a node is currently being dialed. |
50 | | - pub(crate) fn is_pending(&self, node: NodeId) -> bool { |
51 | | - self.pending_dials.contains_key(&node) |
| 28 | + pub(crate) fn is_empty(&self) -> bool { |
| 29 | + self.pending_tasks.is_empty() |
52 | 30 | } |
53 | 31 |
|
54 | 32 | /// Waits for the next dial operation to complete. |
55 | 33 | /// `None` means disconnected |
56 | | - pub(crate) async fn next_conn(&mut self) -> (NodeId, Result<Connection, ConnectError>) { |
57 | | - match self.pending_dials.is_empty() { |
58 | | - false => { |
59 | | - let (node_id, res) = loop { |
60 | | - match self.pending.join_next().await { |
61 | | - Some(Ok((node_id, res))) => { |
62 | | - self.pending_dials.remove(&node_id); |
63 | | - break (node_id, res); |
64 | | - } |
65 | | - Some(Err(e)) => { |
66 | | - error!("next conn error: {:?}", e); |
67 | | - } |
68 | | - None => { |
69 | | - error!("no more pending conns available"); |
70 | | - std::future::pending().await |
71 | | - } |
72 | | - } |
73 | | - }; |
74 | | - |
75 | | - (node_id, res) |
| 34 | + /// |
| 35 | + /// Will be pending forever if no connections are in progress. |
| 36 | + pub(crate) async fn next(&mut self) -> Option<(NodeId, Result<Connection, ConnectError>)> { |
| 37 | + match self.pending_tasks.join_next().await { |
| 38 | + Some(res) => { |
| 39 | + let (node_id, res) = res.expect("connect task panicked"); |
| 40 | + self.pending_nodes.remove(&node_id); |
| 41 | + Some((node_id, res)) |
76 | 42 | } |
77 | | - true => std::future::pending().await, |
| 43 | + None => None, |
78 | 44 | } |
79 | 45 | } |
80 | 46 | } |
0 commit comments