Skip to content

Commit 35c6a3a

Browse files
committed
ci: http client in sql logic test use cookie.
1 parent f91a5c6 commit 35c6a3a

File tree

5 files changed

+101
-59
lines changed

5 files changed

+101
-59
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/sqllogictests/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ async-recursion = { workspace = true }
1818
async-trait = { workspace = true }
1919
bollard = { workspace = true }
2020
clap = { workspace = true }
21+
cookie = { workspace = true }
2122
databend-common-exception = { workspace = true }
2223
env_logger = { workspace = true }
2324
futures-util = { workspace = true }
@@ -34,6 +35,7 @@ testcontainers = { workspace = true }
3435
testcontainers-modules = { workspace = true, features = ["mysql", "redis"] }
3536
thiserror = { workspace = true }
3637
tokio = { workspace = true }
38+
url = { workspace = true }
3739
walkdir = { workspace = true }
3840

3941
[lints]
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright 2021 Datafuse Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::collections::HashMap;
16+
use std::sync::RwLock;
17+
18+
use cookie::Cookie;
19+
use reqwest::cookie::CookieStore;
20+
use reqwest::header::HeaderValue;
21+
use url::Url;
22+
23+
pub(crate) struct GlobalCookieStore {
24+
cookies: RwLock<HashMap<String, Cookie<'static>>>,
25+
}
26+
27+
impl GlobalCookieStore {
28+
pub fn new() -> Self {
29+
GlobalCookieStore {
30+
cookies: RwLock::new(HashMap::new()),
31+
}
32+
}
33+
}
34+
35+
impl CookieStore for GlobalCookieStore {
36+
fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &HeaderValue>, _url: &Url) {
37+
let iter = cookie_headers
38+
.filter_map(|val| std::str::from_utf8(val.as_bytes()).ok())
39+
.filter_map(|kv| Cookie::parse(kv).map(|c| c.into_owned()).ok());
40+
41+
let mut guard = self.cookies.write().unwrap();
42+
for cookie in iter {
43+
guard.insert(cookie.name().to_string(), cookie);
44+
}
45+
}
46+
47+
fn cookies(&self, _url: &Url) -> Option<HeaderValue> {
48+
let guard = self.cookies.read().unwrap();
49+
let s: String = guard
50+
.values()
51+
.map(|cookie| cookie.name_value())
52+
.map(|(name, value)| format!("{name}={value}"))
53+
.collect::<Vec<_>>()
54+
.join("; ");
55+
56+
if s.is_empty() {
57+
return None;
58+
}
59+
60+
HeaderValue::from_str(&s).ok()
61+
}
62+
}

tests/sqllogictests/src/client/http_client.rs

Lines changed: 34 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,28 @@
1313
// limitations under the License.
1414

1515
use std::collections::HashMap;
16+
use std::sync::Arc;
1617
use std::time::Duration;
1718
use std::time::Instant;
1819

20+
use reqwest::cookie::CookieStore;
1921
use reqwest::header::HeaderMap;
2022
use reqwest::header::HeaderValue;
2123
use reqwest::Client;
2224
use reqwest::ClientBuilder;
23-
use reqwest::Response;
2425
use serde::Deserialize;
2526
use sqllogictest::DBOutput;
2627
use sqllogictest::DefaultColumnType;
28+
use url::Url;
2729

28-
use crate::error::DSqlLogicTestError::Databend;
30+
use crate::client::global_cookie_store::GlobalCookieStore;
2931
use crate::error::Result;
3032
use crate::util::parser_rows;
3133
use crate::util::HttpSessionConf;
3234

33-
const SESSION_HEADER: &str = "X-DATABEND-SESSION";
34-
3535
pub struct HttpClient {
3636
pub client: Client,
3737
pub session_token: String,
38-
pub session_headers: HeaderMap,
3938
pub debug: bool,
4039
pub session: Option<HttpSessionConf>,
4140
pub port: u16,
@@ -86,59 +85,28 @@ impl HttpClient {
8685
header.insert("Accept", HeaderValue::from_str("application/json").unwrap());
8786
header.insert(
8887
"X-DATABEND-CLIENT-CAPS",
89-
HeaderValue::from_str("session_header").unwrap(),
88+
HeaderValue::from_str("session_cookie").unwrap(),
9089
);
90+
let cookie_provider = GlobalCookieStore::new();
9191
let client = ClientBuilder::new()
92+
.cookie_provider(Arc::new(cookie_provider))
9293
.default_headers(header)
9394
// https://github.com/hyperium/hyper/issues/2136#issuecomment-589488526
9495
.http2_keep_alive_timeout(Duration::from_secs(15))
9596
.pool_max_idle_per_host(0)
9697
.build()?;
97-
let mut session_headers = HeaderMap::new();
98-
session_headers.insert(SESSION_HEADER, HeaderValue::from_str("").unwrap());
99-
let mut res = Self {
100-
client,
101-
session_token: "".to_string(),
102-
session_headers,
103-
session: None,
104-
debug: false,
105-
port,
106-
};
107-
res.login().await?;
108-
Ok(res)
109-
}
11098

111-
async fn update_session_header(&mut self, response: Response) -> Result<Response> {
112-
if let Some(value) = response.headers().get(SESSION_HEADER) {
113-
let session_header = value.to_str().unwrap().to_owned();
114-
if !session_header.is_empty() {
115-
self.session_headers
116-
.insert(SESSION_HEADER, value.to_owned());
117-
return Ok(response);
118-
}
119-
}
120-
let meta = format!("response={response:?}");
121-
let data = response.text().await.unwrap();
122-
Err(Databend(
123-
format!("{} is empty, {meta}, {data}", SESSION_HEADER,).into(),
124-
))
125-
}
99+
let url = format!("http://127.0.0.1:{}/v1/session/login", port);
126100

127-
async fn login(&mut self) -> Result<()> {
128-
let url = format!("http://127.0.0.1:{}/v1/session/login", self.port);
129-
let response = self
130-
.client
101+
let session_token = client
131102
.post(&url)
132-
.headers(self.session_headers.clone())
133103
.body("{}")
134104
.basic_auth("root", Some(""))
135105
.send()
136106
.await
137107
.inspect_err(|e| {
138108
println!("fail to send to {}: {:?}", url, e);
139-
})?;
140-
let response = self.update_session_header(response).await?;
141-
self.session_token = response
109+
})?
142110
.json::<LoginResponse>()
143111
.await
144112
.inspect_err(|e| {
@@ -147,7 +115,14 @@ impl HttpClient {
147115
.tokens
148116
.unwrap()
149117
.session_token;
150-
Ok(())
118+
119+
Ok(Self {
120+
client,
121+
session_token,
122+
session: None,
123+
debug: false,
124+
port,
125+
})
151126
}
152127

153128
pub async fn query(&mut self, sql: &str) -> Result<DBOutput<DefaultColumnType>> {
@@ -204,43 +179,43 @@ impl HttpClient {
204179
}
205180

206181
// Send request and get response by json format
207-
async fn post_query(&mut self, sql: &str, url: &str) -> Result<QueryResponse> {
182+
async fn post_query(&self, sql: &str, url: &str) -> Result<QueryResponse> {
208183
let mut query = HashMap::new();
209184
query.insert("sql", serde_json::to_value(sql)?);
210185
if let Some(session) = &self.session {
211186
query.insert("session", serde_json::to_value(session)?);
212187
}
213-
let response = self
188+
Ok(self
214189
.client
215190
.post(url)
216-
.headers(self.session_headers.clone())
217191
.json(&query)
218192
.bearer_auth(&self.session_token)
219193
.send()
220194
.await
221195
.inspect_err(|e| {
222196
println!("fail to send to {}: {:?}", url, e);
223-
})?;
224-
let response = self.update_session_header(response).await?;
225-
Ok(response.json::<QueryResponse>().await.inspect_err(|e| {
226-
println!("fail to decode json when call {}: {:?}", url, e);
227-
})?)
197+
})?
198+
.json::<QueryResponse>()
199+
.await
200+
.inspect_err(|e| {
201+
println!("fail to decode json when call {}: {:?}", url, e);
202+
})?)
228203
}
229204

230-
async fn poll_query_result(&mut self, url: &str) -> Result<QueryResponse> {
231-
let response = self
205+
async fn poll_query_result(&self, url: &str) -> Result<QueryResponse> {
206+
Ok(self
232207
.client
233208
.get(url)
234209
.bearer_auth(&self.session_token)
235-
.headers(self.session_headers.clone())
236210
.send()
237211
.await
238212
.inspect_err(|e| {
239213
println!("fail to send to {}: {:?}", url, e);
240-
})?;
241-
let response = self.update_session_header(response).await?;
242-
Ok(response.json::<QueryResponse>().await.inspect_err(|e| {
243-
println!("fail to decode json when call {}: {:?}", url, e);
244-
})?)
214+
})?
215+
.json::<QueryResponse>()
216+
.await
217+
.inspect_err(|e| {
218+
println!("fail to decode json when call {}: {:?}", url, e);
219+
})?)
245220
}
246221
}

tests/sqllogictests/src/client/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
mod global_cookie_store;
1516
mod http_client;
1617
mod mysql_client;
1718
mod ttc_client;

0 commit comments

Comments
 (0)