Skip to content

Commit ba163f9

Browse files
build: allow to use rustls instead of native-tls
* This is used in an effort to remove all dependencies to openssl. Which could be interesting in embedded system or on environment which is difficult to know on which openssl version the software will run it and breaks deployments. * It introduces two compiler feature flags which are `tokio-rustls-runtime` and `async-std-rustls-runtime` that have the same meaning as `tokio-runtime` and `async-std-runtime` except that they use rustls. * There is a safe guard, if we enable both runtimes, this is the native-tls ones that are used to keep consistent with the current behaviour. Signed-off-by: Florentin Dubois <[email protected]>
1 parent 3964ed7 commit ba163f9

File tree

8 files changed

+187
-51
lines changed

8 files changed

+187
-51
lines changed

Cargo.toml

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,16 @@ regex = "^1.9.1"
3030
bit-vec = "^0.6.3"
3131
futures = "^0.3.28"
3232
futures-io = "^0.3.28"
33-
native-tls = "^0.2.11"
33+
native-tls = { version = "^0.2.11", optional = true }
34+
rustls = { version = "^0.21.5", optional = true }
3435
pem = "^3.0.0"
3536
tokio = { version = "^1.29.1", features = ["rt", "net", "time"], optional = true }
3637
tokio-util = { version = "^0.7.8", features = ["codec"], optional = true }
38+
tokio-rustls = { version = "^0.24.1", optional = true }
3739
tokio-native-tls = { version = "^0.3.1", optional = true }
38-
async-std = {version = "^1.12.0", features = [ "attributes", "unstable" ], optional = true }
40+
async-std = { version = "^1.12.0", features = [ "attributes", "unstable" ], optional = true }
3941
asynchronous-codec = { version = "^0.6.2", optional = true }
42+
async-rustls = { version = "^0.4.0", optional = true }
4043
async-native-tls = { version = "^0.5.0", optional = true }
4144
lz4 = { version = "^1.24.0", optional = true }
4245
flate2 = { version = "^1.0.26", optional = true }
@@ -49,7 +52,7 @@ serde_json = { version = "^1.0.103", optional = true }
4952
tracing = { version = "^0.1.37", optional = true }
5053
async-trait = "^0.1.72"
5154
data-url = { version = "^0.3.0", optional = true }
52-
uuid = {version = "^1.4.1", features = ["v4", "fast-rng"] }
55+
uuid = { version = "^1.4.1", features = ["v4", "fast-rng"] }
5356

5457
[dev-dependencies]
5558
serde = { version = "^1.0.175", features = ["derive"] }
@@ -64,8 +67,10 @@ protobuf-src = { version = "1.1.0", optional = true }
6467
[features]
6568
default = [ "compression", "tokio-runtime", "async-std-runtime", "auth-oauth2"]
6669
compression = [ "lz4", "flate2", "zstd", "snap" ]
67-
tokio-runtime = [ "tokio", "tokio-util", "tokio-native-tls" ]
68-
async-std-runtime = [ "async-std", "asynchronous-codec", "async-native-tls" ]
70+
tokio-runtime = [ "tokio", "tokio-util", "native-tls", "tokio-native-tls" ]
71+
tokio-rustls-runtime = ["tokio", "tokio-util", "tokio-rustls", "rustls" ]
72+
async-std-runtime = [ "async-std", "asynchronous-codec", "native-tls", "async-native-tls" ]
73+
async-std-rustls-runtime = ["async-std", "asynchronous-codec", "async-rustls", "rustls" ]
6974
auth-oauth2 = [ "openidconnect", "oauth2", "serde", "serde_json", "data-url" ]
7075
telemetry = ["tracing"]
7176
protobuf-src = ["dep:protobuf-src"]

src/connection.rs

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ use futures::{
2020
task::{Context, Poll},
2121
Future, FutureExt, Sink, SinkExt, Stream, StreamExt,
2222
};
23+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
2324
use native_tls::Certificate;
25+
#[cfg(all(any(feature = "tokio-rustls-runtime", feature = "async-std--rustls-runtime"), not(any(feature = "tokio-runtime", feature = "async-std-runtime"))))]
26+
use rustls::Certificate;
2427
use proto::MessageIdData;
2528
use rand::{seq::SliceRandom, thread_rng};
2629
use url::Url;
@@ -934,11 +937,58 @@ impl<Exe: Executor> Connection<Exe> {
934937
.await
935938
}
936939
}
937-
#[cfg(not(feature = "tokio-runtime"))]
940+
#[cfg(all(feature = "tokio-rustls-runtime", not(feature = "tokio-runtime")))]
941+
ExecutorKind::Tokio => {
942+
if tls {
943+
let stream = tokio::net::TcpStream::connect(&address).await?;
944+
let mut root_store = rustls::RootCertStore::empty();
945+
for certificate in certificate_chain {
946+
root_store.add(certificate)?;
947+
}
948+
949+
let config = rustls::ClientConfig::builder()
950+
.with_safe_default_cipher_suites()
951+
.with_safe_default_kx_groups()
952+
.with_safe_default_protocol_versions()?
953+
.with_root_certificates(root_store)
954+
.with_no_client_auth();
955+
956+
let cx = tokio_rustls::TlsConnector::from(Arc::new(config));
957+
let stream = cx
958+
.connect(rustls::ServerName::try_from(hostname.as_str())?, stream)
959+
.await
960+
.map(|stream| tokio_util::codec::Framed::new(stream, Codec))?;
961+
962+
Connection::connect(
963+
connection_id,
964+
stream,
965+
auth,
966+
proxy_to_broker_url,
967+
executor,
968+
operation_timeout,
969+
)
970+
.await
971+
} else {
972+
let stream = tokio::net::TcpStream::connect(&address)
973+
.await
974+
.map(|stream| tokio_util::codec::Framed::new(stream, Codec))?;
975+
976+
Connection::connect(
977+
connection_id,
978+
stream,
979+
auth,
980+
proxy_to_broker_url,
981+
executor,
982+
operation_timeout,
983+
)
984+
.await
985+
}
986+
}
987+
#[cfg(all(not(feature = "tokio-runtime"), not(feature = "tokio-rustls-runtime")))]
938988
ExecutorKind::Tokio => {
939989
unimplemented!("the tokio-runtime cargo feature is not active");
940990
}
941-
#[cfg(feature = "async-std-runtime")]
991+
#[cfg(feature = "async-std-runtime")]
942992
ExecutorKind::AsyncStd => {
943993
if tls {
944994
let stream = async_std::net::TcpStream::connect(&address).await?;
@@ -980,7 +1030,54 @@ impl<Exe: Executor> Connection<Exe> {
9801030
.await
9811031
}
9821032
}
983-
#[cfg(not(feature = "async-std-runtime"))]
1033+
#[cfg(all(feature = "async-std-rustls-runtime", not(feature = "async-std-runtime")))]
1034+
ExecutorKind::AsyncStd => {
1035+
if tls {
1036+
let stream = async_std::net::TcpStream::connect(&address).await?;
1037+
let mut root_store = rustls::RootCertStore::empty();
1038+
for certificate in certificate_chain {
1039+
root_store.add(certificate)?;
1040+
}
1041+
1042+
let config = rustls::ClientConfig::builder()
1043+
.with_safe_default_cipher_suites()
1044+
.with_safe_default_kx_groups()
1045+
.with_safe_default_protocol_versions()?
1046+
.with_root_certificates(root_store)
1047+
.with_no_client_auth();
1048+
1049+
let connector = async_rustls::TlsConnector::from(Arc::new(config));
1050+
let stream = connector
1051+
.connect(rustls::ServerName::try_from(hostname.as_str())?, stream)
1052+
.await
1053+
.map(|stream| asynchronous_codec::Framed::new(stream, Codec))?;
1054+
1055+
Connection::connect(
1056+
connection_id,
1057+
stream,
1058+
auth,
1059+
proxy_to_broker_url,
1060+
executor,
1061+
operation_timeout,
1062+
)
1063+
.await
1064+
} else {
1065+
let stream = async_std::net::TcpStream::connect(&address)
1066+
.await
1067+
.map(|stream| asynchronous_codec::Framed::new(stream, Codec))?;
1068+
1069+
Connection::connect(
1070+
connection_id,
1071+
stream,
1072+
auth,
1073+
proxy_to_broker_url,
1074+
executor,
1075+
operation_timeout,
1076+
)
1077+
.await
1078+
}
1079+
}
1080+
#[cfg(all(not(feature = "async-std-runtime"), not(feature = "async-std-rustls-runtime")))]
9841081
ExecutorKind::AsyncStd => {
9851082
unimplemented!("the async-std-runtime cargo feature is not active");
9861083
}
@@ -1628,11 +1725,12 @@ mod tests {
16281725
error::{AuthenticationError, SharedError},
16291726
message::{BaseCommand, Codec, Message},
16301727
proto::{AuthData, CommandAuthChallenge, CommandAuthResponse, CommandConnected},
1631-
TokioExecutor,
16321728
};
1729+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
1730+
use crate::TokioExecutor;
16331731

16341732
#[tokio::test]
1635-
#[cfg(feature = "tokio-runtime")]
1733+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
16361734
async fn receiver_auth_challenge_test() {
16371735
let (message_tx, message_rx) = mpsc::unbounded();
16381736
let (tx, _) = mpsc::unbounded();
@@ -1690,7 +1788,7 @@ mod tests {
16901788
}
16911789

16921790
#[tokio::test]
1693-
#[cfg(feature = "tokio-runtime")]
1791+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
16941792
async fn connection_auth_challenge_test() {
16951793
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
16961794

src/connection_manager.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use std::{collections::HashMap, sync::Arc, time::Duration};
22

33
use futures::{channel::oneshot, lock::Mutex};
4+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
45
use native_tls::Certificate;
6+
#[cfg(all(any(feature = "tokio-rustls-runtime", feature = "async-std--rustls-runtime"), not(any(feature = "tokio-runtime", feature = "async-std-runtime"))))]
7+
use rustls::Certificate;
58
use rand::Rng;
69
use url::Url;
710

@@ -153,10 +156,16 @@ impl<Exe: Executor> ConnectionManager<Exe> {
153156
.iter()
154157
.rev()
155158
{
159+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
156160
v.push(
157161
Certificate::from_der(&cert.contents())
158162
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
159163
);
164+
165+
#[cfg(all(any(feature = "tokio-rustls-runtime", feature = "async-std--rustls-runtime"), not(any(feature = "tokio-runtime", feature = "async-std-runtime"))))]
166+
v.push(
167+
Certificate(cert.contents().to_vec())
168+
);
160169
}
161170
v
162171
}

src/consumer/mod.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,11 @@ mod tests {
437437
};
438438
use log::LevelFilter;
439439
use regex::Regex;
440-
#[cfg(feature = "tokio-runtime")]
440+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
441441
use tokio::time::timeout;
442442

443443
use super::*;
444-
#[cfg(feature = "tokio-runtime")]
444+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
445445
use crate::executor::TokioExecutor;
446446
use crate::{
447447
consumer::initial_position::InitialPosition, producer, proto, tests::TEST_LOGGER,
@@ -476,7 +476,7 @@ mod tests {
476476
tag: "multi_consumer",
477477
};
478478
#[tokio::test]
479-
#[cfg(feature = "tokio-runtime")]
479+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
480480
async fn multi_consumer() {
481481
let _result = log::set_logger(&MULTI_LOGGER);
482482
log::set_max_level(LevelFilter::Debug);
@@ -567,7 +567,7 @@ mod tests {
567567
}
568568

569569
#[tokio::test]
570-
#[cfg(feature = "tokio-runtime")]
570+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
571571
async fn consumer_dropped_with_lingering_acks() {
572572
use rand::{distributions::Alphanumeric, Rng};
573573
let _result = log::set_logger(&TEST_LOGGER);
@@ -664,7 +664,7 @@ mod tests {
664664
}
665665

666666
#[tokio::test]
667-
#[cfg(feature = "tokio-runtime")]
667+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
668668
async fn dead_letter_queue() {
669669
let _result = log::set_logger(&TEST_LOGGER);
670670
log::set_max_level(LevelFilter::Debug);
@@ -738,7 +738,7 @@ mod tests {
738738
}
739739

740740
#[tokio::test]
741-
#[cfg(feature = "tokio-runtime")]
741+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
742742
async fn failover() {
743743
let _result = log::set_logger(&MULTI_LOGGER);
744744
log::set_max_level(LevelFilter::Debug);
@@ -798,7 +798,7 @@ mod tests {
798798
}
799799

800800
#[tokio::test]
801-
#[cfg(feature = "tokio-runtime")]
801+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
802802
async fn seek_single_consumer() {
803803
let _result = log::set_logger(&MULTI_LOGGER);
804804
log::set_max_level(LevelFilter::Debug);
@@ -917,7 +917,7 @@ mod tests {
917917
}
918918

919919
#[tokio::test]
920-
#[cfg(feature = "tokio-runtime")]
920+
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
921921
async fn schema_test() {
922922
#[derive(Serialize, Deserialize)]
923923
struct TestData {

src/error.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,12 @@ pub enum ConnectionError {
8888
Encoding(String),
8989
SocketAddr(String),
9090
UnexpectedResponse(String),
91+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
9192
Tls(native_tls::Error),
93+
#[cfg(all(any(feature = "tokio-rustls-runtime", feature = "async-std--rustls-runtime"), not(any(feature = "tokio-runtime", feature = "async-std-runtime"))))]
94+
Tls(rustls::Error),
95+
#[cfg(any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"))]
96+
DnsName(rustls::client::InvalidDnsNameError),
9297
Authentication(AuthenticationError),
9398
NotFound,
9499
Canceled,
@@ -113,13 +118,30 @@ impl From<io::Error> for ConnectionError {
113118
}
114119
}
115120

121+
#[cfg(any(feature = "tokio-runtime", feature = "async-std-runtime"))]
116122
impl From<native_tls::Error> for ConnectionError {
117123
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
118124
fn from(err: native_tls::Error) -> Self {
119125
ConnectionError::Tls(err)
120126
}
121127
}
122128

129+
#[cfg(all(any(feature = "tokio-rustls-runtime", feature = "async-std--rustls-runtime"), not(any(feature = "tokio-runtime", feature = "async-std-runtime"))))]
130+
impl From<rustls::Error> for ConnectionError {
131+
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
132+
fn from(err: rustls::Error) -> Self {
133+
ConnectionError::Tls(err)
134+
}
135+
}
136+
137+
#[cfg(any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"))]
138+
impl From<rustls::client::InvalidDnsNameError> for ConnectionError {
139+
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
140+
fn from(err: rustls::client::InvalidDnsNameError) -> Self {
141+
ConnectionError::DnsName(err)
142+
}
143+
}
144+
123145
impl From<AuthenticationError> for ConnectionError {
124146
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
125147
fn from(err: AuthenticationError) -> Self {
@@ -141,6 +163,8 @@ impl fmt::Display for ConnectionError {
141163
ConnectionError::Encoding(e) => write!(f, "Error encoding message: {e}"),
142164
ConnectionError::SocketAddr(e) => write!(f, "Error obtaining socket address: {e}"),
143165
ConnectionError::Tls(e) => write!(f, "Error connecting TLS stream: {e}"),
166+
#[cfg(any(feature = "tokio-rustls-runtime", feature = "async-std-rustls-runtime"))]
167+
ConnectionError::DnsName(e) => write!(f, "Error resolving hostname: {e}"),
144168
ConnectionError::Authentication(e) => write!(f, "Error authentication: {e}"),
145169
ConnectionError::UnexpectedResponse(e) => {
146170
write!(f, "Unexpected response from pulsar: {e}")

0 commit comments

Comments
 (0)