Skip to content

Commit e91c2a1

Browse files
committed
feat: add chunk_capacity CLI option
Signed-off-by: Xin Liu <[email protected]>
1 parent 9aa1109 commit e91c2a1

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

src/backend/ggml.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ pub(crate) async fn chunks_handler(mut req: Request<Body>) -> Result<Response<Bo
713713
));
714714
}
715715

716-
match llama_core::rag::chunk_text(&contents, extension) {
716+
match llama_core::rag::chunk_text(&contents, extension, chunks_request.chunk_capacity) {
717717
Ok(chunks) => {
718718
let chunks_response = ChunksResponse {
719719
id: chunks_request.id,
@@ -745,7 +745,10 @@ pub(crate) async fn chunks_handler(mut req: Request<Body>) -> Result<Response<Bo
745745
}
746746
}
747747

748-
pub(crate) async fn doc_to_embeddings(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
748+
pub(crate) async fn doc_to_embeddings(
749+
req: Request<Body>,
750+
chunk_capacity: usize,
751+
) -> Result<Response<Body>, hyper::Error> {
749752
// upload the target rag document
750753
let file_object = if req.method() == Method::POST {
751754
let boundary = "boundary=";
@@ -907,7 +910,7 @@ pub(crate) async fn doc_to_embeddings(req: Request<Body>) -> Result<Response<Bod
907910
));
908911
}
909912

910-
match llama_core::rag::chunk_text(&contents, extension) {
913+
match llama_core::rag::chunk_text(&contents, extension, chunk_capacity) {
911914
Ok(chunks) => chunks,
912915
Err(e) => return error::internal_server_error(e.to_string()),
913916
}

src/backend/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use hyper::{Body, Request, Response};
55

66
pub(crate) async fn handle_llama_request(
77
req: Request<Body>,
8+
chunk_capacity: usize,
89
) -> Result<Response<Body>, hyper::Error> {
910
match req.uri().path() {
1011
"/v1/chat/completions" => match QDRANT_CONFIG.get() {
@@ -18,7 +19,7 @@ pub(crate) async fn handle_llama_request(
1819
},
1920
"/v1/files" => ggml::files_handler(req).await,
2021
"/v1/chunks" => ggml::chunks_handler(req).await,
21-
"/v1/create/rag" => ggml::doc_to_embeddings(req).await,
22+
"/v1/create/rag" => ggml::doc_to_embeddings(req, chunk_capacity).await,
2223
_ => error::invalid_endpoint(req.uri().path()),
2324
}
2425
}

src/main.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,15 @@ struct Cli {
7272
/// Name of Qdrant collection
7373
#[arg(long, default_value = "default")]
7474
qdrant_collection_name: String,
75-
/// Max number of retrieved result
76-
#[arg(long, default_value = "3", value_parser = clap::value_parser!(u64))]
75+
/// Max number of retrieved result (no less than 1)
76+
#[arg(long, default_value = "5", value_parser = clap::value_parser!(u64))]
7777
qdrant_limit: u64,
7878
/// Minimal score threshold for the search result
7979
#[arg(long, default_value = "0.4", value_parser = clap::value_parser!(f32))]
8080
qdrant_score_threshold: f32,
81+
/// Maximum number of tokens each chunk contains
82+
#[arg(long, default_value = "100", value_parser = clap::value_parser!(usize))]
83+
chunk_capacity: usize,
8184
/// Print prompt strings to stdout
8285
#[arg(long)]
8386
log_prompts: bool,
@@ -181,6 +184,10 @@ async fn main() -> Result<(), ServerError> {
181184
.set(qdrant_config)
182185
.map_err(|_| ServerError::Operation("Failed to set `QDRANT_CONFIG`.".to_string()))?;
183186

187+
log(format!(
188+
"[INFO] Chunk capacity (in tokens): {}",
189+
&cli.chunk_capacity
190+
));
184191
log(format!("[INFO] Enable prompt log: {}", &cli.log_prompts));
185192
log(format!("[INFO] Enable plugin log: {}", &cli.log_stat));
186193
log(format!("[INFO] Socket address: {}", &cli.socket_addr));
@@ -230,8 +237,13 @@ async fn main() -> Result<(), ServerError> {
230237

231238
let new_service = make_service_fn(move |_| {
232239
let web_ui = cli.web_ui.to_string_lossy().to_string();
240+
let chunk_capacity = cli.chunk_capacity;
233241

234-
async move { Ok::<_, Error>(service_fn(move |req| handle_request(req, web_ui.clone()))) }
242+
async move {
243+
Ok::<_, Error>(service_fn(move |req| {
244+
handle_request(req, chunk_capacity, web_ui.clone())
245+
}))
246+
}
235247
});
236248

237249
// socket address
@@ -255,6 +267,7 @@ async fn main() -> Result<(), ServerError> {
255267

256268
async fn handle_request(
257269
req: Request<Body>,
270+
chunk_capacity: usize,
258271
web_ui: String,
259272
) -> Result<Response<Body>, hyper::Error> {
260273
let path_str = req.uri().path();
@@ -266,7 +279,7 @@ async fn handle_request(
266279

267280
match root_path.as_str() {
268281
"/echo" => Ok(Response::new(Body::from("echo test"))),
269-
"/v1" => backend::handle_llama_request(req).await,
282+
"/v1" => backend::handle_llama_request(req, chunk_capacity).await,
270283
_ => Ok(static_response(path_str, web_ui)),
271284
}
272285
}

0 commit comments

Comments
 (0)