@@ -20,9 +20,18 @@ use futures::{
20
20
task:: { Context , Poll } ,
21
21
Future , FutureExt , Sink , SinkExt , Stream , StreamExt ,
22
22
} ;
23
+ #[ cfg( any( feature = "tokio-runtime" , feature = "async-std-runtime" ) ) ]
23
24
use native_tls:: Certificate ;
24
25
use proto:: MessageIdData ;
25
26
use rand:: { seq:: SliceRandom , thread_rng} ;
27
+ #[ cfg( all(
28
+ any(
29
+ feature = "tokio-rustls-runtime" ,
30
+ feature = "async-std--rustls-runtime"
31
+ ) ,
32
+ not( any( feature = "tokio-runtime" , feature = "async-std-runtime" ) )
33
+ ) ) ]
34
+ use rustls:: Certificate ;
26
35
use url:: Url ;
27
36
use uuid:: Uuid ;
28
37
@@ -934,7 +943,69 @@ impl<Exe: Executor> Connection<Exe> {
934
943
. await
935
944
}
936
945
}
937
- #[ cfg( not( feature = "tokio-runtime" ) ) ]
946
+ #[ cfg( all( feature = "tokio-rustls-runtime" , not( feature = "tokio-runtime" ) ) ) ]
947
+ ExecutorKind :: Tokio => {
948
+ if tls {
949
+ let stream = tokio:: net:: TcpStream :: connect ( & address) . await ?;
950
+ let mut root_store = rustls:: RootCertStore :: empty ( ) ;
951
+ for certificate in certificate_chain {
952
+ root_store. add ( certificate) ?;
953
+ }
954
+
955
+ let trust_anchors = webpki_roots:: TLS_SERVER_ROOTS . 0 . iter ( ) . fold (
956
+ vec ! [ ] ,
957
+ |mut acc, trust_anchor| {
958
+ acc. push (
959
+ rustls:: OwnedTrustAnchor :: from_subject_spki_name_constraints (
960
+ trust_anchor. subject ,
961
+ trust_anchor. spki ,
962
+ trust_anchor. name_constraints ,
963
+ ) ,
964
+ ) ;
965
+ acc
966
+ } ,
967
+ ) ;
968
+
969
+ root_store. add_server_trust_anchors ( trust_anchors. into_iter ( ) ) ;
970
+ let config = rustls:: ClientConfig :: builder ( )
971
+ . with_safe_default_cipher_suites ( )
972
+ . with_safe_default_kx_groups ( )
973
+ . with_safe_default_protocol_versions ( ) ?
974
+ . with_root_certificates ( root_store)
975
+ . with_no_client_auth ( ) ;
976
+
977
+ let cx = tokio_rustls:: TlsConnector :: from ( Arc :: new ( config) ) ;
978
+ let stream = cx
979
+ . connect ( rustls:: ServerName :: try_from ( hostname. as_str ( ) ) ?, stream)
980
+ . await
981
+ . map ( |stream| tokio_util:: codec:: Framed :: new ( stream, Codec ) ) ?;
982
+
983
+ Connection :: connect (
984
+ connection_id,
985
+ stream,
986
+ auth,
987
+ proxy_to_broker_url,
988
+ executor,
989
+ operation_timeout,
990
+ )
991
+ . await
992
+ } else {
993
+ let stream = tokio:: net:: TcpStream :: connect ( & address)
994
+ . await
995
+ . map ( |stream| tokio_util:: codec:: Framed :: new ( stream, Codec ) ) ?;
996
+
997
+ Connection :: connect (
998
+ connection_id,
999
+ stream,
1000
+ auth,
1001
+ proxy_to_broker_url,
1002
+ executor,
1003
+ operation_timeout,
1004
+ )
1005
+ . await
1006
+ }
1007
+ }
1008
+ #[ cfg( all( not( feature = "tokio-runtime" ) , not( feature = "tokio-rustls-runtime" ) ) ) ]
938
1009
ExecutorKind :: Tokio => {
939
1010
unimplemented ! ( "the tokio-runtime cargo feature is not active" ) ;
940
1011
}
@@ -980,7 +1051,75 @@ impl<Exe: Executor> Connection<Exe> {
980
1051
. await
981
1052
}
982
1053
}
983
- #[ cfg( not( feature = "async-std-runtime" ) ) ]
1054
+ #[ cfg( all(
1055
+ feature = "async-std-rustls-runtime" ,
1056
+ not( feature = "async-std-runtime" )
1057
+ ) ) ]
1058
+ ExecutorKind :: AsyncStd => {
1059
+ if tls {
1060
+ let stream = async_std:: net:: TcpStream :: connect ( & address) . await ?;
1061
+ let mut root_store = rustls:: RootCertStore :: empty ( ) ;
1062
+ for certificate in certificate_chain {
1063
+ root_store. add ( certificate) ?;
1064
+ }
1065
+
1066
+ let trust_anchors = webpki_roots:: TLS_SERVER_ROOTS . 0 . iter ( ) . fold (
1067
+ vec ! [ ] ,
1068
+ |mut acc, trust_anchor| {
1069
+ acc. push (
1070
+ rustls:: OwnedTrustAnchor :: from_subject_spki_name_constraints (
1071
+ trust_anchor. subject ,
1072
+ trust_anchor. spki ,
1073
+ trust_anchor. name_constraints ,
1074
+ ) ,
1075
+ ) ;
1076
+ acc
1077
+ } ,
1078
+ ) ;
1079
+
1080
+ root_store. add_server_trust_anchors ( trust_anchors. into_iter ( ) ) ;
1081
+ let config = rustls:: ClientConfig :: builder ( )
1082
+ . with_safe_default_cipher_suites ( )
1083
+ . with_safe_default_kx_groups ( )
1084
+ . with_safe_default_protocol_versions ( ) ?
1085
+ . with_root_certificates ( root_store)
1086
+ . with_no_client_auth ( ) ;
1087
+
1088
+ let connector = async_rustls:: TlsConnector :: from ( Arc :: new ( config) ) ;
1089
+ let stream = connector
1090
+ . connect ( rustls:: ServerName :: try_from ( hostname. as_str ( ) ) ?, stream)
1091
+ . await
1092
+ . map ( |stream| asynchronous_codec:: Framed :: new ( stream, Codec ) ) ?;
1093
+
1094
+ Connection :: connect (
1095
+ connection_id,
1096
+ stream,
1097
+ auth,
1098
+ proxy_to_broker_url,
1099
+ executor,
1100
+ operation_timeout,
1101
+ )
1102
+ . await
1103
+ } else {
1104
+ let stream = async_std:: net:: TcpStream :: connect ( & address)
1105
+ . await
1106
+ . map ( |stream| asynchronous_codec:: Framed :: new ( stream, Codec ) ) ?;
1107
+
1108
+ Connection :: connect (
1109
+ connection_id,
1110
+ stream,
1111
+ auth,
1112
+ proxy_to_broker_url,
1113
+ executor,
1114
+ operation_timeout,
1115
+ )
1116
+ . await
1117
+ }
1118
+ }
1119
+ #[ cfg( all(
1120
+ not( feature = "async-std-runtime" ) ,
1121
+ not( feature = "async-std-rustls-runtime" )
1122
+ ) ) ]
984
1123
ExecutorKind :: AsyncStd => {
985
1124
unimplemented ! ( "the async-std-runtime cargo feature is not active" ) ;
986
1125
}
@@ -1623,16 +1762,17 @@ mod tests {
1623
1762
use uuid:: Uuid ;
1624
1763
1625
1764
use super :: { Connection , Receiver } ;
1765
+ #[ cfg( any( feature = "tokio-runtime" , feature = "tokio-rustls-runtime" ) ) ]
1766
+ use crate :: TokioExecutor ;
1626
1767
use crate :: {
1627
1768
authentication:: Authentication ,
1628
1769
error:: { AuthenticationError , SharedError } ,
1629
1770
message:: { BaseCommand , Codec , Message } ,
1630
1771
proto:: { AuthData , CommandAuthChallenge , CommandAuthResponse , CommandConnected } ,
1631
- TokioExecutor ,
1632
1772
} ;
1633
1773
1634
1774
#[ tokio:: test]
1635
- #[ cfg( feature = "tokio-runtime" ) ]
1775
+ #[ cfg( any ( feature = "tokio-runtime" , feature = "tokio-rustls-runtime" ) ) ]
1636
1776
async fn receiver_auth_challenge_test ( ) {
1637
1777
let ( message_tx, message_rx) = mpsc:: unbounded ( ) ;
1638
1778
let ( tx, _) = mpsc:: unbounded ( ) ;
@@ -1690,7 +1830,7 @@ mod tests {
1690
1830
}
1691
1831
1692
1832
#[ tokio:: test]
1693
- #[ cfg( feature = "tokio-runtime" ) ]
1833
+ #[ cfg( any ( feature = "tokio-runtime" , feature = "tokio-rustls-runtime" ) ) ]
1694
1834
async fn connection_auth_challenge_test ( ) {
1695
1835
let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
1696
1836
0 commit comments