Skip to content

Commit b2c658b

Browse files
build: allow to use rustls instead of native-tls
* This is used in an effort to remove all dependencies to openssl. Which could be interesting in embedded system or on environment which is difficult to know on which openssl version the software will run it and breaks deployments. * It introduces two compiler feature flags which are `tokio-rustls-runtime` and `async-std-rustls-runtime` that have the same meaning as `tokio-runtime` and `async-std-runtime` except that they use rustls. * There is a safe guard, if we enable both runtimes, this is the native-tls ones that are used to keep consistent with the current behaviour. Signed-off-by: Florentin Dubois <[email protected]>
1 parent 3964ed7 commit b2c658b

File tree

8 files changed

+272
-51
lines changed

8 files changed

+272
-51
lines changed

Cargo.toml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,17 @@ regex = "^1.9.1"
3030
bit-vec = "^0.6.3"
3131
futures = "^0.3.28"
3232
futures-io = "^0.3.28"
33-
native-tls = "^0.2.11"
33+
native-tls = { version = "^0.2.11", optional = true }
34+
rustls = { version = "^0.21.5", optional = true }
35+
webpki-roots = { version = "^0.24.0", optional = true }
3436
pem = "^3.0.0"
3537
tokio = { version = "^1.29.1", features = ["rt", "net", "time"], optional = true }
3638
tokio-util = { version = "^0.7.8", features = ["codec"], optional = true }
39+
tokio-rustls = { version = "^0.24.1", optional = true }
3740
tokio-native-tls = { version = "^0.3.1", optional = true }
38-
async-std = {version = "^1.12.0", features = [ "attributes", "unstable" ], optional = true }
41+
async-std = { version = "^1.12.0", features = [ "attributes", "unstable" ], optional = true }
3942
asynchronous-codec = { version = "^0.6.2", optional = true }
43+
async-rustls = { version = "^0.4.0", optional = true }
4044
async-native-tls = { version = "^0.5.0", optional = true }
4145
lz4 = { version = "^1.24.0", optional = true }
4246
flate2 = { version = "^1.0.26", optional = true }
@@ -49,7 +53,7 @@ serde_json = { version = "^1.0.103", optional = true }
4953
tracing = { version = "^0.1.37", optional = true }
5054
async-trait = "^0.1.72"
5155
data-url = { version = "^0.3.0", optional = true }
52-
uuid = {version = "^1.4.1", features = ["v4", "fast-rng"] }
56+
uuid = { version = "^1.4.1", features = ["v4", "fast-rng"] }
5357

5458
[dev-dependencies]
5559
serde = { version = "^1.0.175", features = ["derive"] }
@@ -64,8 +68,10 @@ protobuf-src = { version = "1.1.0", optional = true }
6468
[features]
6569
default = [ "compression", "tokio-runtime", "async-std-runtime", "auth-oauth2"]
6670
compression = [ "lz4", "flate2", "zstd", "snap" ]
67-
tokio-runtime = [ "tokio", "tokio-util", "tokio-native-tls" ]
68-
async-std-runtime = [ "async-std", "asynchronous-codec", "async-native-tls" ]
71+
tokio-runtime = [ "tokio", "tokio-util", "native-tls", "tokio-native-tls" ]
72+
tokio-rustls-runtime = ["tokio", "tokio-util", "tokio-rustls", "rustls", "webpki-roots" ]
73+
async-std-runtime = [ "async-std", "asynchronous-codec", "native-tls", "async-native-tls" ]
74+
async-std-rustls-runtime = ["async-std", "asynchronous-codec", "async-rustls", "rustls", "webpki-roots" ]
6975
auth-oauth2 = [ "openidconnect", "oauth2", "serde", "serde_json", "data-url" ]
7076
telemetry = ["tracing"]
7177
protobuf-src = ["dep:protobuf-src"]

src/connection.rs

Lines changed: 145 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,18 @@ use futures::{
2020
task::{Context, Poll},
2121
Future, FutureExt, Sink, SinkExt, Stream, StreamExt,
2222
};
23+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
2324
use native_tls::Certificate;
2425
use proto::MessageIdData;
2526
use rand::{seq::SliceRandom, thread_rng};
27+
#[cfg(all(
28+
any(
29+
feature = "tokio-rustls-runtime",
30+
feature = "async-std--rustls-runtime"
31+
),
32+
not(any(feature = "tokio-runtime", feature = "async-std-runtime"))
33+
))]
34+
use rustls::Certificate;
2635
use url::Url;
2736
use uuid::Uuid;
2837

@@ -934,7 +943,69 @@ impl<Exe: Executor> Connection<Exe> {
934943
.await
935944
}
936945
}
937-
#[cfg(not(feature = "tokio-runtime"))]
946+
#[cfg(all(feature = "tokio-rustls-runtime", not(feature = "tokio-runtime")))]
947+
ExecutorKind::Tokio => {
948+
if tls {
949+
let stream = tokio::net::TcpStream::connect(&address).await?;
950+
let mut root_store = rustls::RootCertStore::empty();
951+
for certificate in certificate_chain {
952+
root_store.add(certificate)?;
953+
}
954+
955+
let trust_anchors = webpki_roots::TLS_SERVER_ROOTS.0.iter().fold(
956+
vec![],
957+
|mut acc, trust_anchor| {
958+
acc.push(
959+
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
960+
trust_anchor.subject,
961+
trust_anchor.spki,
962+
trust_anchor.name_constraints,
963+
),
964+
);
965+
acc
966+
},
967+
);
968+
969+
root_store.add_server_trust_anchors(trust_anchors.into_iter());
970+
let config = rustls::ClientConfig::builder()
971+
.with_safe_default_cipher_suites()
972+
.with_safe_default_kx_groups()
973+
.with_safe_default_protocol_versions()?
974+
.with_root_certificates(root_store)
975+
.with_no_client_auth();
976+
977+
let cx = tokio_rustls::TlsConnector::from(Arc::new(config));
978+
let stream = cx
979+
.connect(rustls::ServerName::try_from(hostname.as_str())?, stream)
980+
.await
981+
.map(|stream| tokio_util::codec::Framed::new(stream, Codec))?;
982+
983+
Connection::connect(
984+
connection_id,
985+
stream,
986+
auth,
987+
proxy_to_broker_url,
988+
executor,
989+
operation_timeout,
990+
)
991+
.await
992+
} else {
993+
let stream = tokio::net::TcpStream::connect(&address)
994+
.await
995+
.map(|stream| tokio_util::codec::Framed::new(stream, Codec))?;
996+
997+
Connection::connect(
998+
connection_id,
999+
stream,
1000+
auth,
1001+
proxy_to_broker_url,
1002+
executor,
1003+
operation_timeout,
1004+
)
1005+
.await
1006+
}
1007+
}
1008+
#[cfg(all(not(feature = "tokio-runtime"), not(feature = "tokio-rustls-runtime")))]
9381009
ExecutorKind::Tokio => {
9391010
unimplemented!("the tokio-runtime cargo feature is not active");
9401011
}
@@ -980,7 +1051,75 @@ impl<Exe: Executor> Connection<Exe> {
9801051
.await
9811052
}
9821053
}
983-
#[cfg(not(feature = "async-std-runtime"))]
1054+
#[cfg(all(
1055+
feature = "async-std-rustls-runtime",
1056+
not(feature = "async-std-runtime")
1057+
))]
1058+
ExecutorKind::AsyncStd => {
1059+
if tls {
1060+
let stream = async_std::net::TcpStream::connect(&address).await?;
1061+
let mut root_store = rustls::RootCertStore::empty();
1062+
for certificate in certificate_chain {
1063+
root_store.add(certificate)?;
1064+
}
1065+
1066+
let trust_anchors = webpki_roots::TLS_SERVER_ROOTS.0.iter().fold(
1067+
vec![],
1068+
|mut acc, trust_anchor| {
1069+
acc.push(
1070+
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
1071+
trust_anchor.subject,
1072+
trust_anchor.spki,
1073+
trust_anchor.name_constraints,
1074+
),
1075+
);
1076+
acc
1077+
},
1078+
);
1079+
1080+
root_store.add_server_trust_anchors(trust_anchors.into_iter());
1081+
let config = rustls::ClientConfig::builder()
1082+
.with_safe_default_cipher_suites()
1083+
.with_safe_default_kx_groups()
1084+
.with_safe_default_protocol_versions()?
1085+
.with_root_certificates(root_store)
1086+
.with_no_client_auth();
1087+
1088+
let connector = async_rustls::TlsConnector::from(Arc::new(config));
1089+
let stream = connector
1090+
.connect(rustls::ServerName::try_from(hostname.as_str())?, stream)
1091+
.await
1092+
.map(|stream| asynchronous_codec::Framed::new(stream, Codec))?;
1093+
1094+
Connection::connect(
1095+
connection_id,
1096+
stream,
1097+
auth,
1098+
proxy_to_broker_url,
1099+
executor,
1100+
operation_timeout,
1101+
)
1102+
.await
1103+
} else {
1104+
let stream = async_std::net::TcpStream::connect(&address)
1105+
.await
1106+
.map(|stream| asynchronous_codec::Framed::new(stream, Codec))?;
1107+
1108+
Connection::connect(
1109+
connection_id,
1110+
stream,
1111+
auth,
1112+
proxy_to_broker_url,
1113+
executor,
1114+
operation_timeout,
1115+
)
1116+
.await
1117+
}
1118+
}
1119+
#[cfg(all(
1120+
not(feature = "async-std-runtime"),
1121+
not(feature = "async-std-rustls-runtime")
1122+
))]
9841123
ExecutorKind::AsyncStd => {
9851124
unimplemented!("the async-std-runtime cargo feature is not active");
9861125
}
@@ -1623,16 +1762,17 @@ mod tests {
16231762
use uuid::Uuid;
16241763

16251764
use super::{Connection, Receiver};
1765+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
1766+
use crate::TokioExecutor;
16261767
use crate::{
16271768
authentication::Authentication,
16281769
error::{AuthenticationError, SharedError},
16291770
message::{BaseCommand, Codec, Message},
16301771
proto::{AuthData, CommandAuthChallenge, CommandAuthResponse, CommandConnected},
1631-
TokioExecutor,
16321772
};
16331773

16341774
#[tokio::test]
1635-
#[cfg(feature = "tokio-runtime")]
1775+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
16361776
async fn receiver_auth_challenge_test() {
16371777
let (message_tx, message_rx) = mpsc::unbounded();
16381778
let (tx, _) = mpsc::unbounded();
@@ -1690,7 +1830,7 @@ mod tests {
16901830
}
16911831

16921832
#[tokio::test]
1693-
#[cfg(feature = "tokio-runtime")]
1833+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
16941834
async fn connection_auth_challenge_test() {
16951835
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
16961836

src/connection_manager.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
use std::{collections::HashMap, sync::Arc, time::Duration};
22

33
use futures::{channel::oneshot, lock::Mutex};
4+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
45
use native_tls::Certificate;
56
use rand::Rng;
7+
#[cfg(all(
8+
any(
9+
feature = "tokio-rustls-runtime",
10+
feature = "async-std--rustls-runtime"
11+
),
12+
not(any(feature = "tokio-runtime", feature = "async-std-runtime"))
13+
))]
14+
use rustls::Certificate;
615
use url::Url;
716

817
use crate::{connection::Connection, error::ConnectionError, executor::Executor};
@@ -153,10 +162,20 @@ impl<Exe: Executor> ConnectionManager<Exe> {
153162
.iter()
154163
.rev()
155164
{
165+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
156166
v.push(
157-
Certificate::from_der(&cert.contents())
167+
Certificate::from_der(cert.contents())
158168
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
159169
);
170+
171+
#[cfg(all(
172+
any(
173+
feature = "tokio-rustls-runtime",
174+
feature = "async-std--rustls-runtime"
175+
),
176+
not(any(feature = "tokio-runtime", feature = "async-std-runtime"))
177+
))]
178+
v.push(Certificate(cert.contents().to_vec()));
160179
}
161180
v
162181
}

src/consumer/mod.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,11 @@ mod tests {
437437
};
438438
use log::LevelFilter;
439439
use regex::Regex;
440-
#[cfg(feature = "tokio-runtime")]
440+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
441441
use tokio::time::timeout;
442442

443443
use super::*;
444-
#[cfg(feature = "tokio-runtime")]
444+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
445445
use crate::executor::TokioExecutor;
446446
use crate::{
447447
consumer::initial_position::InitialPosition, producer, proto, tests::TEST_LOGGER,
@@ -476,7 +476,7 @@ mod tests {
476476
tag: "multi_consumer",
477477
};
478478
#[tokio::test]
479-
#[cfg(feature = "tokio-runtime")]
479+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
480480
async fn multi_consumer() {
481481
let _result = log::set_logger(&MULTI_LOGGER);
482482
log::set_max_level(LevelFilter::Debug);
@@ -567,7 +567,7 @@ mod tests {
567567
}
568568

569569
#[tokio::test]
570-
#[cfg(feature = "tokio-runtime")]
570+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
571571
async fn consumer_dropped_with_lingering_acks() {
572572
use rand::{distributions::Alphanumeric, Rng};
573573
let _result = log::set_logger(&TEST_LOGGER);
@@ -664,7 +664,7 @@ mod tests {
664664
}
665665

666666
#[tokio::test]
667-
#[cfg(feature = "tokio-runtime")]
667+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
668668
async fn dead_letter_queue() {
669669
let _result = log::set_logger(&TEST_LOGGER);
670670
log::set_max_level(LevelFilter::Debug);
@@ -738,7 +738,7 @@ mod tests {
738738
}
739739

740740
#[tokio::test]
741-
#[cfg(feature = "tokio-runtime")]
741+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
742742
async fn failover() {
743743
let _result = log::set_logger(&MULTI_LOGGER);
744744
log::set_max_level(LevelFilter::Debug);
@@ -798,7 +798,7 @@ mod tests {
798798
}
799799

800800
#[tokio::test]
801-
#[cfg(feature = "tokio-runtime")]
801+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
802802
async fn seek_single_consumer() {
803803
let _result = log::set_logger(&MULTI_LOGGER);
804804
log::set_max_level(LevelFilter::Debug);
@@ -917,7 +917,7 @@ mod tests {
917917
}
918918

919919
#[tokio::test]
920-
#[cfg(feature = "tokio-runtime")]
920+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
921921
async fn schema_test() {
922922
#[derive(Serialize, Deserialize)]
923923
struct TestData {

0 commit comments

Comments
 (0)