Skip to content

Commit 1cabcfc

Browse files
authored
Merge pull request #19 from serprex/not-infallible
Fix MakeTlsConnect error to be rustls::pki_types::InvalidDnsNameError
2 parents a588474 + 2c6cfad commit 1cabcfc

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

src/lib.rs

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ use DigestAlgorithm::{Sha1, Sha256, Sha384, Sha512};
1010

1111
use futures::future::{FutureExt, TryFutureExt};
1212
use ring::digest;
13-
use rustls::ClientConfig;
1413
use rustls::pki_types::ServerName;
14+
use rustls::ClientConfig;
1515
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
1616
use tokio_postgres::tls::{ChannelBinding, MakeTlsConnect, TlsConnect};
1717
use tokio_rustls::{client::TlsStream, TlsConnector};
@@ -39,21 +39,19 @@ where
3939
{
4040
type Stream = RustlsStream<S>;
4141
type TlsConnect = RustlsConnect;
42-
type Error = io::Error;
42+
type Error = rustls::pki_types::InvalidDnsNameError;
4343

44-
fn make_tls_connect(&mut self, hostname: &str) -> io::Result<RustlsConnect> {
45-
ServerName::try_from(hostname)
46-
.map(|dns_name| {
47-
RustlsConnect(Some(RustlsConnectData {
48-
hostname: dns_name.to_owned(),
49-
connector: Arc::clone(&self.config).into(),
50-
}))
44+
fn make_tls_connect(&mut self, hostname: &str) -> Result<RustlsConnect, Self::Error> {
45+
ServerName::try_from(hostname).map(|dns_name| {
46+
RustlsConnect(RustlsConnectData {
47+
hostname: dns_name.to_owned(),
48+
connector: Arc::clone(&self.config).into(),
5149
})
52-
.or(Ok(RustlsConnect(None)))
50+
})
5351
}
5452
}
5553

56-
pub struct RustlsConnect(Option<RustlsConnectData>);
54+
pub struct RustlsConnect(RustlsConnectData);
5755

5856
struct RustlsConnectData {
5957
hostname: ServerName<'static>,
@@ -69,14 +67,11 @@ where
6967
type Future = Pin<Box<dyn Future<Output = io::Result<RustlsStream<S>>> + Send>>;
7068

7169
fn connect(self, stream: S) -> Self::Future {
72-
match self.0 {
73-
None => Box::pin(core::future::ready(Err(io::ErrorKind::InvalidInput.into()))),
74-
Some(c) => c
75-
.connector
76-
.connect(c.hostname, stream)
77-
.map_ok(|s| RustlsStream(Box::pin(s)))
78-
.boxed(),
79-
}
70+
self.0
71+
.connector
72+
.connect(self.0.hostname, stream)
73+
.map_ok(|s| RustlsStream(Box::pin(s)))
74+
.boxed()
8075
}
8176
}
8277

@@ -152,12 +147,12 @@ where
152147
mod tests {
153148
use super::*;
154149
use futures::future::TryFutureExt;
150+
use rustls::pki_types::{CertificateDer, UnixTime};
155151
use rustls::{
156152
client::danger::ServerCertVerifier,
157153
client::danger::{HandshakeSignatureValid, ServerCertVerified},
158154
Error, SignatureScheme,
159155
};
160-
use rustls::pki_types::{CertificateDer, UnixTime};
161156

162157
#[derive(Debug)]
163158
struct AcceptAllVerifier {}

0 commit comments

Comments
 (0)