Skip to content

Commit d36b8c3

Browse files
committed
Subscriptions keep clients alive until dropped
1 parent ffd5020 commit d36b8c3

File tree

10 files changed

+108
-79
lines changed

10 files changed

+108
-79
lines changed

client/http-client/src/client.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,14 +483,16 @@ where
483483
> + Send
484484
+ Sync,
485485
{
486+
type SubscriptionClient = Self;
487+
486488
/// Send a subscription request to the server. Not implemented for HTTP; will always return
487489
/// [`Error::HttpNotImplemented`].
488490
fn subscribe<'a, N, Params>(
489491
&self,
490492
_subscribe_method: &'a str,
491493
_params: Params,
492494
_unsubscribe_method: &'a str,
493-
) -> impl Future<Output = Result<Subscription<N>, Error>>
495+
) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, N>, Error>>
494496
where
495497
Params: ToRpcParams + Send,
496498
N: DeserializeOwned,
@@ -499,7 +501,10 @@ where
499501
}
500502

501503
/// Subscribe to a specific method. Not implemented for HTTP; will always return [`Error::HttpNotImplemented`].
502-
fn subscribe_to_method<N>(&self, _method: &str) -> impl Future<Output = Result<Subscription<N>, Error>>
504+
fn subscribe_to_method<N>(
505+
&self,
506+
_method: &str,
507+
) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, N>, Error>>
503508
where
504509
N: DeserializeOwned,
505510
{

client/ws-client/src/tests.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ async fn subscription_works() {
159159
let uri = to_ws_uri_string(server.local_addr());
160160
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();
161161
{
162-
let mut sub: Subscription<String> = client
162+
let mut sub: Subscription<_, String> = client
163163
.subscribe("subscribe_hello", rpc_params![], "unsubscribe_hello")
164164
.with_default_timeout()
165165
.await
@@ -183,7 +183,7 @@ async fn notification_handler_works() {
183183
let uri = to_ws_uri_string(server.local_addr());
184184
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();
185185
{
186-
let mut nh: Subscription<String> =
186+
let mut nh: Subscription<_, String> =
187187
client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap();
188188
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
189189
assert_eq!("server originated notification works".to_owned(), response);
@@ -203,7 +203,7 @@ async fn notification_no_params() {
203203
let uri = to_ws_uri_string(server.local_addr());
204204
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();
205205
{
206-
let mut nh: Subscription<serde_json::Value> =
206+
let mut nh: Subscription<_, serde_json::Value> =
207207
client.subscribe_to_method("no_params").with_default_timeout().await.unwrap().unwrap();
208208
let response = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
209209
assert_eq!(response, serde_json::Value::Null);
@@ -244,15 +244,15 @@ async fn batched_notifs_works() {
244244
// Ensure that subscription is returned back to the correct handle
245245
// and is handled separately from ordinary notifications.
246246
{
247-
let mut nh: Subscription<String> =
247+
let mut nh: Subscription<_, String> =
248248
client.subscribe("sub", rpc_params![], "unsub").with_default_timeout().await.unwrap().unwrap();
249249
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
250250
assert_eq!("sub_notif", response);
251251
}
252252

253253
// Ensure that method notif is returned back to the correct handle.
254254
{
255-
let mut nh: Subscription<String> =
255+
let mut nh: Subscription<_, String> =
256256
client.subscribe_to_method("sub").with_default_timeout().await.unwrap().unwrap();
257257
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
258258
assert_eq!("method_notif", response);
@@ -279,7 +279,7 @@ async fn notification_close_on_lagging() {
279279
.await
280280
.unwrap()
281281
.unwrap();
282-
let mut nh: Subscription<String> =
282+
let mut nh: Subscription<_, String> =
283283
client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap();
284284

285285
// Don't poll the notification stream for 2 seconds, should be full now.
@@ -297,7 +297,7 @@ async fn notification_close_on_lagging() {
297297
assert!(nh.next().with_default_timeout().await.unwrap().is_none());
298298

299299
// The same subscription should be possible to register again.
300-
let mut other_nh: Subscription<String> =
300+
let mut other_nh: Subscription<_, String> =
301301
client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap();
302302

303303
// check that the new subscription works.

core/src/client/async_client/mod.rs

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -360,14 +360,16 @@ impl<L> ClientBuilder<L> {
360360

361361
tokio::spawn(wait_for_shutdown(send_receive_task_sync_rx, client_dropped_rx, disconnect_reason.clone()));
362362

363-
Client {
363+
let inner = ClientInner {
364364
to_back: to_back.clone(),
365365
service: self.service_builder.service(RpcService::new(to_back.clone())),
366366
request_timeout: self.request_timeout,
367367
error: ErrorFromBack::new(to_back, disconnect_reason),
368368
id_manager: RequestIdManager::new(self.id_kind),
369369
on_exit: Some(client_dropped_tx),
370-
}
370+
};
371+
372+
Client { inner: Arc::new(inner) }
371373
}
372374

373375
/// Build the client with given transport.
@@ -430,9 +432,8 @@ impl<L> ClientBuilder<L> {
430432
}
431433
}
432434

433-
/// Generic asynchronous client.
434435
#[derive(Debug)]
435-
pub struct Client<L = RpcLogger<RpcService>> {
436+
struct ClientInner<L = RpcLogger<RpcService>> {
436437
/// Channel to send requests to the background task.
437438
to_back: mpsc::Sender<FrontToBack>,
438439
error: ErrorFromBack,
@@ -445,6 +446,21 @@ pub struct Client<L = RpcLogger<RpcService>> {
445446
service: L,
446447
}
447448

449+
impl<L> Drop for ClientInner<L> {
450+
fn drop(&mut self) {
451+
if let Some(e) = self.on_exit.take() {
452+
let _ = e.send(());
453+
}
454+
}
455+
}
456+
457+
/// Generic asynchronous client.
458+
#[derive(Debug)]
459+
#[repr(transparent)]
460+
pub struct Client<L = RpcLogger<RpcService>> {
461+
inner: Arc<ClientInner<L>>
462+
}
463+
448464
impl Client<Identity> {
449465
/// Create a builder for the client.
450466
pub fn builder() -> ClientBuilder {
@@ -455,13 +471,13 @@ impl Client<Identity> {
455471
impl<L> Client<L> {
456472
/// Checks if the client is connected to the target.
457473
pub fn is_connected(&self) -> bool {
458-
!self.to_back.is_closed()
474+
!self.inner.to_back.is_closed()
459475
}
460476

461477
async fn run_future_until_timeout<T>(&self, fut: impl Future<Output = Result<T, Error>>) -> Result<T, Error> {
462478
tokio::pin!(fut);
463479

464-
match futures_util::future::select(fut, futures_timer::Delay::new(self.request_timeout)).await {
480+
match futures_util::future::select(fut, futures_timer::Delay::new(self.inner.request_timeout)).await {
465481
Either::Left((Ok(r), _)) => Ok(r),
466482
Either::Left((Err(Error::ServiceDisconnect), _)) => Err(self.on_disconnect().await),
467483
Either::Left((Err(e), _)) => Err(e),
@@ -476,20 +492,18 @@ impl<L> Client<L> {
476492
///
477493
/// This method is cancel safe.
478494
pub async fn on_disconnect(&self) -> Error {
479-
self.error.read_error().await
495+
self.inner.error.read_error().await
480496
}
481497

482498
/// Returns configured request timeout.
483499
pub fn request_timeout(&self) -> Duration {
484-
self.request_timeout
500+
self.inner.request_timeout
485501
}
486502
}
487503

488-
impl<L> Drop for Client<L> {
489-
fn drop(&mut self) {
490-
if let Some(e) = self.on_exit.take() {
491-
let _ = e.send(());
492-
}
504+
impl<L> Clone for Client<L> {
505+
fn clone(&self) -> Self {
506+
Self { inner: self.inner.clone() }
493507
}
494508
}
495509

@@ -508,9 +522,9 @@ where
508522
{
509523
async {
510524
// NOTE: we use this to guard against max number of concurrent requests.
511-
let _req_id = self.id_manager.next_request_id();
525+
let _req_id = self.inner.id_manager.next_request_id();
512526
let params = params.to_rpc_params()?.map(StdCow::Owned);
513-
let fut = self.service.notification(jsonrpsee_types::Notification::new(method.into(), params));
527+
let fut = self.inner.service.notification(jsonrpsee_types::Notification::new(method.into(), params));
514528
self.run_future_until_timeout(fut).await?;
515529
Ok(())
516530
}
@@ -522,9 +536,9 @@ where
522536
Params: ToRpcParams + Send,
523537
{
524538
async {
525-
let id = self.id_manager.next_request_id();
539+
let id = self.inner.id_manager.next_request_id();
526540
let params = params.to_rpc_params()?;
527-
let fut = self.service.call(Request::borrowed(method, params.as_deref(), id.clone()));
541+
let fut = self.inner.service.call(Request::borrowed(method, params.as_deref(), id.clone()));
528542
let rp = self.run_future_until_timeout(fut).await?;
529543
let success = ResponseSuccess::try_from(rp.into_response().into_inner())?;
530544

@@ -541,15 +555,15 @@ where
541555
{
542556
async {
543557
let batch = batch.build()?;
544-
let id = self.id_manager.next_request_id();
558+
let id = self.inner.id_manager.next_request_id();
545559
let id_range = generate_batch_id_range(id, batch.len() as u64)?;
546560

547561
let mut b = Batch::with_capacity(batch.len());
548562

549563
for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
550564
b.push(Request {
551565
jsonrpc: TwoPointZero,
552-
id: self.id_manager.as_id_kind().into_id(id),
566+
id: self.inner.id_manager.as_id_kind().into_id(id),
553567
method: method.into(),
554568
params: params.map(StdCow::Owned),
555569
extensions: Extensions::new(),
@@ -558,7 +572,7 @@ where
558572

559573
b.extensions_mut().insert(IsBatch { id_range });
560574

561-
let fut = self.service.batch(b);
575+
let fut = self.inner.service.batch(b);
562576
let json_values = self.run_future_until_timeout(fut).await?;
563577

564578
let mut responses = Vec::with_capacity(json_values.len());
@@ -592,6 +606,8 @@ where
592606
> + Send
593607
+ Sync,
594608
{
609+
type SubscriptionClient = Self;
610+
595611
/// Send a subscription request to the server.
596612
///
597613
/// The `subscribe_method` and `params` are used to ask for the subscription towards the
@@ -601,7 +617,7 @@ where
601617
subscribe_method: &'a str,
602618
params: Params,
603619
unsubscribe_method: &'a str,
604-
) -> impl Future<Output = Result<Subscription<Notif>, Error>> + Send
620+
) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, Notif>, Error>> + Send
605621
where
606622
Params: ToRpcParams + Send,
607623
Notif: DeserializeOwned,
@@ -611,8 +627,8 @@ where
611627
return Err(RegisterMethodError::SubscriptionNameConflict(unsubscribe_method.to_owned()).into());
612628
}
613629

614-
let req_id_sub = self.id_manager.next_request_id();
615-
let req_id_unsub = self.id_manager.next_request_id();
630+
let req_id_sub = self.inner.id_manager.next_request_id();
631+
let req_id_unsub = self.inner.id_manager.next_request_id();
616632
let params = params.to_rpc_params()?;
617633

618634
let mut ext = Extensions::new();
@@ -626,24 +642,25 @@ where
626642
extensions: ext,
627643
};
628644

629-
let fut = self.service.call(req);
645+
let fut = self.inner.service.call(req);
630646
let sub = self
631647
.run_future_until_timeout(fut)
632648
.await?
633649
.into_subscription()
634650
.expect("Extensions set to subscription, must return subscription; qed");
635-
Ok(Subscription::new(self.to_back.clone(), sub.stream, SubscriptionKind::Subscription(sub.sub_id)))
651+
Ok(Subscription::new(self.clone(), self.inner.to_back.clone(), sub.stream, SubscriptionKind::Subscription(sub.sub_id)))
636652
}
637653
}
638654

639655
/// Subscribe to a specific method.
640-
fn subscribe_to_method<N>(&self, method: &str) -> impl Future<Output = Result<Subscription<N>, Error>> + Send
656+
fn subscribe_to_method<N>(&self, method: &str) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, N>, Error>> + Send
641657
where
642658
N: DeserializeOwned,
643659
{
644660
async {
645661
let (send_back_tx, send_back_rx) = oneshot::channel();
646662
if self
663+
.inner
647664
.to_back
648665
.clone()
649666
.send(FrontToBack::RegisterNotification(RegisterNotificationMessage {
@@ -656,15 +673,15 @@ where
656673
return Err(self.on_disconnect().await);
657674
}
658675

659-
let res = call_with_timeout(self.request_timeout, send_back_rx).await;
676+
let res = call_with_timeout(self.inner.request_timeout, send_back_rx).await;
660677

661678
let (rx, method) = match res {
662679
Ok(Ok(val)) => val,
663680
Ok(Err(err)) => return Err(err),
664681
Err(_) => return Err(self.on_disconnect().await),
665682
};
666683

667-
Ok(Subscription::new(self.to_back.clone(), rx, SubscriptionKind::Method(method)))
684+
Ok(Subscription::new(self.clone(), self.inner.to_back.clone(), rx, SubscriptionKind::Method(method)))
668685
}
669686
}
670687
}

core/src/client/mod.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ pub trait ClientT {
119119

120120
/// [JSON-RPC](https://www.jsonrpc.org/specification) client interface that can make requests, notifications and subscriptions.
121121
pub trait SubscriptionClientT: ClientT {
122+
type SubscriptionClient;
123+
122124
/// Initiate a subscription by performing a JSON-RPC method call where the server responds with
123125
/// a `Subscription ID` that is used to fetch messages on that subscription,
124126
///
@@ -136,7 +138,7 @@ pub trait SubscriptionClientT: ClientT {
136138
subscribe_method: &'a str,
137139
params: Params,
138140
unsubscribe_method: &'a str,
139-
) -> impl Future<Output = Result<Subscription<Notif>, Error>> + Send
141+
) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, Notif>, Error>> + Send
140142
where
141143
Params: ToRpcParams + Send,
142144
Notif: DeserializeOwned;
@@ -148,7 +150,7 @@ pub trait SubscriptionClientT: ClientT {
148150
fn subscribe_to_method<Notif>(
149151
&self,
150152
method: &str,
151-
) -> impl Future<Output = Result<Subscription<Notif>, Error>> + Send
153+
) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, Notif>, Error>> + Send
152154
where
153155
Notif: DeserializeOwned;
154156
}
@@ -272,7 +274,7 @@ pub enum SubscriptionCloseReason {
272274
/// You can call [`Subscription::close_reason`] to determine why
273275
/// the subscription was closed.
274276
#[derive(Debug)]
275-
pub struct Subscription<Notif> {
277+
pub struct Subscription<Client, Notif> {
276278
is_closed: bool,
277279
/// Channel to send requests to the background task.
278280
to_back: mpsc::Sender<FrontToBack>,
@@ -282,16 +284,18 @@ pub struct Subscription<Notif> {
282284
kind: Option<SubscriptionKind>,
283285
/// Marker in order to pin the `Notif` parameter.
284286
marker: PhantomData<Notif>,
287+
/// Keep client alive at least until subscription is dropped
288+
_client: Client,
285289
}
286290

287291
// `Subscription` does not automatically implement this due to `PhantomData<Notif>`,
288292
// but type type has no need to be pinned.
289-
impl<Notif> std::marker::Unpin for Subscription<Notif> {}
293+
impl<Client, Notif> std::marker::Unpin for Subscription<Client, Notif> {}
290294

291-
impl<Notif> Subscription<Notif> {
295+
impl<Client, Notif> Subscription<Client, Notif> {
292296
/// Create a new subscription.
293-
fn new(to_back: mpsc::Sender<FrontToBack>, rx: SubscriptionReceiver, kind: SubscriptionKind) -> Self {
294-
Self { to_back, rx, kind: Some(kind), marker: PhantomData, is_closed: false }
297+
fn new(client: Client, to_back: mpsc::Sender<FrontToBack>, rx: SubscriptionReceiver, kind: SubscriptionKind) -> Self {
298+
Self { _client: client, to_back, rx, kind: Some(kind), marker: PhantomData, is_closed: false }
295299
}
296300

297301
/// Return the subscription type and, if applicable, ID.
@@ -404,7 +408,7 @@ enum FrontToBack {
404408
SubscriptionClosed(SubscriptionId<'static>),
405409
}
406410

407-
impl<Notif> Subscription<Notif>
411+
impl<Client, Notif> Subscription<Client, Notif>
408412
where
409413
Notif: DeserializeOwned,
410414
{
@@ -421,7 +425,7 @@ where
421425
}
422426
}
423427

424-
impl<Notif> Stream for Subscription<Notif>
428+
impl<Client, Notif> Stream for Subscription<Client, Notif>
425429
where
426430
Notif: DeserializeOwned,
427431
{
@@ -439,7 +443,7 @@ where
439443
}
440444
}
441445

442-
impl<Notif> Drop for Subscription<Notif> {
446+
impl<Client, Notif> Drop for Subscription<Client, Notif> {
443447
fn drop(&mut self) {
444448
// We can't actually guarantee that this goes through. If the background task is busy, then
445449
// the channel's buffer will be full.

0 commit comments

Comments
 (0)