Skip to content

wip: minio client-hooks #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ version = "0.12.18"
default-features = false
features = ["stream"]


[features]
default = ["default-tls", "dep:hmac", "dep:sha2"]
default-tls = ["reqwest/default-tls"]
Expand Down
184 changes: 142 additions & 42 deletions src/s3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@

//! S3 client to perform bucket and object operations

use std::fmt::Debug;
use std::fs::File;
use std::io::prelude::*;
use std::mem;
use std::path::{Path, PathBuf};
use std::sync::{Arc, OnceLock};

use crate::s3::builders::{BucketExists, ComposeSource};
pub use crate::s3::client::hooks::RequestLifecycleHooks;
use crate::s3::creds::Provider;
use crate::s3::error::{Error, ErrorCode, ErrorResponse};
use crate::s3::http::BaseUrl;
use crate::s3::http::{BaseUrl, Url};
use crate::s3::multimap::{Multimap, MultimapExt};
use crate::s3::response::a_response_traits::{HasEtagFromHeaders, HasS3Fields};
use crate::s3::response::*;
Expand All @@ -35,9 +37,10 @@ use crate::s3::utils::{EMPTY_SHA256, sha256_hash_sb, to_amz_date, utc_now};
use bytes::Bytes;
use dashmap::DashMap;
use http::HeaderMap;
use hyper::http::Method;
pub use hyper::http::Method;
use rand::Rng;
use reqwest::Body;
pub use reqwest::{Error as ReqwestError, Response};

mod append_object;
mod bucket_exists;
Expand Down Expand Up @@ -69,6 +72,7 @@ mod get_object_tagging;
mod get_presigned_object_url;
mod get_presigned_post_form_data;
mod get_region;
pub mod hooks;
mod list_buckets;
mod list_objects;
mod listen_bucket_notification;
Expand Down Expand Up @@ -123,6 +127,7 @@ pub const MAX_MULTIPART_COUNT: u16 = 10_000;
pub struct ClientBuilder {
base_url: BaseUrl,
provider: Option<Arc<dyn Provider + Send + Sync + 'static>>,
client_hooks: Vec<Arc<dyn RequestLifecycleHooks + Send + Sync + 'static>>,
ssl_cert_file: Option<PathBuf>,
ignore_cert_check: Option<bool>,
app_info: Option<(String, String)>,
Expand All @@ -138,6 +143,13 @@ impl ClientBuilder {
}
}

/// Add a client hook to the builder. Hooks will be called after each other in
/// order they were added.
pub fn hook(mut self, hooks: Arc<dyn RequestLifecycleHooks + Send + Sync + 'static>) -> Self {
self.client_hooks.push(hooks);
self
}

/// Set the credential provider. If not, set anonymous access is used.
pub fn provider<P: Provider + Send + Sync + 'static>(mut self, provider: Option<P>) -> Self {
self.provider = provider.map(|p| Arc::new(p) as Arc<dyn Provider + Send + Sync + 'static>);
Expand Down Expand Up @@ -209,6 +221,7 @@ impl ClientBuilder {
shared: Arc::new(SharedClientItems {
base_url: self.base_url,
provider: self.provider,
client_hooks: self.client_hooks,
..Default::default()
}),
})
Expand Down Expand Up @@ -427,55 +440,69 @@ impl Client {
body: Option<Arc<SegmentedBytes>>,
retry: bool,
) -> Result<reqwest::Response, Error> {
let url = self.shared.base_url.build_url(
let mut url = self.shared.base_url.build_url(
method,
region,
query_params,
bucket_name,
object_name,
)?;
let mut extensions = http::Extensions::default();
headers.add("Host", url.host_header_value());

{
headers.add("Host", url.host_header_value());
let sha256: String = match *method {
Method::PUT | Method::POST => {
if !headers.contains_key("Content-Type") {
headers.add("Content-Type", "application/octet-stream");
}
let len: usize = body.as_ref().map_or(0, |b| b.len());
headers.add("Content-Length", len.to_string());
match body {
None => EMPTY_SHA256.into(),
Some(ref v) => {
let clone = v.clone();
async_std::task::spawn_blocking(move || sha256_hash_sb(clone)).await
}
}
let sha256: String = match *method {
Method::PUT | Method::POST => {
if !headers.contains_key("Content-Type") {
headers.add("Content-Type", "application/octet-stream");
}
_ => EMPTY_SHA256.into(),
};
headers.add("x-amz-content-sha256", sha256.clone());

let date = utc_now();
headers.add("x-amz-date", to_amz_date(date));
if let Some(p) = &self.shared.provider {
let creds = p.fetch();
if creds.session_token.is_some() {
headers.add("X-Amz-Security-Token", creds.session_token.unwrap());
let len: usize = body.as_ref().map_or(0, |b| b.len());
headers.add("Content-Length", len.to_string());
match body {
None => EMPTY_SHA256.into(),
Some(ref v) => {
let clone = v.clone();
async_std::task::spawn_blocking(move || sha256_hash_sb(clone)).await
}
}
sign_v4_s3(
method,
&url.path,
region,
headers,
query_params,
&creds.access_key,
&creds.secret_key,
&sha256,
date,
);
}
_ => EMPTY_SHA256.into(),
};
headers.add("x-amz-content-sha256", sha256.clone());

let date = utc_now();
headers.add("x-amz-date", to_amz_date(date));

self.run_before_signing_hooks(
method,
&mut url,
region,
headers,
query_params,
bucket_name,
object_name,
body.clone(),
&mut extensions,
)
.await?;

if let Some(p) = &self.shared.provider {
let creds = p.fetch();
if creds.session_token.is_some() {
headers.add("X-Amz-Security-Token", creds.session_token.unwrap());
}
sign_v4_s3(
method,
&url.path,
region,
headers,
query_params,
&creds.access_key,
&creds.secret_key,
&sha256,
date,
);
}

let mut req = self.http_client.request(method.clone(), url.to_string());

for (key, values) in headers.iter_all() {
Expand Down Expand Up @@ -504,7 +531,7 @@ impl Client {

if (*method == Method::PUT) || (*method == Method::POST) {
//TODO: why-oh-why first collect into a vector and then iterate to a stream?
let bytes_vec: Vec<Bytes> = match body {
let bytes_vec: Vec<Bytes> = match body.clone() {
Some(v) => v.iter().collect(),
None => Vec::new(),
};
Expand All @@ -516,8 +543,22 @@ impl Client {
req = req.body(Body::wrap_stream(stream));
}

let resp: reqwest::Response = req.send().await?;
let resp = req.send().await;

self.run_after_execute_hooks(
method,
&url,
region,
headers,
query_params,
bucket_name,
object_name,
&resp,
&mut extensions,
)
.await;

let resp = resp?;
if resp.status().is_success() {
return Ok(resp);
}
Expand Down Expand Up @@ -596,12 +637,71 @@ impl Client {
)
.await
}

async fn run_after_execute_hooks(
&self,
method: &Method,
url: &Url,
region: &str,
headers: &mut Multimap,
query_params: &Multimap,
bucket_name: Option<&str>,
object_name: Option<&str>,
resp: &Result<Response, reqwest::Error>,
extensions: &mut http::Extensions,
) {
for hook in self.shared.client_hooks.iter() {
hook.after_execute(
method,
url,
region,
headers,
query_params,
bucket_name,
object_name,
resp,
extensions,
)
.await;
}
}

async fn run_before_signing_hooks(
&self,
method: &Method,
url: &mut Url,
region: &str,
headers: &mut Multimap,
query_params: &Multimap,
bucket_name: Option<&str>,
object_name: Option<&str>,
body: Option<Arc<SegmentedBytes>>,
extensions: &mut http::Extensions,
) -> Result<(), Error> {
for hook in self.shared.client_hooks.iter() {
hook.before_signing_mut(
method,
url,
region,
headers,
query_params,
bucket_name,
object_name,
body.as_deref(),
extensions,
)
.await
.inspect_err(|e| log::warn!("Hook {} failed {e}", hook.name()))?;
}
Ok(())
}
}

#[derive(Clone, Debug, Default)]
pub(crate) struct SharedClientItems {
pub(crate) base_url: BaseUrl,
pub(crate) provider: Option<Arc<dyn Provider + Send + Sync + 'static>>,
client_hooks: Vec<Arc<dyn RequestLifecycleHooks + Send + Sync + 'static>>,
region_map: DashMap<String, String>,
express: OnceLock<bool>,
}
Expand Down
43 changes: 43 additions & 0 deletions src/s3/client/hooks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
pub use http::Extensions;

use crate::s3::error::Error;
use crate::s3::http::Url;
use crate::s3::multimap::Multimap;
use crate::s3::segmented_bytes::SegmentedBytes;
use http::Method;
use reqwest::Response;
use std::fmt::Debug;

#[async_trait::async_trait]
pub trait RequestLifecycleHooks: Debug {
fn name(&self) -> &'static str;

async fn before_signing_mut(
&self,
_method: &Method,
_url: &mut Url,
_region: &str,
_headers: &mut Multimap,
_query_params: &Multimap,
_bucket_name: Option<&str>,
_object_name: Option<&str>,
_body: Option<&SegmentedBytes>,
_extensions: &mut Extensions,
) -> Result<(), Error> {
Ok(())
}

async fn after_execute(
&self,
_method: &Method,
_url: &Url,
_region: &str,
_headers: &Multimap,
_query_params: &Multimap,
_bucket_name: Option<&str>,
_object_name: Option<&str>,
_resp: &Result<Response, reqwest::Error>,
_extensions: &mut Extensions,
) {
}
}
7 changes: 7 additions & 0 deletions src/s3/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ pub enum Error {
NoClientProvided,
TagDecodingError(String, String),
ContentLengthUnknown,
Hook {
source: Box<dyn std::error::Error + Send + Sync>,
name: String,
},
}

impl std::error::Error for Error {}
Expand Down Expand Up @@ -343,6 +347,9 @@ impl fmt::Display for Error {
write!(f, "tag decoding failed: {error_message} on input '{input}'")
}
Error::ContentLengthUnknown => write!(f, "content length is unknown"),
Error::Hook { source, name } => {
write!(f, "{} interceptor failed: '{}'", name, source)
}
}
}
}
Expand Down