Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,14 +519,16 @@ where
> + Send
+ Sync,
{
type SubscriptionClient = Self;

/// Send a subscription request to the server. Not implemented for HTTP; will always return
/// [`Error::HttpNotImplemented`].
fn subscribe<'a, N, Params>(
&self,
_subscribe_method: &'a str,
_params: Params,
_unsubscribe_method: &'a str,
) -> impl Future<Output = Result<Subscription<N>, Error>>
) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, N>, Error>>
where
Params: ToRpcParams + Send,
N: DeserializeOwned,
Expand All @@ -535,7 +537,10 @@ where
}

/// Subscribe to a specific method. Not implemented for HTTP; will always return [`Error::HttpNotImplemented`].
fn subscribe_to_method<N>(&self, _method: &str) -> impl Future<Output = Result<Subscription<N>, Error>>
fn subscribe_to_method<N>(
&self,
_method: &str,
) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, N>, Error>>
where
N: DeserializeOwned,
{
Expand Down
14 changes: 7 additions & 7 deletions client/ws-client/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
let uri = to_ws_uri_string(server.local_addr());
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();
{
let mut sub: Subscription<String> = client
let mut sub: Subscription<_, String> = client
.subscribe("subscribe_hello", rpc_params![], "unsubscribe_hello")
.with_default_timeout()
.await
Expand All @@ -183,7 +183,7 @@
let uri = to_ws_uri_string(server.local_addr());
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();
{
let mut nh: Subscription<String> =
let mut nh: Subscription<_, String> =
client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap();
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
assert_eq!("server originated notification works".to_owned(), response);
Expand All @@ -203,7 +203,7 @@
let uri = to_ws_uri_string(server.local_addr());
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();
{
let mut nh: Subscription<serde_json::Value> =
let mut nh: Subscription<_, serde_json::Value> =
client.subscribe_to_method("no_params").with_default_timeout().await.unwrap().unwrap();
let response = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
assert_eq!(response, serde_json::Value::Null);
Expand Down Expand Up @@ -244,15 +244,15 @@
// Ensure that subscription is returned back to the correct handle
// and is handled separately from ordinary notifications.
{
let mut nh: Subscription<String> =
let mut nh: Subscription<_, String> =
client.subscribe("sub", rpc_params![], "unsub").with_default_timeout().await.unwrap().unwrap();
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
assert_eq!("sub_notif", response);
}

// Ensure that method notif is returned back to the correct handle.
{
let mut nh: Subscription<String> =
let mut nh: Subscription<_, String> =
client.subscribe_to_method("sub").with_default_timeout().await.unwrap().unwrap();
let response: String = nh.next().with_default_timeout().await.unwrap().unwrap().unwrap();
assert_eq!("method_notif", response);
Expand All @@ -279,7 +279,7 @@
.await
.unwrap()
.unwrap();
let mut nh: Subscription<String> =
let mut nh: Subscription<_, String> =
client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap();

// Don't poll the notification stream for 2 seconds, should be full now.
Expand All @@ -297,7 +297,7 @@
assert!(nh.next().with_default_timeout().await.unwrap().is_none());

// The same subscription should be possible to register again.
let mut other_nh: Subscription<String> =
let mut other_nh: Subscription<_, String> =
client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap();

// check that the new subscription works.
Expand Down Expand Up @@ -440,7 +440,7 @@
}

async fn run_batch_request_with_response<T: Send + DeserializeOwned + std::fmt::Debug + Clone + 'static>(
batch: BatchRequestBuilder<'_>,

Check warning on line 443 in client/ws-client/src/tests.rs

View workflow job for this annotation

GitHub Actions / Check style

hiding a lifetime that's elided elsewhere is confusing
response: String,
) -> Result<BatchResponse<T>, Error> {
let server = WebSocketTestServer::with_hardcoded_response("127.0.0.1:0".parse().unwrap(), response)
Expand Down
79 changes: 49 additions & 30 deletions core/src/client/async_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,14 +360,16 @@ impl<L> ClientBuilder<L> {

tokio::spawn(wait_for_shutdown(send_receive_task_sync_rx, client_dropped_rx, disconnect_reason.clone()));

Client {
let inner = ClientInner {
to_back: to_back.clone(),
service: self.service_builder.service(RpcService::new(to_back.clone())),
request_timeout: self.request_timeout,
error: ErrorFromBack::new(to_back, disconnect_reason),
id_manager: RequestIdManager::new(self.id_kind),
on_exit: Some(client_dropped_tx),
}
};

Client { inner: Arc::new(inner) }
}

/// Build the client with given transport.
Expand Down Expand Up @@ -419,20 +421,21 @@ impl<L> ClientBuilder<L> {
disconnect_reason.clone(),
));

Client {
let inner = ClientInner {
to_back: to_back.clone(),
service: self.service_builder.service(RpcService::new(to_back.clone())),
request_timeout: self.request_timeout,
error: ErrorFromBack::new(to_back, disconnect_reason),
id_manager: RequestIdManager::new(self.id_kind),
on_exit: Some(client_dropped_tx),
}
};

Client { inner: Arc::new(inner) }
}
}

/// Generic asynchronous client.
#[derive(Debug)]
pub struct Client<L = RpcLogger<RpcService>> {
struct ClientInner<L = RpcLogger<RpcService>> {
/// Channel to send requests to the background task.
to_back: mpsc::Sender<FrontToBack>,
error: ErrorFromBack,
Expand All @@ -445,6 +448,21 @@ pub struct Client<L = RpcLogger<RpcService>> {
service: L,
}

impl<L> Drop for ClientInner<L> {
fn drop(&mut self) {
if let Some(e) = self.on_exit.take() {
let _ = e.send(());
}
}
}

/// Generic asynchronous client.
#[derive(Debug)]
#[repr(transparent)]
pub struct Client<L = RpcLogger<RpcService>> {
inner: Arc<ClientInner<L>>
}

impl Client<Identity> {
/// Create a builder for the client.
pub fn builder() -> ClientBuilder {
Expand All @@ -455,13 +473,13 @@ impl Client<Identity> {
impl<L> Client<L> {
/// Checks if the client is connected to the target.
pub fn is_connected(&self) -> bool {
!self.to_back.is_closed()
!self.inner.to_back.is_closed()
}

async fn run_future_until_timeout<T>(&self, fut: impl Future<Output = Result<T, Error>>) -> Result<T, Error> {
tokio::pin!(fut);

match futures_util::future::select(fut, futures_timer::Delay::new(self.request_timeout)).await {
match futures_util::future::select(fut, futures_timer::Delay::new(self.inner.request_timeout)).await {
Either::Left((Ok(r), _)) => Ok(r),
Either::Left((Err(Error::ServiceDisconnect), _)) => Err(self.on_disconnect().await),
Either::Left((Err(e), _)) => Err(e),
Expand All @@ -476,20 +494,18 @@ impl<L> Client<L> {
///
/// This method is cancel safe.
pub async fn on_disconnect(&self) -> Error {
self.error.read_error().await
self.inner.error.read_error().await
}

/// Returns configured request timeout.
pub fn request_timeout(&self) -> Duration {
self.request_timeout
self.inner.request_timeout
}
}

impl<L> Drop for Client<L> {
fn drop(&mut self) {
if let Some(e) = self.on_exit.take() {
let _ = e.send(());
}
impl<L> Clone for Client<L> {
fn clone(&self) -> Self {
Self { inner: self.inner.clone() }
}
}

Expand All @@ -508,9 +524,9 @@ where
{
async {
// NOTE: we use this to guard against max number of concurrent requests.
let _req_id = self.id_manager.next_request_id();
let _req_id = self.inner.id_manager.next_request_id();
let params = params.to_rpc_params()?.map(StdCow::Owned);
let fut = self.service.notification(jsonrpsee_types::Notification::new(method.into(), params));
let fut = self.inner.service.notification(jsonrpsee_types::Notification::new(method.into(), params));
self.run_future_until_timeout(fut).await?;
Ok(())
}
Expand All @@ -522,9 +538,9 @@ where
Params: ToRpcParams + Send,
{
async {
let id = self.id_manager.next_request_id();
let id = self.inner.id_manager.next_request_id();
let params = params.to_rpc_params()?;
let fut = self.service.call(Request::borrowed(method, params.as_deref(), id.clone()));
let fut = self.inner.service.call(Request::borrowed(method, params.as_deref(), id.clone()));
let rp = self.run_future_until_timeout(fut).await?;
let success = ResponseSuccess::try_from(rp.into_response().into_inner())?;

Expand All @@ -541,15 +557,15 @@ where
{
async {
let batch = batch.build()?;
let id = self.id_manager.next_request_id();
let id = self.inner.id_manager.next_request_id();
let id_range = generate_batch_id_range(id, batch.len() as u64)?;

let mut b = Batch::with_capacity(batch.len());

for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
b.push(Request {
jsonrpc: TwoPointZero,
id: self.id_manager.as_id_kind().into_id(id),
id: self.inner.id_manager.as_id_kind().into_id(id),
method: method.into(),
params: params.map(StdCow::Owned),
extensions: Extensions::new(),
Expand All @@ -558,7 +574,7 @@ where

b.extensions_mut().insert(IsBatch { id_range });

let fut = self.service.batch(b);
let fut = self.inner.service.batch(b);
let json_values = self.run_future_until_timeout(fut).await?;

let mut responses = Vec::with_capacity(json_values.len());
Expand Down Expand Up @@ -592,6 +608,8 @@ where
> + Send
+ Sync,
{
type SubscriptionClient = Self;

/// Send a subscription request to the server.
///
/// The `subscribe_method` and `params` are used to ask for the subscription towards the
Expand All @@ -601,7 +619,7 @@ where
subscribe_method: &'a str,
params: Params,
unsubscribe_method: &'a str,
) -> impl Future<Output = Result<Subscription<Notif>, Error>> + Send
) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, Notif>, Error>> + Send
where
Params: ToRpcParams + Send,
Notif: DeserializeOwned,
Expand All @@ -611,8 +629,8 @@ where
return Err(RegisterMethodError::SubscriptionNameConflict(unsubscribe_method.to_owned()).into());
}

let req_id_sub = self.id_manager.next_request_id();
let req_id_unsub = self.id_manager.next_request_id();
let req_id_sub = self.inner.id_manager.next_request_id();
let req_id_unsub = self.inner.id_manager.next_request_id();
let params = params.to_rpc_params()?;

let mut ext = Extensions::new();
Expand All @@ -626,24 +644,25 @@ where
extensions: ext,
};

let fut = self.service.call(req);
let fut = self.inner.service.call(req);
let sub = self
.run_future_until_timeout(fut)
.await?
.into_subscription()
.expect("Extensions set to subscription, must return subscription; qed");
Ok(Subscription::new(self.to_back.clone(), sub.stream, SubscriptionKind::Subscription(sub.sub_id)))
Ok(Subscription::new(self.clone(), self.inner.to_back.clone(), sub.stream, SubscriptionKind::Subscription(sub.sub_id)))
}
}

/// Subscribe to a specific method.
fn subscribe_to_method<N>(&self, method: &str) -> impl Future<Output = Result<Subscription<N>, Error>> + Send
fn subscribe_to_method<N>(&self, method: &str) -> impl Future<Output = Result<Subscription<Self::SubscriptionClient, N>, Error>> + Send
where
N: DeserializeOwned,
{
async {
let (send_back_tx, send_back_rx) = oneshot::channel();
if self
.inner
.to_back
.clone()
.send(FrontToBack::RegisterNotification(RegisterNotificationMessage {
Expand All @@ -656,15 +675,15 @@ where
return Err(self.on_disconnect().await);
}

let res = call_with_timeout(self.request_timeout, send_back_rx).await;
let res = call_with_timeout(self.inner.request_timeout, send_back_rx).await;

let (rx, method) = match res {
Ok(Ok(val)) => val,
Ok(Err(err)) => return Err(err),
Err(_) => return Err(self.on_disconnect().await),
};

Ok(Subscription::new(self.to_back.clone(), rx, SubscriptionKind::Method(method)))
Ok(Subscription::new(self.clone(), self.inner.to_back.clone(), rx, SubscriptionKind::Method(method)))
}
}
}
Expand Down
Loading
Loading