@@ -20,7 +20,10 @@ use futures::{
20
20
task:: { Context , Poll } ,
21
21
Future , FutureExt , Sink , SinkExt , Stream , StreamExt ,
22
22
} ;
23
+ #[ cfg( any( feature = "tokio-runtime" , feature = "async-std-runtime" ) ) ]
23
24
use native_tls:: Certificate ;
25
+ #[ cfg( all( any( feature = "tokio-rustls-runtime" , feature = "async-std--rustls-runtime" ) , not( any( feature = "tokio-runtime" , feature = "async-std-runtime" ) ) ) ) ]
26
+ use rustls:: Certificate ;
24
27
use proto:: MessageIdData ;
25
28
use rand:: { seq:: SliceRandom , thread_rng} ;
26
29
use url:: Url ;
@@ -934,11 +937,58 @@ impl<Exe: Executor> Connection<Exe> {
934
937
. await
935
938
}
936
939
}
937
- #[ cfg( not( feature = "tokio-runtime" ) ) ]
940
+ #[ cfg( all( feature = "tokio-rustls-runtime" , not( feature = "tokio-runtime" ) ) ) ]
941
+ ExecutorKind :: Tokio => {
942
+ if tls {
943
+ let stream = tokio:: net:: TcpStream :: connect ( & address) . await ?;
944
+ let mut root_store = rustls:: RootCertStore :: empty ( ) ;
945
+ for certificate in certificate_chain {
946
+ root_store. add ( certificate) ?;
947
+ }
948
+
949
+ let config = rustls:: ClientConfig :: builder ( )
950
+ . with_safe_default_cipher_suites ( )
951
+ . with_safe_default_kx_groups ( )
952
+ . with_safe_default_protocol_versions ( ) ?
953
+ . with_root_certificates ( root_store)
954
+ . with_no_client_auth ( ) ;
955
+
956
+ let cx = tokio_rustls:: TlsConnector :: from ( Arc :: new ( config) ) ;
957
+ let stream = cx
958
+ . connect ( rustls:: ServerName :: try_from ( hostname. as_str ( ) ) ?, stream)
959
+ . await
960
+ . map ( |stream| tokio_util:: codec:: Framed :: new ( stream, Codec ) ) ?;
961
+
962
+ Connection :: connect (
963
+ connection_id,
964
+ stream,
965
+ auth,
966
+ proxy_to_broker_url,
967
+ executor,
968
+ operation_timeout,
969
+ )
970
+ . await
971
+ } else {
972
+ let stream = tokio:: net:: TcpStream :: connect ( & address)
973
+ . await
974
+ . map ( |stream| tokio_util:: codec:: Framed :: new ( stream, Codec ) ) ?;
975
+
976
+ Connection :: connect (
977
+ connection_id,
978
+ stream,
979
+ auth,
980
+ proxy_to_broker_url,
981
+ executor,
982
+ operation_timeout,
983
+ )
984
+ . await
985
+ }
986
+ }
987
+ #[ cfg( all( not( feature = "tokio-runtime" ) , not( feature = "tokio-rustls-runtime" ) ) ) ]
938
988
ExecutorKind :: Tokio => {
939
989
unimplemented ! ( "the tokio-runtime cargo feature is not active" ) ;
940
990
}
941
- #[ cfg( feature = "async-std-runtime" ) ]
991
+ #[ cfg( feature = "async-std-runtime" ) ]
942
992
ExecutorKind :: AsyncStd => {
943
993
if tls {
944
994
let stream = async_std:: net:: TcpStream :: connect ( & address) . await ?;
@@ -980,7 +1030,54 @@ impl<Exe: Executor> Connection<Exe> {
980
1030
. await
981
1031
}
982
1032
}
983
- #[ cfg( not( feature = "async-std-runtime" ) ) ]
1033
+ #[ cfg( all( feature = "async-std-rustls-runtime" , not( feature = "async-std-runtime" ) ) ) ]
1034
+ ExecutorKind :: AsyncStd => {
1035
+ if tls {
1036
+ let stream = async_std:: net:: TcpStream :: connect ( & address) . await ?;
1037
+ let mut root_store = rustls:: RootCertStore :: empty ( ) ;
1038
+ for certificate in certificate_chain {
1039
+ root_store. add ( certificate) ?;
1040
+ }
1041
+
1042
+ let config = rustls:: ClientConfig :: builder ( )
1043
+ . with_safe_default_cipher_suites ( )
1044
+ . with_safe_default_kx_groups ( )
1045
+ . with_safe_default_protocol_versions ( ) ?
1046
+ . with_root_certificates ( root_store)
1047
+ . with_no_client_auth ( ) ;
1048
+
1049
+ let connector = async_rustls:: TlsConnector :: from ( Arc :: new ( config) ) ;
1050
+ let stream = connector
1051
+ . connect ( rustls:: ServerName :: try_from ( hostname. as_str ( ) ) ?, stream)
1052
+ . await
1053
+ . map ( |stream| asynchronous_codec:: Framed :: new ( stream, Codec ) ) ?;
1054
+
1055
+ Connection :: connect (
1056
+ connection_id,
1057
+ stream,
1058
+ auth,
1059
+ proxy_to_broker_url,
1060
+ executor,
1061
+ operation_timeout,
1062
+ )
1063
+ . await
1064
+ } else {
1065
+ let stream = async_std:: net:: TcpStream :: connect ( & address)
1066
+ . await
1067
+ . map ( |stream| asynchronous_codec:: Framed :: new ( stream, Codec ) ) ?;
1068
+
1069
+ Connection :: connect (
1070
+ connection_id,
1071
+ stream,
1072
+ auth,
1073
+ proxy_to_broker_url,
1074
+ executor,
1075
+ operation_timeout,
1076
+ )
1077
+ . await
1078
+ }
1079
+ }
1080
+ #[ cfg( all( not( feature = "async-std-runtime" ) , not( feature = "async-std-rustls-runtime" ) ) ) ]
984
1081
ExecutorKind :: AsyncStd => {
985
1082
unimplemented ! ( "the async-std-runtime cargo feature is not active" ) ;
986
1083
}
@@ -1628,11 +1725,12 @@ mod tests {
1628
1725
error:: { AuthenticationError , SharedError } ,
1629
1726
message:: { BaseCommand , Codec , Message } ,
1630
1727
proto:: { AuthData , CommandAuthChallenge , CommandAuthResponse , CommandConnected } ,
1631
- TokioExecutor ,
1632
1728
} ;
1729
+ #[ cfg( any( feature = "tokio-runtime" , feature = "tokio-rustls-runtime" ) ) ]
1730
+ use crate :: TokioExecutor ;
1633
1731
1634
1732
#[ tokio:: test]
1635
- #[ cfg( feature = "tokio-runtime" ) ]
1733
+ #[ cfg( any ( feature = "tokio-runtime" , feature = "tokio-rustls-runtime" ) ) ]
1636
1734
async fn receiver_auth_challenge_test ( ) {
1637
1735
let ( message_tx, message_rx) = mpsc:: unbounded ( ) ;
1638
1736
let ( tx, _) = mpsc:: unbounded ( ) ;
@@ -1690,7 +1788,7 @@ mod tests {
1690
1788
}
1691
1789
1692
1790
#[ tokio:: test]
1693
- #[ cfg( feature = "tokio-runtime" ) ]
1791
+ #[ cfg( any ( feature = "tokio-runtime" , feature = "tokio-rustls-runtime" ) ) ]
1694
1792
async fn connection_auth_challenge_test ( ) {
1695
1793
let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
1696
1794
0 commit comments