Skip to content

Commit 6734d6c

Browse files
build: allow to use rustls instead of native-tls
* This is used in an effort to remove all dependencies to openssl. Which could be interesting in embedded system or on environment which is difficult to know on which openssl version the software will run it and breaks deployments. * It introduces two compiler feature flags which are `tokio-rustls-runtime` and `async-std-rustls-runtime` that have the same meaning as `tokio-runtime` and `async-std-runtime` except that they use rustls. * There is a safe guard, if we enable both runtimes, this is the native-tls ones that are used to keep consistent with the current behaviour. Signed-off-by: Florentin Dubois <[email protected]>
1 parent 3964ed7 commit 6734d6c

File tree

8 files changed

+265
-51
lines changed

8 files changed

+265
-51
lines changed

Cargo.toml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,17 @@ regex = "^1.9.1"
3030
bit-vec = "^0.6.3"
3131
futures = "^0.3.28"
3232
futures-io = "^0.3.28"
33-
native-tls = "^0.2.11"
33+
native-tls = { version = "^0.2.11", optional = true }
34+
rustls = { version = "^0.21.5", optional = true }
35+
webpki-roots = { version = "^0.24.0", optional = true }
3436
pem = "^3.0.0"
3537
tokio = { version = "^1.29.1", features = ["rt", "net", "time"], optional = true }
3638
tokio-util = { version = "^0.7.8", features = ["codec"], optional = true }
39+
tokio-rustls = { version = "^0.24.1", optional = true }
3740
tokio-native-tls = { version = "^0.3.1", optional = true }
38-
async-std = {version = "^1.12.0", features = [ "attributes", "unstable" ], optional = true }
41+
async-std = { version = "^1.12.0", features = [ "attributes", "unstable" ], optional = true }
3942
asynchronous-codec = { version = "^0.6.2", optional = true }
43+
async-rustls = { version = "^0.4.0", optional = true }
4044
async-native-tls = { version = "^0.5.0", optional = true }
4145
lz4 = { version = "^1.24.0", optional = true }
4246
flate2 = { version = "^1.0.26", optional = true }
@@ -49,7 +53,7 @@ serde_json = { version = "^1.0.103", optional = true }
4953
tracing = { version = "^0.1.37", optional = true }
5054
async-trait = "^0.1.72"
5155
data-url = { version = "^0.3.0", optional = true }
52-
uuid = {version = "^1.4.1", features = ["v4", "fast-rng"] }
56+
uuid = { version = "^1.4.1", features = ["v4", "fast-rng"] }
5357

5458
[dev-dependencies]
5559
serde = { version = "^1.0.175", features = ["derive"] }
@@ -64,8 +68,10 @@ protobuf-src = { version = "1.1.0", optional = true }
6468
[features]
6569
default = [ "compression", "tokio-runtime", "async-std-runtime", "auth-oauth2"]
6670
compression = [ "lz4", "flate2", "zstd", "snap" ]
67-
tokio-runtime = [ "tokio", "tokio-util", "tokio-native-tls" ]
68-
async-std-runtime = [ "async-std", "asynchronous-codec", "async-native-tls" ]
71+
tokio-runtime = [ "tokio", "tokio-util", "native-tls", "tokio-native-tls" ]
72+
tokio-rustls-runtime = ["tokio", "tokio-util", "tokio-rustls", "rustls", "webpki-roots" ]
73+
async-std-runtime = [ "async-std", "asynchronous-codec", "native-tls", "async-native-tls" ]
74+
async-std-rustls-runtime = ["async-std", "asynchronous-codec", "async-rustls", "rustls", "webpki-roots" ]
6975
auth-oauth2 = [ "openidconnect", "oauth2", "serde", "serde_json", "data-url" ]
7076
telemetry = ["tracing"]
7177
protobuf-src = ["dep:protobuf-src"]

src/connection.rs

Lines changed: 142 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@ use futures::{
2020
task::{Context, Poll},
2121
Future, FutureExt, Sink, SinkExt, Stream, StreamExt,
2222
};
23+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
2324
use native_tls::Certificate;
2425
use proto::MessageIdData;
2526
use rand::{seq::SliceRandom, thread_rng};
27+
#[cfg(all(
28+
any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"),
29+
not(any(feature = "tokio-runtime", feature = "async-std-runtime"))
30+
))]
31+
use rustls::Certificate;
2632
use url::Url;
2733
use uuid::Uuid;
2834

@@ -934,7 +940,69 @@ impl<Exe: Executor> Connection<Exe> {
934940
.await
935941
}
936942
}
937-
#[cfg(not(feature = "tokio-runtime"))]
943+
#[cfg(all(feature = "tokio-rustls-runtime", not(feature = "tokio-runtime")))]
944+
ExecutorKind::Tokio => {
945+
if tls {
946+
let stream = tokio::net::TcpStream::connect(&address).await?;
947+
let mut root_store = rustls::RootCertStore::empty();
948+
for certificate in certificate_chain {
949+
root_store.add(certificate)?;
950+
}
951+
952+
let trust_anchors = webpki_roots::TLS_SERVER_ROOTS.0.iter().fold(
953+
vec![],
954+
|mut acc, trust_anchor| {
955+
acc.push(
956+
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
957+
trust_anchor.subject,
958+
trust_anchor.spki,
959+
trust_anchor.name_constraints,
960+
),
961+
);
962+
acc
963+
},
964+
);
965+
966+
root_store.add_server_trust_anchors(trust_anchors.into_iter());
967+
let config = rustls::ClientConfig::builder()
968+
.with_safe_default_cipher_suites()
969+
.with_safe_default_kx_groups()
970+
.with_safe_default_protocol_versions()?
971+
.with_root_certificates(root_store)
972+
.with_no_client_auth();
973+
974+
let cx = tokio_rustls::TlsConnector::from(Arc::new(config));
975+
let stream = cx
976+
.connect(rustls::ServerName::try_from(hostname.as_str())?, stream)
977+
.await
978+
.map(|stream| tokio_util::codec::Framed::new(stream, Codec))?;
979+
980+
Connection::connect(
981+
connection_id,
982+
stream,
983+
auth,
984+
proxy_to_broker_url,
985+
executor,
986+
operation_timeout,
987+
)
988+
.await
989+
} else {
990+
let stream = tokio::net::TcpStream::connect(&address)
991+
.await
992+
.map(|stream| tokio_util::codec::Framed::new(stream, Codec))?;
993+
994+
Connection::connect(
995+
connection_id,
996+
stream,
997+
auth,
998+
proxy_to_broker_url,
999+
executor,
1000+
operation_timeout,
1001+
)
1002+
.await
1003+
}
1004+
}
1005+
#[cfg(all(not(feature = "tokio-runtime"), not(feature = "tokio-rustls-runtime")))]
9381006
ExecutorKind::Tokio => {
9391007
unimplemented!("the tokio-runtime cargo feature is not active");
9401008
}
@@ -980,7 +1048,75 @@ impl<Exe: Executor> Connection<Exe> {
9801048
.await
9811049
}
9821050
}
983-
#[cfg(not(feature = "async-std-runtime"))]
1051+
#[cfg(all(
1052+
feature = "async-std-rustls-runtime",
1053+
not(feature = "async-std-runtime")
1054+
))]
1055+
ExecutorKind::AsyncStd => {
1056+
if tls {
1057+
let stream = async_std::net::TcpStream::connect(&address).await?;
1058+
let mut root_store = rustls::RootCertStore::empty();
1059+
for certificate in certificate_chain {
1060+
root_store.add(certificate)?;
1061+
}
1062+
1063+
let trust_anchors = webpki_roots::TLS_SERVER_ROOTS.0.iter().fold(
1064+
vec![],
1065+
|mut acc, trust_anchor| {
1066+
acc.push(
1067+
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
1068+
trust_anchor.subject,
1069+
trust_anchor.spki,
1070+
trust_anchor.name_constraints,
1071+
),
1072+
);
1073+
acc
1074+
},
1075+
);
1076+
1077+
root_store.add_server_trust_anchors(trust_anchors.into_iter());
1078+
let config = rustls::ClientConfig::builder()
1079+
.with_safe_default_cipher_suites()
1080+
.with_safe_default_kx_groups()
1081+
.with_safe_default_protocol_versions()?
1082+
.with_root_certificates(root_store)
1083+
.with_no_client_auth();
1084+
1085+
let connector = async_rustls::TlsConnector::from(Arc::new(config));
1086+
let stream = connector
1087+
.connect(rustls::ServerName::try_from(hostname.as_str())?, stream)
1088+
.await
1089+
.map(|stream| asynchronous_codec::Framed::new(stream, Codec))?;
1090+
1091+
Connection::connect(
1092+
connection_id,
1093+
stream,
1094+
auth,
1095+
proxy_to_broker_url,
1096+
executor,
1097+
operation_timeout,
1098+
)
1099+
.await
1100+
} else {
1101+
let stream = async_std::net::TcpStream::connect(&address)
1102+
.await
1103+
.map(|stream| asynchronous_codec::Framed::new(stream, Codec))?;
1104+
1105+
Connection::connect(
1106+
connection_id,
1107+
stream,
1108+
auth,
1109+
proxy_to_broker_url,
1110+
executor,
1111+
operation_timeout,
1112+
)
1113+
.await
1114+
}
1115+
}
1116+
#[cfg(all(
1117+
not(feature = "async-std-runtime"),
1118+
not(feature = "async-std-rustls-runtime")
1119+
))]
9841120
ExecutorKind::AsyncStd => {
9851121
unimplemented!("the async-std-runtime cargo feature is not active");
9861122
}
@@ -1623,16 +1759,17 @@ mod tests {
16231759
use uuid::Uuid;
16241760

16251761
use super::{Connection, Receiver};
1762+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
1763+
use crate::TokioExecutor;
16261764
use crate::{
16271765
authentication::Authentication,
16281766
error::{AuthenticationError, SharedError},
16291767
message::{BaseCommand, Codec, Message},
16301768
proto::{AuthData, CommandAuthChallenge, CommandAuthResponse, CommandConnected},
1631-
TokioExecutor,
16321769
};
16331770

16341771
#[tokio::test]
1635-
#[cfg(feature = "tokio-runtime")]
1772+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
16361773
async fn receiver_auth_challenge_test() {
16371774
let (message_tx, message_rx) = mpsc::unbounded();
16381775
let (tx, _) = mpsc::unbounded();
@@ -1690,7 +1827,7 @@ mod tests {
16901827
}
16911828

16921829
#[tokio::test]
1693-
#[cfg(feature = "tokio-runtime")]
1830+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
16941831
async fn connection_auth_challenge_test() {
16951832
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
16961833

src/connection_manager.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
use std::{collections::HashMap, sync::Arc, time::Duration};
22

33
use futures::{channel::oneshot, lock::Mutex};
4+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
45
use native_tls::Certificate;
56
use rand::Rng;
7+
#[cfg(all(
8+
any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"),
9+
not(any(feature = "tokio-runtime", feature = "async-std-runtime"))
10+
))]
11+
use rustls::Certificate;
612
use url::Url;
713

814
use crate::{connection::Connection, error::ConnectionError, executor::Executor};
@@ -153,10 +159,20 @@ impl<Exe: Executor> ConnectionManager<Exe> {
153159
.iter()
154160
.rev()
155161
{
162+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
156163
v.push(
157-
Certificate::from_der(&cert.contents())
164+
Certificate::from_der(cert.contents())
158165
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
159166
);
167+
168+
#[cfg(all(
169+
any(
170+
feature = "tokio-rustls-runtime",
171+
feature = "async-std-rustls-runtime"
172+
),
173+
not(any(feature = "tokio-runtime", feature = "async-std-runtime"))
174+
))]
175+
v.push(Certificate(cert.contents().to_vec()));
160176
}
161177
v
162178
}

src/consumer/mod.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,11 @@ mod tests {
437437
};
438438
use log::LevelFilter;
439439
use regex::Regex;
440-
#[cfg(feature = "tokio-runtime")]
440+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
441441
use tokio::time::timeout;
442442

443443
use super::*;
444-
#[cfg(feature = "tokio-runtime")]
444+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
445445
use crate::executor::TokioExecutor;
446446
use crate::{
447447
consumer::initial_position::InitialPosition, producer, proto, tests::TEST_LOGGER,
@@ -476,7 +476,7 @@ mod tests {
476476
tag: "multi_consumer",
477477
};
478478
#[tokio::test]
479-
#[cfg(feature = "tokio-runtime")]
479+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
480480
async fn multi_consumer() {
481481
let _result = log::set_logger(&MULTI_LOGGER);
482482
log::set_max_level(LevelFilter::Debug);
@@ -567,7 +567,7 @@ mod tests {
567567
}
568568

569569
#[tokio::test]
570-
#[cfg(feature = "tokio-runtime")]
570+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
571571
async fn consumer_dropped_with_lingering_acks() {
572572
use rand::{distributions::Alphanumeric, Rng};
573573
let _result = log::set_logger(&TEST_LOGGER);
@@ -664,7 +664,7 @@ mod tests {
664664
}
665665

666666
#[tokio::test]
667-
#[cfg(feature = "tokio-runtime")]
667+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
668668
async fn dead_letter_queue() {
669669
let _result = log::set_logger(&TEST_LOGGER);
670670
log::set_max_level(LevelFilter::Debug);
@@ -738,7 +738,7 @@ mod tests {
738738
}
739739

740740
#[tokio::test]
741-
#[cfg(feature = "tokio-runtime")]
741+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
742742
async fn failover() {
743743
let _result = log::set_logger(&MULTI_LOGGER);
744744
log::set_max_level(LevelFilter::Debug);
@@ -798,7 +798,7 @@ mod tests {
798798
}
799799

800800
#[tokio::test]
801-
#[cfg(feature = "tokio-runtime")]
801+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
802802
async fn seek_single_consumer() {
803803
let _result = log::set_logger(&MULTI_LOGGER);
804804
log::set_max_level(LevelFilter::Debug);
@@ -917,7 +917,7 @@ mod tests {
917917
}
918918

919919
#[tokio::test]
920-
#[cfg(feature = "tokio-runtime")]
920+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
921921
async fn schema_test() {
922922
#[derive(Serialize, Deserialize)]
923923
struct TestData {

src/error.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,15 @@ pub enum ConnectionError {
8888
Encoding(String),
8989
SocketAddr(String),
9090
UnexpectedResponse(String),
91+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
9192
Tls(native_tls::Error),
93+
#[cfg(all(
94+
any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"),
95+
not(any(feature = "tokio-runtime", feature = "async-std-runtime"))
96+
))]
97+
Tls(rustls::Error),
98+
#[cfg(any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"))]
99+
DnsName(rustls::client::InvalidDnsNameError),
92100
Authentication(AuthenticationError),
93101
NotFound,
94102
Canceled,
@@ -113,13 +121,33 @@ impl From<io::Error> for ConnectionError {
113121
}
114122
}
115123

124+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
116125
impl From<native_tls::Error> for ConnectionError {
117126
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
118127
fn from(err: native_tls::Error) -> Self {
119128
ConnectionError::Tls(err)
120129
}
121130
}
122131

132+
#[cfg(all(
133+
any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"),
134+
not(any(feature = "tokio-runtime", feature = "async-std-runtime"))
135+
))]
136+
impl From<rustls::Error> for ConnectionError {
137+
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
138+
fn from(err: rustls::Error) -> Self {
139+
ConnectionError::Tls(err)
140+
}
141+
}
142+
143+
#[cfg(any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"))]
144+
impl From<rustls::client::InvalidDnsNameError> for ConnectionError {
145+
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
146+
fn from(err: rustls::client::InvalidDnsNameError) -> Self {
147+
ConnectionError::DnsName(err)
148+
}
149+
}
150+
123151
impl From<AuthenticationError> for ConnectionError {
124152
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
125153
fn from(err: AuthenticationError) -> Self {
@@ -141,6 +169,8 @@ impl fmt::Display for ConnectionError {
141169
ConnectionError::Encoding(e) => write!(f, "Error encoding message: {e}"),
142170
ConnectionError::SocketAddr(e) => write!(f, "Error obtaining socket address: {e}"),
143171
ConnectionError::Tls(e) => write!(f, "Error connecting TLS stream: {e}"),
172+
#[cfg(any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"))]
173+
ConnectionError::DnsName(e) => write!(f, "Error resolving hostname: {e}"),
144174
ConnectionError::Authentication(e) => write!(f, "Error authentication: {e}"),
145175
ConnectionError::UnexpectedResponse(e) => {
146176
write!(f, "Unexpected response from pulsar: {e}")

0 commit comments

Comments
 (0)