1515
1616//! S3 client to perform bucket and object operations
1717
18+ use std:: fmt:: Debug ;
1819use std:: fs:: File ;
1920use std:: io:: prelude:: * ;
2021use std:: mem;
2122use std:: path:: { Path , PathBuf } ;
2223use std:: sync:: { Arc , OnceLock } ;
2324
2425use crate :: s3:: builders:: { BucketExists , ComposeSource } ;
26+ pub use crate :: s3:: client:: hooks:: RequestLifecycleHooks ;
2527use crate :: s3:: creds:: Provider ;
2628use crate :: s3:: error:: { Error , ErrorCode , ErrorResponse } ;
27- use crate :: s3:: http:: BaseUrl ;
29+ use crate :: s3:: http:: { BaseUrl , Url } ;
2830use crate :: s3:: multimap:: { Multimap , MultimapExt } ;
2931use crate :: s3:: response:: a_response_traits:: { HasEtagFromHeaders , HasS3Fields } ;
3032use crate :: s3:: response:: * ;
@@ -35,9 +37,10 @@ use crate::s3::utils::{EMPTY_SHA256, sha256_hash_sb, to_amz_date, utc_now};
3537use bytes:: Bytes ;
3638use dashmap:: DashMap ;
3739use http:: HeaderMap ;
38- use hyper:: http:: Method ;
40+ pub use hyper:: http:: Method ;
3941use rand:: Rng ;
4042use reqwest:: Body ;
43+ pub use reqwest:: { Error as ReqwestError , Response } ;
4144
4245mod append_object;
4346mod bucket_exists;
@@ -69,6 +72,7 @@ mod get_object_tagging;
6972mod get_presigned_object_url;
7073mod get_presigned_post_form_data;
7174mod get_region;
75+ pub mod hooks;
7276mod list_buckets;
7377mod list_objects;
7478mod listen_bucket_notification;
@@ -123,6 +127,7 @@ pub const MAX_MULTIPART_COUNT: u16 = 10_000;
123127pub struct ClientBuilder {
124128 base_url : BaseUrl ,
125129 provider : Option < Arc < dyn Provider + Send + Sync + ' static > > ,
130+ client_hooks : Vec < Arc < dyn RequestLifecycleHooks + Send + Sync + ' static > > ,
126131 ssl_cert_file : Option < PathBuf > ,
127132 ignore_cert_check : Option < bool > ,
128133 app_info : Option < ( String , String ) > ,
@@ -138,6 +143,13 @@ impl ClientBuilder {
138143 }
139144 }
140145
146+ /// Add a client hook to the builder. Hooks will be called after each other in
147+ /// order they were added.
148+ pub fn hook ( mut self , hooks : Arc < dyn RequestLifecycleHooks + Send + Sync + ' static > ) -> Self {
149+ self . client_hooks . push ( hooks) ;
150+ self
151+ }
152+
141153 /// Set the credential provider. If not, set anonymous access is used.
142154 pub fn provider < P : Provider + Send + Sync + ' static > ( mut self , provider : Option < P > ) -> Self {
143155 self . provider = provider. map ( |p| Arc :: new ( p) as Arc < dyn Provider + Send + Sync + ' static > ) ;
@@ -209,6 +221,7 @@ impl ClientBuilder {
209221 shared : Arc :: new ( SharedClientItems {
210222 base_url : self . base_url ,
211223 provider : self . provider ,
224+ client_hooks : self . client_hooks ,
212225 ..Default :: default ( )
213226 } ) ,
214227 } )
@@ -427,55 +440,69 @@ impl Client {
427440 body : Option < Arc < SegmentedBytes > > ,
428441 retry : bool ,
429442 ) -> Result < reqwest:: Response , Error > {
430- let url = self . shared . base_url . build_url (
443+ let mut url = self . shared . base_url . build_url (
431444 method,
432445 region,
433446 query_params,
434447 bucket_name,
435448 object_name,
436449 ) ?;
450+ let mut extensions = http:: Extensions :: default ( ) ;
451+ headers. add ( "Host" , url. host_header_value ( ) ) ;
437452
438- {
439- headers. add ( "Host" , url. host_header_value ( ) ) ;
440- let sha256: String = match * method {
441- Method :: PUT | Method :: POST => {
442- if !headers. contains_key ( "Content-Type" ) {
443- headers. add ( "Content-Type" , "application/octet-stream" ) ;
444- }
445- let len: usize = body. as_ref ( ) . map_or ( 0 , |b| b. len ( ) ) ;
446- headers. add ( "Content-Length" , len. to_string ( ) ) ;
447- match body {
448- None => EMPTY_SHA256 . into ( ) ,
449- Some ( ref v) => {
450- let clone = v. clone ( ) ;
451- async_std:: task:: spawn_blocking ( move || sha256_hash_sb ( clone) ) . await
452- }
453- }
453+ let sha256: String = match * method {
454+ Method :: PUT | Method :: POST => {
455+ if !headers. contains_key ( "Content-Type" ) {
456+ headers. add ( "Content-Type" , "application/octet-stream" ) ;
454457 }
455- _ => EMPTY_SHA256 . into ( ) ,
456- } ;
457- headers. add ( "x-amz-content-sha256" , sha256. clone ( ) ) ;
458-
459- let date = utc_now ( ) ;
460- headers. add ( "x-amz-date" , to_amz_date ( date) ) ;
461- if let Some ( p) = & self . shared . provider {
462- let creds = p. fetch ( ) ;
463- if creds. session_token . is_some ( ) {
464- headers. add ( "X-Amz-Security-Token" , creds. session_token . unwrap ( ) ) ;
458+ let len: usize = body. as_ref ( ) . map_or ( 0 , |b| b. len ( ) ) ;
459+ headers. add ( "Content-Length" , len. to_string ( ) ) ;
460+ match body {
461+ None => EMPTY_SHA256 . into ( ) ,
462+ Some ( ref v) => {
463+ let clone = v. clone ( ) ;
464+ async_std:: task:: spawn_blocking ( move || sha256_hash_sb ( clone) ) . await
465+ }
465466 }
466- sign_v4_s3 (
467- method,
468- & url. path ,
469- region,
470- headers,
471- query_params,
472- & creds. access_key ,
473- & creds. secret_key ,
474- & sha256,
475- date,
476- ) ;
477467 }
468+ _ => EMPTY_SHA256 . into ( ) ,
469+ } ;
470+ headers. add ( "x-amz-content-sha256" , sha256. clone ( ) ) ;
471+
472+ let date = utc_now ( ) ;
473+ headers. add ( "x-amz-date" , to_amz_date ( date) ) ;
474+
475+ self . run_before_signing_hooks (
476+ method,
477+ & mut url,
478+ region,
479+ headers,
480+ query_params,
481+ bucket_name,
482+ object_name,
483+ body. clone ( ) ,
484+ & mut extensions,
485+ )
486+ . await ?;
487+
488+ if let Some ( p) = & self . shared . provider {
489+ let creds = p. fetch ( ) ;
490+ if creds. session_token . is_some ( ) {
491+ headers. add ( "X-Amz-Security-Token" , creds. session_token . unwrap ( ) ) ;
492+ }
493+ sign_v4_s3 (
494+ method,
495+ & url. path ,
496+ region,
497+ headers,
498+ query_params,
499+ & creds. access_key ,
500+ & creds. secret_key ,
501+ & sha256,
502+ date,
503+ ) ;
478504 }
505+
479506 let mut req = self . http_client . request ( method. clone ( ) , url. to_string ( ) ) ;
480507
481508 for ( key, values) in headers. iter_all ( ) {
@@ -504,7 +531,7 @@ impl Client {
504531
505532 if ( * method == Method :: PUT ) || ( * method == Method :: POST ) {
506533 //TODO: why-oh-why first collect into a vector and then iterate to a stream?
507- let bytes_vec: Vec < Bytes > = match body {
534+ let bytes_vec: Vec < Bytes > = match body. clone ( ) {
508535 Some ( v) => v. iter ( ) . collect ( ) ,
509536 None => Vec :: new ( ) ,
510537 } ;
@@ -516,8 +543,22 @@ impl Client {
516543 req = req. body ( Body :: wrap_stream ( stream) ) ;
517544 }
518545
519- let resp: reqwest:: Response = req. send ( ) . await ?;
546+ let resp = req. send ( ) . await ;
547+
548+ self . run_after_execute_hooks (
549+ method,
550+ & url,
551+ region,
552+ headers,
553+ query_params,
554+ bucket_name,
555+ object_name,
556+ & resp,
557+ & mut extensions,
558+ )
559+ . await ;
520560
561+ let resp = resp?;
521562 if resp. status ( ) . is_success ( ) {
522563 return Ok ( resp) ;
523564 }
@@ -596,12 +637,71 @@ impl Client {
596637 )
597638 . await
598639 }
640+
641+ async fn run_after_execute_hooks (
642+ & self ,
643+ method : & Method ,
644+ url : & Url ,
645+ region : & str ,
646+ headers : & mut Multimap ,
647+ query_params : & Multimap ,
648+ bucket_name : Option < & str > ,
649+ object_name : Option < & str > ,
650+ resp : & Result < Response , reqwest:: Error > ,
651+ extensions : & mut http:: Extensions ,
652+ ) {
653+ for hook in self . shared . client_hooks . iter ( ) {
654+ hook. after_execute (
655+ method,
656+ url,
657+ region,
658+ headers,
659+ query_params,
660+ bucket_name,
661+ object_name,
662+ resp,
663+ extensions,
664+ )
665+ . await ;
666+ }
667+ }
668+
669+ async fn run_before_signing_hooks (
670+ & self ,
671+ method : & Method ,
672+ url : & mut Url ,
673+ region : & str ,
674+ headers : & mut Multimap ,
675+ query_params : & Multimap ,
676+ bucket_name : Option < & str > ,
677+ object_name : Option < & str > ,
678+ body : Option < Arc < SegmentedBytes > > ,
679+ extensions : & mut http:: Extensions ,
680+ ) -> Result < ( ) , Error > {
681+ for hook in self . shared . client_hooks . iter ( ) {
682+ hook. before_signing_mut (
683+ method,
684+ url,
685+ region,
686+ headers,
687+ query_params,
688+ bucket_name,
689+ object_name,
690+ body. as_deref ( ) ,
691+ extensions,
692+ )
693+ . await
694+ . inspect_err ( |e| log:: warn!( "Hook {} failed {e}" , hook. name( ) ) ) ?;
695+ }
696+ Ok ( ( ) )
697+ }
599698}
600699
601700#[ derive( Clone , Debug , Default ) ]
602701pub ( crate ) struct SharedClientItems {
603702 pub ( crate ) base_url : BaseUrl ,
604703 pub ( crate ) provider : Option < Arc < dyn Provider + Send + Sync + ' static > > ,
704+ client_hooks : Vec < Arc < dyn RequestLifecycleHooks + Send + Sync + ' static > > ,
605705 region_map : DashMap < String , String > ,
606706 express : OnceLock < bool > ,
607707}
0 commit comments