Skip to content

Commit f9f5e82

Browse files
build: allow to use rustls instead of native-tls
* This is used in an effort to remove all dependencies to openssl. Which could be interesting in embedded system or on environment which is difficult to know on which openssl version the software will run it and breaks deployments. * It introduces two compiler feature flags which are `tokio-rustls-runtime` and `async-std-rustls-runtime` that have the same meaning as `tokio-runtime` and `async-std-runtime` except that they use rustls. * There is a safe guard, if we enable both runtimes, this is the native-tls ones that are used to keep consistent with the current behaviour. Signed-off-by: Florentin Dubois <[email protected]>
1 parent 3964ed7 commit f9f5e82

File tree

8 files changed

+202
-53
lines changed

8 files changed

+202
-53
lines changed

Cargo.toml

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,17 @@ regex = "^1.9.1"
3030
bit-vec = "^0.6.3"
3131
futures = "^0.3.28"
3232
futures-io = "^0.3.28"
33-
native-tls = "^0.2.11"
33+
native-tls = { version = "^0.2.11", optional = true }
34+
rustls = { version = "^0.21.5", optional = true }
35+
webpki-roots = { version = "^0.24.0", optional = true }
3436
pem = "^3.0.0"
3537
tokio = { version = "^1.29.1", features = ["rt", "net", "time"], optional = true }
3638
tokio-util = { version = "^0.7.8", features = ["codec"], optional = true }
39+
tokio-rustls = { version = "^0.24.1", optional = true }
3740
tokio-native-tls = { version = "^0.3.1", optional = true }
38-
async-std = {version = "^1.12.0", features = [ "attributes", "unstable" ], optional = true }
41+
async-std = { version = "^1.12.0", features = [ "attributes", "unstable" ], optional = true }
3942
asynchronous-codec = { version = "^0.6.2", optional = true }
43+
async-rustls = { version = "^0.4.0", optional = true }
4044
async-native-tls = { version = "^0.5.0", optional = true }
4145
lz4 = { version = "^1.24.0", optional = true }
4246
flate2 = { version = "^1.0.26", optional = true }
@@ -49,7 +53,7 @@ serde_json = { version = "^1.0.103", optional = true }
4953
tracing = { version = "^0.1.37", optional = true }
5054
async-trait = "^0.1.72"
5155
data-url = { version = "^0.3.0", optional = true }
52-
uuid = {version = "^1.4.1", features = ["v4", "fast-rng"] }
56+
uuid = { version = "^1.4.1", features = ["v4", "fast-rng"] }
5357

5458
[dev-dependencies]
5559
serde = { version = "^1.0.175", features = ["derive"] }
@@ -62,10 +66,12 @@ prost-build = "^0.11.9"
6266
protobuf-src = { version = "1.1.0", optional = true }
6367

6468
[features]
65-
default = [ "compression", "tokio-runtime", "async-std-runtime", "auth-oauth2"]
69+
default = [ "compression", "tokio-rustls-runtime", "async-std-rustls-runtime", "auth-oauth2"]
6670
compression = [ "lz4", "flate2", "zstd", "snap" ]
67-
tokio-runtime = [ "tokio", "tokio-util", "tokio-native-tls" ]
68-
async-std-runtime = [ "async-std", "asynchronous-codec", "async-native-tls" ]
71+
tokio-runtime = [ "tokio", "tokio-util", "native-tls", "tokio-native-tls" ]
72+
tokio-rustls-runtime = ["tokio", "tokio-util", "tokio-rustls", "rustls", "webpki-roots" ]
73+
async-std-runtime = [ "async-std", "asynchronous-codec", "native-tls", "async-native-tls" ]
74+
async-std-rustls-runtime = ["async-std", "asynchronous-codec", "async-rustls", "rustls", "webpki-roots" ]
6975
auth-oauth2 = [ "openidconnect", "oauth2", "serde", "serde_json", "data-url" ]
7076
telemetry = ["tracing"]
7177
protobuf-src = ["dep:protobuf-src"]

src/connection.rs

Lines changed: 116 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ use futures::{
2020
task::{Context, Poll},
2121
Future, FutureExt, Sink, SinkExt, Stream, StreamExt,
2222
};
23+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
2324
use native_tls::Certificate;
25+
#[cfg(all(any(feature = "tokio-rustls-runtime", feature = "async-std--rustls-runtime"), not(any(feature = "tokio-runtime", feature = "async-std-runtime"))))]
26+
use rustls::Certificate;
2427
use proto::MessageIdData;
2528
use rand::{seq::SliceRandom, thread_rng};
2629
use url::Url;
@@ -934,11 +937,64 @@ impl<Exe: Executor> Connection<Exe> {
934937
.await
935938
}
936939
}
937-
#[cfg(not(feature = "tokio-runtime"))]
940+
#[cfg(all(feature = "tokio-rustls-runtime", not(feature = "tokio-runtime")))]
941+
ExecutorKind::Tokio => {
942+
if tls {
943+
let stream = tokio::net::TcpStream::connect(&address).await?;
944+
let mut root_store = rustls::RootCertStore::empty();
945+
for certificate in certificate_chain {
946+
root_store.add(certificate)?;
947+
}
948+
949+
let trust_anchors = webpki_roots::TLS_SERVER_ROOTS.0.iter().fold(vec![], |mut acc, trust_anchor| {
950+
acc.push(rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(trust_anchor.subject, trust_anchor.spki, trust_anchor.name_constraints));
951+
acc
952+
});
953+
954+
root_store.add_server_trust_anchors(trust_anchors.into_iter());
955+
let config = rustls::ClientConfig::builder()
956+
.with_safe_default_cipher_suites()
957+
.with_safe_default_kx_groups()
958+
.with_safe_default_protocol_versions()?
959+
.with_root_certificates(root_store)
960+
.with_no_client_auth();
961+
962+
let cx = tokio_rustls::TlsConnector::from(Arc::new(config));
963+
let stream = cx
964+
.connect(rustls::ServerName::try_from(hostname.as_str())?, stream)
965+
.await
966+
.map(|stream| tokio_util::codec::Framed::new(stream, Codec))?;
967+
968+
Connection::connect(
969+
connection_id,
970+
stream,
971+
auth,
972+
proxy_to_broker_url,
973+
executor,
974+
operation_timeout,
975+
)
976+
.await
977+
} else {
978+
let stream = tokio::net::TcpStream::connect(&address)
979+
.await
980+
.map(|stream| tokio_util::codec::Framed::new(stream, Codec))?;
981+
982+
Connection::connect(
983+
connection_id,
984+
stream,
985+
auth,
986+
proxy_to_broker_url,
987+
executor,
988+
operation_timeout,
989+
)
990+
.await
991+
}
992+
}
993+
#[cfg(all(not(feature = "tokio-runtime"), not(feature = "tokio-rustls-runtime")))]
938994
ExecutorKind::Tokio => {
939995
unimplemented!("the tokio-runtime cargo feature is not active");
940996
}
941-
#[cfg(feature = "async-std-runtime")]
997+
#[cfg(feature = "async-std-runtime")]
942998
ExecutorKind::AsyncStd => {
943999
if tls {
9441000
let stream = async_std::net::TcpStream::connect(&address).await?;
@@ -980,7 +1036,60 @@ impl<Exe: Executor> Connection<Exe> {
9801036
.await
9811037
}
9821038
}
983-
#[cfg(not(feature = "async-std-runtime"))]
1039+
#[cfg(all(feature = "async-std-rustls-runtime", not(feature = "async-std-runtime")))]
1040+
ExecutorKind::AsyncStd => {
1041+
if tls {
1042+
let stream = async_std::net::TcpStream::connect(&address).await?;
1043+
let mut root_store = rustls::RootCertStore::empty();
1044+
for certificate in certificate_chain {
1045+
root_store.add(certificate)?;
1046+
}
1047+
1048+
let trust_anchors = webpki_roots::TLS_SERVER_ROOTS.0.iter().fold(vec![], |mut acc, trust_anchor| {
1049+
acc.push(rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(trust_anchor.subject, trust_anchor.spki, trust_anchor.name_constraints));
1050+
acc
1051+
});
1052+
1053+
root_store.add_server_trust_anchors(trust_anchors.into_iter());
1054+
let config = rustls::ClientConfig::builder()
1055+
.with_safe_default_cipher_suites()
1056+
.with_safe_default_kx_groups()
1057+
.with_safe_default_protocol_versions()?
1058+
.with_root_certificates(root_store)
1059+
.with_no_client_auth();
1060+
1061+
let connector = async_rustls::TlsConnector::from(Arc::new(config));
1062+
let stream = connector
1063+
.connect(rustls::ServerName::try_from(hostname.as_str())?, stream)
1064+
.await
1065+
.map(|stream| asynchronous_codec::Framed::new(stream, Codec))?;
1066+
1067+
Connection::connect(
1068+
connection_id,
1069+
stream,
1070+
auth,
1071+
proxy_to_broker_url,
1072+
executor,
1073+
operation_timeout,
1074+
)
1075+
.await
1076+
} else {
1077+
let stream = async_std::net::TcpStream::connect(&address)
1078+
.await
1079+
.map(|stream| asynchronous_codec::Framed::new(stream, Codec))?;
1080+
1081+
Connection::connect(
1082+
connection_id,
1083+
stream,
1084+
auth,
1085+
proxy_to_broker_url,
1086+
executor,
1087+
operation_timeout,
1088+
)
1089+
.await
1090+
}
1091+
}
1092+
#[cfg(all(not(feature = "async-std-runtime"), not(feature = "async-std-rustls-runtime")))]
9841093
ExecutorKind::AsyncStd => {
9851094
unimplemented!("the async-std-runtime cargo feature is not active");
9861095
}
@@ -1628,11 +1737,12 @@ mod tests {
16281737
error::{AuthenticationError, SharedError},
16291738
message::{BaseCommand, Codec, Message},
16301739
proto::{AuthData, CommandAuthChallenge, CommandAuthResponse, CommandConnected},
1631-
TokioExecutor,
16321740
};
1741+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
1742+
use crate::TokioExecutor;
16331743

16341744
#[tokio::test]
1635-
#[cfg(feature = "tokio-runtime")]
1745+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
16361746
async fn receiver_auth_challenge_test() {
16371747
let (message_tx, message_rx) = mpsc::unbounded();
16381748
let (tx, _) = mpsc::unbounded();
@@ -1690,7 +1800,7 @@ mod tests {
16901800
}
16911801

16921802
#[tokio::test]
1693-
#[cfg(feature = "tokio-runtime")]
1803+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
16941804
async fn connection_auth_challenge_test() {
16951805
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
16961806

src/connection_manager.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use std::{collections::HashMap, sync::Arc, time::Duration};
22

33
use futures::{channel::oneshot, lock::Mutex};
4+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
45
use native_tls::Certificate;
6+
#[cfg(all(any(feature = "tokio-rustls-runtime", feature = "async-std--rustls-runtime"), not(any(feature = "tokio-runtime", feature = "async-std-runtime"))))]
7+
use rustls::Certificate;
58
use rand::Rng;
69
use url::Url;
710

@@ -153,10 +156,16 @@ impl<Exe: Executor> ConnectionManager<Exe> {
153156
.iter()
154157
.rev()
155158
{
159+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
156160
v.push(
157-
Certificate::from_der(&cert.contents())
161+
Certificate::from_der(cert.contents())
158162
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
159163
);
164+
165+
#[cfg(all(any(feature = "tokio-rustls-runtime", feature = "async-std--rustls-runtime"), not(any(feature = "tokio-runtime", feature = "async-std-runtime"))))]
166+
v.push(
167+
Certificate(cert.contents().to_vec())
168+
);
160169
}
161170
v
162171
}

src/consumer/mod.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,11 @@ mod tests {
437437
};
438438
use log::LevelFilter;
439439
use regex::Regex;
440-
#[cfg(feature = "tokio-runtime")]
440+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
441441
use tokio::time::timeout;
442442

443443
use super::*;
444-
#[cfg(feature = "tokio-runtime")]
444+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
445445
use crate::executor::TokioExecutor;
446446
use crate::{
447447
consumer::initial_position::InitialPosition, producer, proto, tests::TEST_LOGGER,
@@ -476,7 +476,7 @@ mod tests {
476476
tag: "multi_consumer",
477477
};
478478
#[tokio::test]
479-
#[cfg(feature = "tokio-runtime")]
479+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
480480
async fn multi_consumer() {
481481
let _result = log::set_logger(&MULTI_LOGGER);
482482
log::set_max_level(LevelFilter::Debug);
@@ -567,7 +567,7 @@ mod tests {
567567
}
568568

569569
#[tokio::test]
570-
#[cfg(feature = "tokio-runtime")]
570+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
571571
async fn consumer_dropped_with_lingering_acks() {
572572
use rand::{distributions::Alphanumeric, Rng};
573573
let _result = log::set_logger(&TEST_LOGGER);
@@ -664,7 +664,7 @@ mod tests {
664664
}
665665

666666
#[tokio::test]
667-
#[cfg(feature = "tokio-runtime")]
667+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
668668
async fn dead_letter_queue() {
669669
let _result = log::set_logger(&TEST_LOGGER);
670670
log::set_max_level(LevelFilter::Debug);
@@ -738,7 +738,7 @@ mod tests {
738738
}
739739

740740
#[tokio::test]
741-
#[cfg(feature = "tokio-runtime")]
741+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
742742
async fn failover() {
743743
let _result = log::set_logger(&MULTI_LOGGER);
744744
log::set_max_level(LevelFilter::Debug);
@@ -798,7 +798,7 @@ mod tests {
798798
}
799799

800800
#[tokio::test]
801-
#[cfg(feature = "tokio-runtime")]
801+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
802802
async fn seek_single_consumer() {
803803
let _result = log::set_logger(&MULTI_LOGGER);
804804
log::set_max_level(LevelFilter::Debug);
@@ -917,7 +917,7 @@ mod tests {
917917
}
918918

919919
#[tokio::test]
920-
#[cfg(feature = "tokio-runtime")]
920+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
921921
async fn schema_test() {
922922
#[derive(Serialize, Deserialize)]
923923
struct TestData {

src/error.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,12 @@ pub enum ConnectionError {
8888
Encoding(String),
8989
SocketAddr(String),
9090
UnexpectedResponse(String),
91+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
9192
Tls(native_tls::Error),
93+
#[cfg(all(any(feature = "tokio-rustls-runtime", feature = "async-std--rustls-runtime"), not(any(feature = "tokio-runtime", feature = "async-std-runtime"))))]
94+
Tls(rustls::Error),
95+
#[cfg(any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"))]
96+
DnsName(rustls::client::InvalidDnsNameError),
9297
Authentication(AuthenticationError),
9398
NotFound,
9499
Canceled,
@@ -113,13 +118,30 @@ impl From<io::Error> for ConnectionError {
113118
}
114119
}
115120

121+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
116122
impl From<native_tls::Error> for ConnectionError {
117123
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
118124
fn from(err: native_tls::Error) -> Self {
119125
ConnectionError::Tls(err)
120126
}
121127
}
122128

129+
#[cfg(all(any(feature = "tokio-rustls-runtime", feature = "async-std--rustls-runtime"), not(any(feature = "tokio-runtime", feature = "async-std-runtime"))))]
130+
impl From<rustls::Error> for ConnectionError {
131+
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
132+
fn from(err: rustls::Error) -> Self {
133+
ConnectionError::Tls(err)
134+
}
135+
}
136+
137+
#[cfg(any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"))]
138+
impl From<rustls::client::InvalidDnsNameError> for ConnectionError {
139+
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
140+
fn from(err: rustls::client::InvalidDnsNameError) -> Self {
141+
ConnectionError::DnsName(err)
142+
}
143+
}
144+
123145
impl From<AuthenticationError> for ConnectionError {
124146
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
125147
fn from(err: AuthenticationError) -> Self {
@@ -141,6 +163,8 @@ impl fmt::Display for ConnectionError {
141163
ConnectionError::Encoding(e) => write!(f, "Error encoding message: {e}"),
142164
ConnectionError::SocketAddr(e) => write!(f, "Error obtaining socket address: {e}"),
143165
ConnectionError::Tls(e) => write!(f, "Error connecting TLS stream: {e}"),
166+
#[cfg(any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"))]
167+
ConnectionError::DnsName(e) => write!(f, "Error resolving hostname: {e}"),
144168
ConnectionError::Authentication(e) => write!(f, "Error authentication: {e}"),
145169
ConnectionError::UnexpectedResponse(e) => {
146170
write!(f, "Unexpected response from pulsar: {e}")

0 commit comments

Comments
 (0)