@@ -20,7 +20,10 @@ use futures::{
20
20
task:: { Context , Poll } ,
21
21
Future , FutureExt , Sink , SinkExt , Stream , StreamExt ,
22
22
} ;
23
+ #[ cfg( any( feature = "tokio-runtime" , feature = "async-std-runtime" ) ) ]
23
24
use native_tls:: Certificate ;
25
+ #[ cfg( all( any( feature = "tokio-rustls-runtime" , feature = "async-std--rustls-runtime" ) , not( any( feature = "tokio-runtime" , feature = "async-std-runtime" ) ) ) ) ]
26
+ use rustls:: Certificate ;
24
27
use proto:: MessageIdData ;
25
28
use rand:: { seq:: SliceRandom , thread_rng} ;
26
29
use url:: Url ;
@@ -934,11 +937,64 @@ impl<Exe: Executor> Connection<Exe> {
934
937
. await
935
938
}
936
939
}
937
- #[ cfg( not( feature = "tokio-runtime" ) ) ]
940
+ #[ cfg( all( feature = "tokio-rustls-runtime" , not( feature = "tokio-runtime" ) ) ) ]
941
+ ExecutorKind :: Tokio => {
942
+ if tls {
943
+ let stream = tokio:: net:: TcpStream :: connect ( & address) . await ?;
944
+ let mut root_store = rustls:: RootCertStore :: empty ( ) ;
945
+ for certificate in certificate_chain {
946
+ root_store. add ( certificate) ?;
947
+ }
948
+
949
+ let trust_anchors = webpki_roots:: TLS_SERVER_ROOTS . 0 . iter ( ) . fold ( vec ! [ ] , |mut acc, trust_anchor| {
950
+ acc. push ( rustls:: OwnedTrustAnchor :: from_subject_spki_name_constraints ( trust_anchor. subject , trust_anchor. spki , trust_anchor. name_constraints ) ) ;
951
+ acc
952
+ } ) ;
953
+
954
+ root_store. add_server_trust_anchors ( trust_anchors. into_iter ( ) ) ;
955
+ let config = rustls:: ClientConfig :: builder ( )
956
+ . with_safe_default_cipher_suites ( )
957
+ . with_safe_default_kx_groups ( )
958
+ . with_safe_default_protocol_versions ( ) ?
959
+ . with_root_certificates ( root_store)
960
+ . with_no_client_auth ( ) ;
961
+
962
+ let cx = tokio_rustls:: TlsConnector :: from ( Arc :: new ( config) ) ;
963
+ let stream = cx
964
+ . connect ( rustls:: ServerName :: try_from ( hostname. as_str ( ) ) ?, stream)
965
+ . await
966
+ . map ( |stream| tokio_util:: codec:: Framed :: new ( stream, Codec ) ) ?;
967
+
968
+ Connection :: connect (
969
+ connection_id,
970
+ stream,
971
+ auth,
972
+ proxy_to_broker_url,
973
+ executor,
974
+ operation_timeout,
975
+ )
976
+ . await
977
+ } else {
978
+ let stream = tokio:: net:: TcpStream :: connect ( & address)
979
+ . await
980
+ . map ( |stream| tokio_util:: codec:: Framed :: new ( stream, Codec ) ) ?;
981
+
982
+ Connection :: connect (
983
+ connection_id,
984
+ stream,
985
+ auth,
986
+ proxy_to_broker_url,
987
+ executor,
988
+ operation_timeout,
989
+ )
990
+ . await
991
+ }
992
+ }
993
+ #[ cfg( all( not( feature = "tokio-runtime" ) , not( feature = "tokio-rustls-runtime" ) ) ) ]
938
994
ExecutorKind :: Tokio => {
939
995
unimplemented ! ( "the tokio-runtime cargo feature is not active" ) ;
940
996
}
941
- #[ cfg( feature = "async-std-runtime" ) ]
997
+ #[ cfg( feature = "async-std-runtime" ) ]
942
998
ExecutorKind :: AsyncStd => {
943
999
if tls {
944
1000
let stream = async_std:: net:: TcpStream :: connect ( & address) . await ?;
@@ -980,7 +1036,60 @@ impl<Exe: Executor> Connection<Exe> {
980
1036
. await
981
1037
}
982
1038
}
983
- #[ cfg( not( feature = "async-std-runtime" ) ) ]
1039
+ #[ cfg( all( feature = "async-std-rustls-runtime" , not( feature = "async-std-runtime" ) ) ) ]
1040
+ ExecutorKind :: AsyncStd => {
1041
+ if tls {
1042
+ let stream = async_std:: net:: TcpStream :: connect ( & address) . await ?;
1043
+ let mut root_store = rustls:: RootCertStore :: empty ( ) ;
1044
+ for certificate in certificate_chain {
1045
+ root_store. add ( certificate) ?;
1046
+ }
1047
+
1048
+ let trust_anchors = webpki_roots:: TLS_SERVER_ROOTS . 0 . iter ( ) . fold ( vec ! [ ] , |mut acc, trust_anchor| {
1049
+ acc. push ( rustls:: OwnedTrustAnchor :: from_subject_spki_name_constraints ( trust_anchor. subject , trust_anchor. spki , trust_anchor. name_constraints ) ) ;
1050
+ acc
1051
+ } ) ;
1052
+
1053
+ root_store. add_server_trust_anchors ( trust_anchors. into_iter ( ) ) ;
1054
+ let config = rustls:: ClientConfig :: builder ( )
1055
+ . with_safe_default_cipher_suites ( )
1056
+ . with_safe_default_kx_groups ( )
1057
+ . with_safe_default_protocol_versions ( ) ?
1058
+ . with_root_certificates ( root_store)
1059
+ . with_no_client_auth ( ) ;
1060
+
1061
+ let connector = async_rustls:: TlsConnector :: from ( Arc :: new ( config) ) ;
1062
+ let stream = connector
1063
+ . connect ( rustls:: ServerName :: try_from ( hostname. as_str ( ) ) ?, stream)
1064
+ . await
1065
+ . map ( |stream| asynchronous_codec:: Framed :: new ( stream, Codec ) ) ?;
1066
+
1067
+ Connection :: connect (
1068
+ connection_id,
1069
+ stream,
1070
+ auth,
1071
+ proxy_to_broker_url,
1072
+ executor,
1073
+ operation_timeout,
1074
+ )
1075
+ . await
1076
+ } else {
1077
+ let stream = async_std:: net:: TcpStream :: connect ( & address)
1078
+ . await
1079
+ . map ( |stream| asynchronous_codec:: Framed :: new ( stream, Codec ) ) ?;
1080
+
1081
+ Connection :: connect (
1082
+ connection_id,
1083
+ stream,
1084
+ auth,
1085
+ proxy_to_broker_url,
1086
+ executor,
1087
+ operation_timeout,
1088
+ )
1089
+ . await
1090
+ }
1091
+ }
1092
+ #[ cfg( all( not( feature = "async-std-runtime" ) , not( feature = "async-std-rustls-runtime" ) ) ) ]
984
1093
ExecutorKind :: AsyncStd => {
985
1094
unimplemented ! ( "the async-std-runtime cargo feature is not active" ) ;
986
1095
}
@@ -1628,11 +1737,12 @@ mod tests {
1628
1737
error:: { AuthenticationError , SharedError } ,
1629
1738
message:: { BaseCommand , Codec , Message } ,
1630
1739
proto:: { AuthData , CommandAuthChallenge , CommandAuthResponse , CommandConnected } ,
1631
- TokioExecutor ,
1632
1740
} ;
1741
+ #[ cfg( any( feature = "tokio-runtime" , feature = "tokio-rustls-runtime" ) ) ]
1742
+ use crate :: TokioExecutor ;
1633
1743
1634
1744
#[ tokio:: test]
1635
- #[ cfg( feature = "tokio-runtime" ) ]
1745
+ #[ cfg( any ( feature = "tokio-runtime" , feature = "tokio-rustls-runtime" ) ) ]
1636
1746
async fn receiver_auth_challenge_test ( ) {
1637
1747
let ( message_tx, message_rx) = mpsc:: unbounded ( ) ;
1638
1748
let ( tx, _) = mpsc:: unbounded ( ) ;
@@ -1690,7 +1800,7 @@ mod tests {
1690
1800
}
1691
1801
1692
1802
#[ tokio:: test]
1693
- #[ cfg( feature = "tokio-runtime" ) ]
1803
+ #[ cfg( any ( feature = "tokio-runtime" , feature = "tokio-rustls-runtime" ) ) ]
1694
1804
async fn connection_auth_challenge_test ( ) {
1695
1805
let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
1696
1806
0 commit comments