11use std:: collections:: HashMap ;
2+ use std:: sync:: Arc ;
23
34use crate :: error:: MutinyError ;
45use crate :: storage:: MutinyStorage ;
56use core:: time:: Duration ;
7+ use gloo_net:: websocket:: futures:: WebSocket ;
68use hex_conservative:: DisplayHex ;
79use once_cell:: sync:: Lazy ;
810use payjoin:: receive:: v2:: Enrolled ;
@@ -69,16 +71,67 @@ impl<S: MutinyStorage> PayjoinStorage for S {
6971 }
7072}
7173
72- pub async fn fetch_ohttp_keys ( _ohttp_relay : Url , directory : Url ) -> Result < OhttpKeys , Error > {
73- let http_client = reqwest :: Client :: builder ( ) . build ( ) . unwrap ( ) ;
74+ pub async fn fetch_ohttp_keys ( ohttp_relay : Url , directory : Url ) -> Result < OhttpKeys , Error > {
75+ use futures_util :: { AsyncReadExt , AsyncWriteExt } ;
7476
75- let ohttp_keys_res = http_client
76- . get ( format ! ( "{}/ohttp-keys" , directory. as_ref( ) ) )
77- . send ( )
78- . await ?
79- . bytes ( )
80- . await ?;
81- Ok ( OhttpKeys :: decode ( ohttp_keys_res. as_ref ( ) ) . map_err ( |_| Error :: OhttpDecodeFailed ) ?)
77+ let tls_connector = {
78+ let root_store = futures_rustls:: rustls:: RootCertStore {
79+ roots : webpki_roots:: TLS_SERVER_ROOTS . iter ( ) . cloned ( ) . collect ( ) ,
80+ } ;
81+ let config = futures_rustls:: rustls:: ClientConfig :: builder ( )
82+ . with_root_certificates ( root_store)
83+ . with_no_client_auth ( ) ;
84+ futures_rustls:: TlsConnector :: from ( Arc :: new ( config) )
85+ } ;
86+ let directory_host = directory. host_str ( ) . ok_or ( Error :: BadDirectoryHost ) ?;
87+ let domain = futures_rustls:: rustls:: pki_types:: ServerName :: try_from ( directory_host)
88+ . map_err ( |_| Error :: BadDirectoryHost ) ?
89+ . to_owned ( ) ;
90+
91+ let ws = WebSocket :: open ( & format ! (
92+ "wss://{}:443" ,
93+ ohttp_relay. host_str( ) . ok_or( Error :: BadOhttpWsHost ) ?
94+ ) )
95+ . map_err ( |_| Error :: BadOhttpWsHost ) ?;
96+
97+ let mut tls_stream = tls_connector
98+ . connect ( domain, ws)
99+ . await
100+ . map_err ( |e| Error :: RequestFailed ( e. to_string ( ) ) ) ?;
101+ let ohttp_keys_req = format ! (
102+ "GET /ohttp-keys HTTP/1.1\r \n Host: {}\r \n Connection: close\r \n \r \n " ,
103+ directory_host
104+ ) ;
105+ tls_stream
106+ . write_all ( ohttp_keys_req. as_bytes ( ) )
107+ . await
108+ . map_err ( |e| Error :: RequestFailed ( e. to_string ( ) ) ) ?;
109+ tls_stream. flush ( ) . await . unwrap ( ) ;
110+ let mut response_bytes = Vec :: new ( ) ;
111+ tls_stream. read_to_end ( & mut response_bytes) . await . unwrap ( ) ;
112+ let ( _headers, res_body) = separate_headers_and_body ( & response_bytes) ?;
113+ payjoin:: OhttpKeys :: decode ( & res_body) . map_err ( |_| Error :: OhttpDecodeFailed )
114+ }
115+
116+ fn separate_headers_and_body ( response_bytes : & [ u8 ] ) -> Result < ( & [ u8 ] , & [ u8 ] ) , Error > {
117+ let separator = b"\r \n \r \n " ;
118+
119+ // Search for the separator
120+ if let Some ( position) = response_bytes
121+ . windows ( separator. len ( ) )
122+ . position ( |window| window == separator)
123+ {
124+ // The body starts immediately after the separator
125+ let body_start_index = position + separator. len ( ) ;
126+ let headers = & response_bytes[ ..position] ;
127+ let body = & response_bytes[ body_start_index..] ;
128+
129+ Ok ( ( headers, body) )
130+ } else {
131+ Err ( Error :: RequestFailed (
132+ "No header-body separator found in the response" . to_string ( ) ,
133+ ) )
134+ }
82135}
83136
84137#[ derive( Debug ) ]
@@ -90,6 +143,9 @@ pub enum Error {
90143 OhttpDecodeFailed ,
91144 Shutdown ,
92145 SessionExpired ,
146+ BadDirectoryHost ,
147+ BadOhttpWsHost ,
148+ RequestFailed ( String ) ,
93149}
94150
95151impl std:: error:: Error for Error { }
@@ -104,6 +160,9 @@ impl std::fmt::Display for Error {
104160 Error :: OhttpDecodeFailed => write ! ( f, "Failed to decode ohttp keys" ) ,
105161 Error :: Shutdown => write ! ( f, "Payjoin stopped by application shutdown" ) ,
106162 Error :: SessionExpired => write ! ( f, "Payjoin session expired. Create a new payment request and have the sender try again." ) ,
163+ Error :: BadDirectoryHost => write ! ( f, "Bad directory host" ) ,
164+ Error :: BadOhttpWsHost => write ! ( f, "Bad ohttp ws host" ) ,
165+ Error :: RequestFailed ( e) => write ! ( f, "Request failed: {}" , e) ,
107166 }
108167 }
109168}
0 commit comments