diff --git a/src/db.rs b/src/db.rs index 2aa42bd..fc22b83 100644 --- a/src/db.rs +++ b/src/db.rs @@ -55,6 +55,36 @@ pub struct Collection { } impl Collection { + pub fn list(&self) -> Vec { + self + .embeddings + .iter() + .map(|e| e.id.to_owned()) + .collect() + } + + pub fn get(&self, id: &str) -> Option<&Embedding> { + self + .embeddings + .iter() + .find(|e| e.id == id) + } + + pub fn get_by_metadata(&self, filter: &[HashMap], k: usize) -> Vec { + self + .embeddings + .iter() + .filter_map(|embedding| { + if match_embedding(embedding, filter) { + Some(embedding.clone()) + } else { + None + } + }) + .take(k) + .collect() + } + pub fn get_by_metadata_and_similarity(&self, filter: &[HashMap], query: &[f32], k: usize) -> Vec { let memo_attr = get_cache_attr(self.distance, query); let distance_fn = get_distance_fn(self.distance); @@ -92,6 +122,41 @@ impl Collection { }) .collect() } + + pub fn delete(&mut self, id: &str) -> bool { + let index_opt = self.embeddings + .iter() + .position(|e| e.id == id); + + match index_opt { + None => false, + Some(index) => { self.embeddings.remove(index); true } + } + } + + pub fn delete_by_metadata(&mut self, filter: &[HashMap]) { + if filter.len() == 0 { + self.embeddings.clear(); + return + } + + let indexes = self + .embeddings + .par_iter() + .enumerate() + .filter_map(|(index, embedding)| { + if match_embedding(embedding, filter) { + Some(index) + } else { + None + } + }) + .collect::>(); + + for index in indexes { + self.embeddings.remove(index); + } + } } fn match_embedding(embedding: &Embedding, filter: &[HashMap]) -> bool { @@ -104,7 +169,7 @@ fn match_embedding(embedding: &Embedding, filter: &[HashMap]) -> // no metadata in an embedding cannot be matched by a not empty filter None => false, Some(metadata) => { - // enumerate criteria with OR semantics; look for the first one matching + // enumerate criteria with OR semantics; look for the first one matching for criteria in filter { let mut matches = true; // enumerate entries with AND semantics; look for the first one failing @@ -211,6 +276,18 @@ impl Db { self.collections.get(name) } + pub fn get_collection_mut(&mut self, name: &str) -> Option<&mut Collection> { + self.collections.get_mut(name) + } + + pub fn list(&self) -> Vec { + self + .collections + .keys() + .map(|name| name.to_owned()) + .collect() + } + fn load_from_store() -> anyhow::Result { if !STORE_PATH.exists() { tracing::debug!("Creating database store"); diff --git a/src/routes/collection.rs b/src/routes/collection.rs index 8aa7633..e4c1f15 100644 --- a/src/routes/collection.rs +++ b/src/routes/collection.rs @@ -20,14 +20,33 @@ pub fn handler() -> ApiRouter { ApiRouter::new().nest( "/collections", ApiRouter::new() + .api_route("/", get(get_collections)) .api_route("/:collection_name", put(create_collection)) .api_route("/:collection_name", post(query_collection)) .api_route("/:collection_name", get(get_collection_info)) .api_route("/:collection_name", delete(delete_collection)) - .api_route("/:collection_name/insert", post(insert_into_collection)), + .api_route("/:collection_name/insert", post(insert_into_collection)) + .api_route("/:collection_name/embeddings", get(get_embeddings)) + .api_route("/:collection_name/embeddings", post(query_embeddings)) + .api_route("/:collection_name/embeddings", delete(delete_embeddings)) + .api_route("/:collection_name/embeddings/:embedding_id", get(get_embedding)) + .api_route("/:collection_name/embeddings/:embedding_id", delete(delete_embedding)), ) } +/// Get collection names +async fn get_collections( + Extension(db): DbExtension, +) -> Result>, HTTPError> { + tracing::trace!("Getting collection names"); + + let db = db.read().await; + + let results = db.list(); + + Ok(Json(results)) +} + /// Create a new collection async fn create_collection( Path(collection_name): Path, @@ -170,3 +189,113 @@ async fn insert_into_collection( .with_status(StatusCode::BAD_REQUEST)), } } + +/// Query embeddings in a collection +async fn get_embeddings( + Path(collection_name): Path, + Extension(db): DbExtension, +) -> Result>, HTTPError> { + tracing::trace!("Querying embeddings from collection {collection_name}"); + + let db = db.read().await; + let collection = db + .get_collection(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + let results = collection.list(); + drop(db); + + Ok(Json(results)) +} + +#[derive(Debug, serde::Deserialize, JsonSchema)] +struct EmbeddingsQuery { + /// Metadata to filter with + filter: Vec>, + /// Number of results to return + k: Option, +} + +/// Query embeddings in a collection +async fn query_embeddings( + Path(collection_name): Path, + Extension(db): DbExtension, + Json(req): Json, +) -> Result>, HTTPError> { + tracing::trace!("Querying embeddings from collection {collection_name}"); + + let db = db.read().await; + let collection = db + .get_collection(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + let instant = Instant::now(); + let results = collection.get_by_metadata(&req.filter, req.k.unwrap_or(1)); + drop(db); + + tracing::trace!("Query embeddings from {collection_name} took {:?}", instant.elapsed()); + Ok(Json(results)) +} + +/// Delete embeddings in a collection +async fn delete_embeddings( + Path(collection_name): Path, + Extension(db): DbExtension, + Json(req): Json, +) -> Result { + tracing::trace!("Querying embeddings from collection {collection_name}"); + + let mut db = db.write().await; + let collection = db + .get_collection_mut(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + collection.delete_by_metadata(&req.filter); + drop(db); + + Ok(StatusCode::NO_CONTENT) +} + +/// Get an embedding from a collection +async fn get_embedding( + Path((collection_name, embedding_id)): Path<(String, String)>, + Extension(db): DbExtension, +) -> Result, HTTPError> { + tracing::trace!("Getting {embedding_id} from collection {collection_name}"); + + if embedding_id.len() == 0 { + return Err(HTTPError::new("Embedding identifier empty").with_status(StatusCode::BAD_REQUEST)); + } + + let db = db.read().await; + let collection = db + .get_collection(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + let embedding = collection + .get(&embedding_id) + .ok_or_else(|| HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND))?; + + Ok(Json(embedding.to_owned())) +} + +/// Delete an embedding from a collection +async fn delete_embedding( + Path((collection_name, embedding_id)): Path<(String, String)>, + Extension(db): DbExtension, +) -> Result { + tracing::trace!("Removing embedding {embedding_id} from collection {collection_name}"); + + let mut db = db.write().await; + let collection = db + .get_collection_mut(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + let delete_result = collection.delete(&embedding_id); + drop(db); + + match delete_result { + true => Ok(StatusCode::NO_CONTENT), + false => Err(HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND)), + } +}