Skip to content

refactor: polish session in HTTP handler #18527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 9 additions & 12 deletions src/query/service/src/servers/http/middleware/session_header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,15 @@ impl ClientSession {
headers: &HeaderMap,
caps: &mut ClientCapabilities,
) -> Result<Option<ClientSession>, String> {
if let Some(v) = headers.get(HEADER_SESSION) {
caps.session_header = true;
let v = v.to_str().unwrap().to_string().trim().to_owned();
let s = if v.is_empty() {
// note that curl -H "X-xx:" not work
Self::new_session(false)
} else {
let header = decode_json_header(HEADER_SESSION, v.as_str())?;
Self::old_session(false, header)
};
Ok(Some(s))
} else if caps.session_header {
if caps.session_header {
if let Some(v) = headers.get(HEADER_SESSION) {
caps.session_header = true;
let v = v.to_str().unwrap().to_string().trim().to_owned();
if !v.is_empty() {
let header = decode_json_header(HEADER_SESSION, &v)?;
return Ok(Some(Self::old_session(false, header)));
};
}
Ok(Some(Self::new_session(false)))
} else {
Ok(None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ async fn query_state_handler(
let http_query_manager = HttpQueryManager::instance();
match http_query_manager.get_query(&query_id) {
Some(query) => {
query.check_client_session_id(&ctx.client_session_id)?;
if let Some(reason) = query.check_removed() {
Err(query_id_removed(&query_id, reason))
} else {
Expand Down
2 changes: 2 additions & 0 deletions tests/sqllogictests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ async-recursion = { workspace = true }
async-trait = { workspace = true }
bollard = { workspace = true }
clap = { workspace = true }
cookie = { workspace = true }
databend-common-exception = { workspace = true }
env_logger = { workspace = true }
futures-util = { workspace = true }
Expand All @@ -34,6 +35,7 @@ testcontainers = { workspace = true }
testcontainers-modules = { workspace = true, features = ["mysql", "redis"] }
thiserror = { workspace = true }
tokio = { workspace = true }
url = { workspace = true }
walkdir = { workspace = true }

[lints]
Expand Down
62 changes: 62 additions & 0 deletions tests/sqllogictests/src/client/global_cookie_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2021 Datafuse Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::sync::RwLock;

use cookie::Cookie;
use reqwest::cookie::CookieStore;
use reqwest::header::HeaderValue;
use url::Url;

pub(crate) struct GlobalCookieStore {
cookies: RwLock<HashMap<String, Cookie<'static>>>,
}

impl GlobalCookieStore {
pub fn new() -> Self {
GlobalCookieStore {
cookies: RwLock::new(HashMap::new()),
}
}
}

impl CookieStore for GlobalCookieStore {
fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &HeaderValue>, _url: &Url) {
let iter = cookie_headers
.filter_map(|val| std::str::from_utf8(val.as_bytes()).ok())
.filter_map(|kv| Cookie::parse(kv).map(|c| c.into_owned()).ok());

let mut guard = self.cookies.write().unwrap();
for cookie in iter {
guard.insert(cookie.name().to_string(), cookie);
}
}

fn cookies(&self, _url: &Url) -> Option<HeaderValue> {
let guard = self.cookies.read().unwrap();
let s: String = guard
.values()
.map(|cookie| cookie.name_value())
.map(|(name, value)| format!("{name}={value}"))
.collect::<Vec<_>>()
.join("; ");

if s.is_empty() {
return None;
}

HeaderValue::from_str(&s).ok()
}
}
91 changes: 32 additions & 59 deletions tests/sqllogictests/src/client/http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,26 @@
// limitations under the License.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;

use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use reqwest::Client;
use reqwest::ClientBuilder;
use reqwest::Response;
use serde::Deserialize;
use sqllogictest::DBOutput;
use sqllogictest::DefaultColumnType;

use crate::error::DSqlLogicTestError::Databend;
use crate::client::global_cookie_store::GlobalCookieStore;
use crate::error::Result;
use crate::util::parser_rows;
use crate::util::HttpSessionConf;

const SESSION_HEADER: &str = "X-DATABEND-SESSION";

pub struct HttpClient {
pub client: Client,
pub session_token: String,
pub session_headers: HeaderMap,
pub debug: bool,
pub session: Option<HttpSessionConf>,
pub port: u16,
Expand Down Expand Up @@ -86,59 +83,28 @@ impl HttpClient {
header.insert("Accept", HeaderValue::from_str("application/json").unwrap());
header.insert(
"X-DATABEND-CLIENT-CAPS",
HeaderValue::from_str("session_header").unwrap(),
HeaderValue::from_str("session_cookie").unwrap(),
);
let cookie_provider = GlobalCookieStore::new();
let client = ClientBuilder::new()
.cookie_provider(Arc::new(cookie_provider))
.default_headers(header)
// https://github.com/hyperium/hyper/issues/2136#issuecomment-589488526
.http2_keep_alive_timeout(Duration::from_secs(15))
.pool_max_idle_per_host(0)
.build()?;
let mut session_headers = HeaderMap::new();
session_headers.insert(SESSION_HEADER, HeaderValue::from_str("").unwrap());
let mut res = Self {
client,
session_token: "".to_string(),
session_headers,
session: None,
debug: false,
port,
};
res.login().await?;
Ok(res)
}

async fn update_session_header(&mut self, response: Response) -> Result<Response> {
if let Some(value) = response.headers().get(SESSION_HEADER) {
let session_header = value.to_str().unwrap().to_owned();
if !session_header.is_empty() {
self.session_headers
.insert(SESSION_HEADER, value.to_owned());
return Ok(response);
}
}
let meta = format!("response={response:?}");
let data = response.text().await.unwrap();
Err(Databend(
format!("{} is empty, {meta}, {data}", SESSION_HEADER,).into(),
))
}
let url = format!("http://127.0.0.1:{}/v1/session/login", port);

async fn login(&mut self) -> Result<()> {
let url = format!("http://127.0.0.1:{}/v1/session/login", self.port);
let response = self
.client
let session_token = client
.post(&url)
.headers(self.session_headers.clone())
.body("{}")
.basic_auth("root", Some(""))
.send()
.await
.inspect_err(|e| {
println!("fail to send to {}: {:?}", url, e);
})?;
let response = self.update_session_header(response).await?;
self.session_token = response
})?
.json::<LoginResponse>()
.await
.inspect_err(|e| {
Expand All @@ -147,7 +113,14 @@ impl HttpClient {
.tokens
.unwrap()
.session_token;
Ok(())

Ok(Self {
client,
session_token,
session: None,
debug: false,
port,
})
}

pub async fn query(&mut self, sql: &str) -> Result<DBOutput<DefaultColumnType>> {
Expand Down Expand Up @@ -204,43 +177,43 @@ impl HttpClient {
}

// Send request and get response by json format
async fn post_query(&mut self, sql: &str, url: &str) -> Result<QueryResponse> {
async fn post_query(&self, sql: &str, url: &str) -> Result<QueryResponse> {
let mut query = HashMap::new();
query.insert("sql", serde_json::to_value(sql)?);
if let Some(session) = &self.session {
query.insert("session", serde_json::to_value(session)?);
}
let response = self
Ok(self
.client
.post(url)
.headers(self.session_headers.clone())
.json(&query)
.bearer_auth(&self.session_token)
.send()
.await
.inspect_err(|e| {
println!("fail to send to {}: {:?}", url, e);
})?;
let response = self.update_session_header(response).await?;
Ok(response.json::<QueryResponse>().await.inspect_err(|e| {
println!("fail to decode json when call {}: {:?}", url, e);
})?)
})?
.json::<QueryResponse>()
.await
.inspect_err(|e| {
println!("fail to decode json when call {}: {:?}", url, e);
})?)
}

async fn poll_query_result(&mut self, url: &str) -> Result<QueryResponse> {
let response = self
async fn poll_query_result(&self, url: &str) -> Result<QueryResponse> {
Ok(self
.client
.get(url)
.bearer_auth(&self.session_token)
.headers(self.session_headers.clone())
.send()
.await
.inspect_err(|e| {
println!("fail to send to {}: {:?}", url, e);
})?;
let response = self.update_session_header(response).await?;
Ok(response.json::<QueryResponse>().await.inspect_err(|e| {
println!("fail to decode json when call {}: {:?}", url, e);
})?)
})?
.json::<QueryResponse>()
.await
.inspect_err(|e| {
println!("fail to decode json when call {}: {:?}", url, e);
})?)
}
}
1 change: 1 addition & 0 deletions tests/sqllogictests/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod global_cookie_store;
mod http_client;
mod mysql_client;
mod ttc_client;
Expand Down
Loading