1+ extern crate linked_hash_set;
2+ extern crate once_cell;
13extern crate openssl;
24extern crate openssl_probe;
35
6+ use self :: linked_hash_set:: LinkedHashSet ;
7+ use self :: once_cell:: sync:: OnceCell ;
48use self :: openssl:: error:: ErrorStack ;
9+ use self :: openssl:: ex_data:: Index ;
510use self :: openssl:: hash:: MessageDigest ;
611use self :: openssl:: nid:: Nid ;
712use self :: openssl:: pkcs12:: Pkcs12 ;
813use self :: openssl:: pkey:: PKey ;
914use self :: openssl:: ssl:: {
10- self , MidHandshakeSslStream , SslAcceptor , SslConnector , SslContextBuilder , SslMethod ,
11- SslVerifyMode ,
15+ self , MidHandshakeSslStream , Ssl , SslAcceptor , SslConnector , SslContextBuilder , SslMethod ,
16+ SslSession , SslSessionCacheMode , SslSessionRef , SslVerifyMode ,
1217} ;
1318use self :: openssl:: x509:: { store:: X509StoreBuilder , X509VerifyResult , X509 } ;
19+ use std:: borrow:: Borrow ;
20+ use std:: collections:: hash_map:: { Entry , HashMap } ;
1421use std:: error;
1522use std:: fmt;
23+ use std:: hash:: { Hash , Hasher } ;
1624use std:: io;
17- use std:: sync:: Once ;
25+ use std:: sync:: { Arc , Mutex , Once } ;
1826
1927use self :: openssl:: pkey:: Private ;
2028use { Protocol , TlsAcceptorBuilder , TlsConnectorBuilder } ;
@@ -248,6 +256,8 @@ pub struct TlsConnector {
248256 use_sni : bool ,
249257 accept_invalid_hostnames : bool ,
250258 accept_invalid_certs : bool ,
259+ session_tickets_enabled : bool ,
260+ session_cache : Arc < Mutex < SessionCache > > ,
251261}
252262
253263impl TlsConnector {
@@ -297,11 +307,37 @@ impl TlsConnector {
297307 #[ cfg( target_os = "android" ) ]
298308 load_android_root_certs ( & mut connector) ?;
299309
310+ let session_cache = Arc :: new ( Mutex :: new ( SessionCache :: new ( ) ) ) ;
311+ if builder. session_tickets_enabled {
312+ connector. set_session_cache_mode ( SslSessionCacheMode :: CLIENT ) ;
313+
314+ connector. set_new_session_callback ( {
315+ let session_cache = session_cache. clone ( ) ;
316+ move |ssl, session| {
317+ if let Some ( key) = key_index ( ) . ok ( ) . and_then ( |idx| ssl. ex_data ( idx) ) {
318+ if let Ok ( mut session_cache) = session_cache. lock ( ) {
319+ session_cache. insert ( key. clone ( ) , session) ;
320+ }
321+ }
322+ }
323+ } ) ;
324+ connector. set_remove_session_callback ( {
325+ let session_cache = session_cache. clone ( ) ;
326+ move |_, session| {
327+ if let Ok ( mut session_cache) = session_cache. lock ( ) {
328+ session_cache. remove ( session) ;
329+ }
330+ }
331+ } ) ;
332+ }
333+
300334 Ok ( TlsConnector {
301335 connector : connector. build ( ) ,
302336 use_sni : builder. use_sni ,
303337 accept_invalid_hostnames : builder. accept_invalid_hostnames ,
304338 accept_invalid_certs : builder. accept_invalid_certs ,
339+ session_tickets_enabled : builder. session_tickets_enabled ,
340+ session_cache,
305341 } )
306342 }
307343
@@ -317,6 +353,23 @@ impl TlsConnector {
317353 if self . accept_invalid_certs {
318354 ssl. set_verify ( SslVerifyMode :: NONE ) ;
319355 }
356+ if self . session_tickets_enabled {
357+ let key = SessionKey {
358+ host : domain. to_string ( ) ,
359+ } ;
360+
361+ if let Ok ( mut session_cache) = self . session_cache . lock ( ) {
362+ if let Some ( session) = session_cache. get ( & key) {
363+ // Note: the `unsafe`-ty here is because the `session` is required to come from the
364+ // same SSL_CTX that the ssl object (`ssl`) is from, since it maintains internal
365+ // pointers and refcounts. Here, we only have one SSL_CTX, so this is safe.
366+ unsafe { ssl. set_session ( & session) ? } ;
367+ }
368+ }
369+
370+ let idx = key_index ( ) ?;
371+ ssl. set_ex_data ( idx, key) ;
372+ }
320373
321374 let s = ssl. connect ( domain, stream) ?;
322375 Ok ( TlsStream ( s) )
@@ -452,3 +505,151 @@ impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
452505 self . 0 . flush ( )
453506 }
454507}
508+
509+ fn key_index ( ) -> Result < Index < Ssl , SessionKey > , ErrorStack > {
510+ static IDX : OnceCell < Index < Ssl , SessionKey > > = OnceCell :: new ( ) ;
511+ IDX . get_or_try_init ( || Ssl :: new_ex_index ( ) ) . map ( |v| * v)
512+ }
513+
514+ #[ derive( Hash , PartialEq , Eq , Clone ) ]
515+ pub struct SessionKey {
516+ pub host : String ,
517+ }
518+
519+ #[ derive( Clone ) ]
520+ struct HashSession ( SslSession ) ;
521+
522+ impl PartialEq for HashSession {
523+ fn eq ( & self , other : & HashSession ) -> bool {
524+ self . 0 . id ( ) == other. 0 . id ( )
525+ }
526+ }
527+
528+ impl Eq for HashSession { }
529+
530+ impl Hash for HashSession {
531+ fn hash < H > ( & self , state : & mut H )
532+ where
533+ H : Hasher ,
534+ {
535+ self . 0 . id ( ) . hash ( state) ;
536+ }
537+ }
538+
539+ impl Borrow < [ u8 ] > for HashSession {
540+ fn borrow ( & self ) -> & [ u8 ] {
541+ self . 0 . id ( )
542+ }
543+ }
544+
545+ pub struct SessionCache {
546+ sessions : HashMap < SessionKey , LinkedHashSet < HashSession > > ,
547+ reverse : HashMap < HashSession , SessionKey > ,
548+ }
549+
550+ impl SessionCache {
551+ pub fn new ( ) -> SessionCache {
552+ SessionCache {
553+ sessions : HashMap :: new ( ) ,
554+ reverse : HashMap :: new ( ) ,
555+ }
556+ }
557+
558+ pub fn insert ( & mut self , key : SessionKey , session : SslSession ) {
559+ let session = HashSession ( session) ;
560+
561+ self . sessions
562+ . entry ( key. clone ( ) )
563+ . or_insert_with ( LinkedHashSet :: new)
564+ . insert ( session. clone ( ) ) ;
565+ self . reverse . insert ( session. clone ( ) , key) ;
566+ }
567+
568+ pub fn get ( & mut self , key : & SessionKey ) -> Option < SslSession > {
569+ let session = {
570+ let sessions = self . sessions . get_mut ( key) ?;
571+ sessions. front ( ) . cloned ( ) ?. 0
572+ } ;
573+
574+ #[ cfg( ossl111) ]
575+ {
576+ use self :: openssl:: ssl:: SslVersion ;
577+
578+ // https://tools.ietf.org/html/rfc8446#appendix-C.4
579+ // OpenSSL will remove the session from its cache after the handshake completes anyway, but this ensures
580+ // that concurrent handshakes don't end up with the same session.
581+ if session. protocol_version ( ) == SslVersion :: TLS1_3 {
582+ self . remove ( & session) ;
583+ }
584+ }
585+
586+ Some ( session)
587+ }
588+
589+ pub fn remove ( & mut self , session : & SslSessionRef ) {
590+ let key = match self . reverse . remove ( session. id ( ) ) {
591+ Some ( key) => key,
592+ None => return ,
593+ } ;
594+
595+ if let Entry :: Occupied ( mut sessions) = self . sessions . entry ( key) {
596+ sessions. get_mut ( ) . remove ( session. id ( ) ) ;
597+ if sessions. get ( ) . is_empty ( ) {
598+ sessions. remove ( ) ;
599+ }
600+ }
601+ }
602+ }
603+
604+ #[ cfg( test) ]
605+ mod tests {
606+ use std:: io:: { Read , Write } ;
607+ use std:: net:: TcpStream ;
608+
609+ use crate :: TlsConnector ;
610+
611+ fn connect_and_assert ( tls : & TlsConnector , domain : & str , port : u16 , should_resume : bool ) {
612+ let s = TcpStream :: connect ( ( domain, port) ) . unwrap ( ) ;
613+ let mut stream = tls. connect ( domain, s) . unwrap ( ) ;
614+
615+ // Must write to the stream, as OpenSSL doesn't appear to call the
616+ // session callback until we do.
617+ stream. write_all ( b"GET / HTTP/1.0\r \n \r \n " ) . unwrap ( ) ;
618+ let mut result = vec ! [ ] ;
619+ stream. read_to_end ( & mut result) . unwrap ( ) ;
620+
621+ assert_eq ! ( ( stream. 0 ) . 0 . ssl( ) . session_reused( ) , should_resume) ;
622+
623+ // Must shut down properly, or OpenSSL will invalidate the session.
624+ stream. shutdown ( ) . unwrap ( ) ;
625+ }
626+
627+ #[ test]
628+ fn connect_no_session_ticket_resumption ( ) {
629+ let tls = TlsConnector :: new ( ) . unwrap ( ) ;
630+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
631+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
632+ }
633+
634+ #[ test]
635+ fn connect_session_ticket_resumption ( ) {
636+ let mut builder = TlsConnector :: builder ( ) ;
637+ builder. session_tickets_enabled ( true ) ;
638+ let tls = builder. build ( ) . unwrap ( ) ;
639+
640+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
641+ connect_and_assert ( & tls, "google.com" , 443 , true ) ;
642+ }
643+
644+ #[ test]
645+ fn connect_session_ticket_resumption_two_sites ( ) {
646+ let mut builder = TlsConnector :: builder ( ) ;
647+ builder. session_tickets_enabled ( true ) ;
648+ let tls = builder. build ( ) . unwrap ( ) ;
649+
650+ connect_and_assert ( & tls, "google.com" , 443 , false ) ;
651+ connect_and_assert ( & tls, "mozilla.org" , 443 , false ) ;
652+ connect_and_assert ( & tls, "google.com" , 443 , true ) ;
653+ connect_and_assert ( & tls, "mozilla.org" , 443 , true ) ;
654+ }
655+ }
0 commit comments