@@ -20,9 +20,15 @@ use futures::{
20
20
task:: { Context , Poll } ,
21
21
Future , FutureExt , Sink , SinkExt , Stream , StreamExt ,
22
22
} ;
23
+ #[ cfg( any( feature = "tokio-runtime" , feature = "async-std-runtime" ) ) ]
23
24
use native_tls:: Certificate ;
24
25
use proto:: MessageIdData ;
25
26
use rand:: { seq:: SliceRandom , thread_rng} ;
27
+ #[ cfg( all(
28
+ any( feature = "tokio-rustls-runtime" , feature = "async-std-rustls-runtime" ) ,
29
+ not( any( feature = "tokio-runtime" , feature = "async-std-runtime" ) )
30
+ ) ) ]
31
+ use rustls:: Certificate ;
26
32
use url:: Url ;
27
33
use uuid:: Uuid ;
28
34
@@ -934,7 +940,69 @@ impl<Exe: Executor> Connection<Exe> {
934
940
. await
935
941
}
936
942
}
937
- #[ cfg( not( feature = "tokio-runtime" ) ) ]
943
+ #[ cfg( all( feature = "tokio-rustls-runtime" , not( feature = "tokio-runtime" ) ) ) ]
944
+ ExecutorKind :: Tokio => {
945
+ if tls {
946
+ let stream = tokio:: net:: TcpStream :: connect ( & address) . await ?;
947
+ let mut root_store = rustls:: RootCertStore :: empty ( ) ;
948
+ for certificate in certificate_chain {
949
+ root_store. add ( certificate) ?;
950
+ }
951
+
952
+ let trust_anchors = webpki_roots:: TLS_SERVER_ROOTS . 0 . iter ( ) . fold (
953
+ vec ! [ ] ,
954
+ |mut acc, trust_anchor| {
955
+ acc. push (
956
+ rustls:: OwnedTrustAnchor :: from_subject_spki_name_constraints (
957
+ trust_anchor. subject ,
958
+ trust_anchor. spki ,
959
+ trust_anchor. name_constraints ,
960
+ ) ,
961
+ ) ;
962
+ acc
963
+ } ,
964
+ ) ;
965
+
966
+ root_store. add_server_trust_anchors ( trust_anchors. into_iter ( ) ) ;
967
+ let config = rustls:: ClientConfig :: builder ( )
968
+ . with_safe_default_cipher_suites ( )
969
+ . with_safe_default_kx_groups ( )
970
+ . with_safe_default_protocol_versions ( ) ?
971
+ . with_root_certificates ( root_store)
972
+ . with_no_client_auth ( ) ;
973
+
974
+ let cx = tokio_rustls:: TlsConnector :: from ( Arc :: new ( config) ) ;
975
+ let stream = cx
976
+ . connect ( rustls:: ServerName :: try_from ( hostname. as_str ( ) ) ?, stream)
977
+ . await
978
+ . map ( |stream| tokio_util:: codec:: Framed :: new ( stream, Codec ) ) ?;
979
+
980
+ Connection :: connect (
981
+ connection_id,
982
+ stream,
983
+ auth,
984
+ proxy_to_broker_url,
985
+ executor,
986
+ operation_timeout,
987
+ )
988
+ . await
989
+ } else {
990
+ let stream = tokio:: net:: TcpStream :: connect ( & address)
991
+ . await
992
+ . map ( |stream| tokio_util:: codec:: Framed :: new ( stream, Codec ) ) ?;
993
+
994
+ Connection :: connect (
995
+ connection_id,
996
+ stream,
997
+ auth,
998
+ proxy_to_broker_url,
999
+ executor,
1000
+ operation_timeout,
1001
+ )
1002
+ . await
1003
+ }
1004
+ }
1005
+ #[ cfg( all( not( feature = "tokio-runtime" ) , not( feature = "tokio-rustls-runtime" ) ) ) ]
938
1006
ExecutorKind :: Tokio => {
939
1007
unimplemented ! ( "the tokio-runtime cargo feature is not active" ) ;
940
1008
}
@@ -980,7 +1048,75 @@ impl<Exe: Executor> Connection<Exe> {
980
1048
. await
981
1049
}
982
1050
}
983
- #[ cfg( not( feature = "async-std-runtime" ) ) ]
1051
+ #[ cfg( all(
1052
+ feature = "async-std-rustls-runtime" ,
1053
+ not( feature = "async-std-runtime" )
1054
+ ) ) ]
1055
+ ExecutorKind :: AsyncStd => {
1056
+ if tls {
1057
+ let stream = async_std:: net:: TcpStream :: connect ( & address) . await ?;
1058
+ let mut root_store = rustls:: RootCertStore :: empty ( ) ;
1059
+ for certificate in certificate_chain {
1060
+ root_store. add ( certificate) ?;
1061
+ }
1062
+
1063
+ let trust_anchors = webpki_roots:: TLS_SERVER_ROOTS . 0 . iter ( ) . fold (
1064
+ vec ! [ ] ,
1065
+ |mut acc, trust_anchor| {
1066
+ acc. push (
1067
+ rustls:: OwnedTrustAnchor :: from_subject_spki_name_constraints (
1068
+ trust_anchor. subject ,
1069
+ trust_anchor. spki ,
1070
+ trust_anchor. name_constraints ,
1071
+ ) ,
1072
+ ) ;
1073
+ acc
1074
+ } ,
1075
+ ) ;
1076
+
1077
+ root_store. add_server_trust_anchors ( trust_anchors. into_iter ( ) ) ;
1078
+ let config = rustls:: ClientConfig :: builder ( )
1079
+ . with_safe_default_cipher_suites ( )
1080
+ . with_safe_default_kx_groups ( )
1081
+ . with_safe_default_protocol_versions ( ) ?
1082
+ . with_root_certificates ( root_store)
1083
+ . with_no_client_auth ( ) ;
1084
+
1085
+ let connector = async_rustls:: TlsConnector :: from ( Arc :: new ( config) ) ;
1086
+ let stream = connector
1087
+ . connect ( rustls:: ServerName :: try_from ( hostname. as_str ( ) ) ?, stream)
1088
+ . await
1089
+ . map ( |stream| asynchronous_codec:: Framed :: new ( stream, Codec ) ) ?;
1090
+
1091
+ Connection :: connect (
1092
+ connection_id,
1093
+ stream,
1094
+ auth,
1095
+ proxy_to_broker_url,
1096
+ executor,
1097
+ operation_timeout,
1098
+ )
1099
+ . await
1100
+ } else {
1101
+ let stream = async_std:: net:: TcpStream :: connect ( & address)
1102
+ . await
1103
+ . map ( |stream| asynchronous_codec:: Framed :: new ( stream, Codec ) ) ?;
1104
+
1105
+ Connection :: connect (
1106
+ connection_id,
1107
+ stream,
1108
+ auth,
1109
+ proxy_to_broker_url,
1110
+ executor,
1111
+ operation_timeout,
1112
+ )
1113
+ . await
1114
+ }
1115
+ }
1116
+ #[ cfg( all(
1117
+ not( feature = "async-std-runtime" ) ,
1118
+ not( feature = "async-std-rustls-runtime" )
1119
+ ) ) ]
984
1120
ExecutorKind :: AsyncStd => {
985
1121
unimplemented ! ( "the async-std-runtime cargo feature is not active" ) ;
986
1122
}
@@ -1623,16 +1759,17 @@ mod tests {
1623
1759
use uuid:: Uuid ;
1624
1760
1625
1761
use super :: { Connection , Receiver } ;
1762
+ #[ cfg( any( feature = "tokio-runtime" , feature = "tokio-rustls-runtime" ) ) ]
1763
+ use crate :: TokioExecutor ;
1626
1764
use crate :: {
1627
1765
authentication:: Authentication ,
1628
1766
error:: { AuthenticationError , SharedError } ,
1629
1767
message:: { BaseCommand , Codec , Message } ,
1630
1768
proto:: { AuthData , CommandAuthChallenge , CommandAuthResponse , CommandConnected } ,
1631
- TokioExecutor ,
1632
1769
} ;
1633
1770
1634
1771
#[ tokio:: test]
1635
- #[ cfg( feature = "tokio-runtime" ) ]
1772
+ #[ cfg( any ( feature = "tokio-runtime" , feature = "tokio-rustls-runtime" ) ) ]
1636
1773
async fn receiver_auth_challenge_test ( ) {
1637
1774
let ( message_tx, message_rx) = mpsc:: unbounded ( ) ;
1638
1775
let ( tx, _) = mpsc:: unbounded ( ) ;
@@ -1690,7 +1827,7 @@ mod tests {
1690
1827
}
1691
1828
1692
1829
#[ tokio:: test]
1693
- #[ cfg( feature = "tokio-runtime" ) ]
1830
+ #[ cfg( any ( feature = "tokio-runtime" , feature = "tokio-rustls-runtime" ) ) ]
1694
1831
async fn connection_auth_challenge_test ( ) {
1695
1832
let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
1696
1833
0 commit comments