Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,36 @@ pub struct Collection {
}

impl Collection {
pub fn list(&self) -> Vec<String> {
self
.embeddings
.iter()
.map(|e| e.id.to_owned())
.collect()
}

pub fn get(&self, id: &str) -> Option<&Embedding> {
self
.embeddings
.iter()
.find(|e| e.id == id)
}

pub fn get_by_metadata(&self, filter: &[HashMap<String, String>], k: usize) -> Vec<Embedding> {
self
.embeddings
.iter()
.filter_map(|embedding| {
if match_embedding(embedding, filter) {
Some(embedding.clone())
} else {
None
}
})
.take(k)
.collect()
}

pub fn get_by_metadata_and_similarity(&self, filter: &[HashMap<String, String>], query: &[f32], k: usize) -> Vec<SimilarityResult> {
let memo_attr = get_cache_attr(self.distance, query);
let distance_fn = get_distance_fn(self.distance);
Expand Down Expand Up @@ -92,6 +122,41 @@ impl Collection {
})
.collect()
}

pub fn delete(&mut self, id: &str) -> bool {
let index_opt = self.embeddings
.iter()
.position(|e| e.id == id);

match index_opt {
None => false,
Some(index) => { self.embeddings.remove(index); true }
}
}

pub fn delete_by_metadata(&mut self, filter: &[HashMap<String, String>]) {
if filter.len() == 0 {
self.embeddings.clear();
return
}

let indexes = self
.embeddings
.par_iter()
.enumerate()
.filter_map(|(index, embedding)| {
if match_embedding(embedding, filter) {
Some(index)
} else {
None
}
})
.collect::<Vec<_>>();

for index in indexes {
self.embeddings.remove(index);
}
}
}

fn match_embedding(embedding: &Embedding, filter: &[HashMap<String, String>]) -> bool {
Expand All @@ -104,7 +169,7 @@ fn match_embedding(embedding: &Embedding, filter: &[HashMap<String, String>]) ->
// no metadata in an embedding cannot be matched by a not empty filter
None => false,
Some(metadata) => {
// enumerate criteria with OR semantics; look for the first one matching
// enumerate criteria with OR semantics; look for the first one matching
for criteria in filter {
let mut matches = true;
// enumerate entries with AND semantics; look for the first one failing
Expand Down Expand Up @@ -211,6 +276,18 @@ impl Db {
self.collections.get(name)
}

pub fn get_collection_mut(&mut self, name: &str) -> Option<&mut Collection> {
self.collections.get_mut(name)
}

pub fn list(&self) -> Vec<String> {
self
.collections
.keys()
.map(|name| name.to_owned())
.collect()
}

fn load_from_store() -> anyhow::Result<Self> {
if !STORE_PATH.exists() {
tracing::debug!("Creating database store");
Expand Down
131 changes: 130 additions & 1 deletion src/routes/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,33 @@ pub fn handler() -> ApiRouter {
ApiRouter::new().nest(
"/collections",
ApiRouter::new()
.api_route("/", get(get_collections))
.api_route("/:collection_name", put(create_collection))
.api_route("/:collection_name", post(query_collection))
.api_route("/:collection_name", get(get_collection_info))
.api_route("/:collection_name", delete(delete_collection))
.api_route("/:collection_name/insert", post(insert_into_collection)),
.api_route("/:collection_name/insert", post(insert_into_collection))
.api_route("/:collection_name/embeddings", get(get_embeddings))
.api_route("/:collection_name/embeddings", post(query_embeddings))
.api_route("/:collection_name/embeddings", delete(delete_embeddings))
.api_route("/:collection_name/embeddings/:embedding_id", get(get_embedding))
.api_route("/:collection_name/embeddings/:embedding_id", delete(delete_embedding)),
)
}

/// Get collection names
async fn get_collections(
Extension(db): DbExtension,
) -> Result<Json<Vec<String>>, HTTPError> {
tracing::trace!("Getting collection names");

let db = db.read().await;

let results = db.list();

Ok(Json(results))
}

/// Create a new collection
async fn create_collection(
Path(collection_name): Path<String>,
Expand Down Expand Up @@ -170,3 +189,113 @@ async fn insert_into_collection(
.with_status(StatusCode::BAD_REQUEST)),
}
}

/// Query embeddings in a collection
async fn get_embeddings(
Path(collection_name): Path<String>,
Extension(db): DbExtension,
) -> Result<Json<Vec<String>>, HTTPError> {
tracing::trace!("Querying embeddings from collection {collection_name}");

let db = db.read().await;
let collection = db
.get_collection(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

let results = collection.list();
drop(db);

Ok(Json(results))
}

#[derive(Debug, serde::Deserialize, JsonSchema)]
struct EmbeddingsQuery {
/// Metadata to filter with
filter: Vec<HashMap<String, String>>,
/// Number of results to return
k: Option<usize>,
}

/// Query embeddings in a collection
async fn query_embeddings(
Path(collection_name): Path<String>,
Extension(db): DbExtension,
Json(req): Json<EmbeddingsQuery>,
) -> Result<Json<Vec<Embedding>>, HTTPError> {
tracing::trace!("Querying embeddings from collection {collection_name}");

let db = db.read().await;
let collection = db
.get_collection(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

let instant = Instant::now();
let results = collection.get_by_metadata(&req.filter, req.k.unwrap_or(1));
drop(db);

tracing::trace!("Query embeddings from {collection_name} took {:?}", instant.elapsed());
Ok(Json(results))
}

/// Delete embeddings in a collection
async fn delete_embeddings(
Path(collection_name): Path<String>,
Extension(db): DbExtension,
Json(req): Json<EmbeddingsQuery>,
) -> Result<StatusCode, HTTPError> {
tracing::trace!("Querying embeddings from collection {collection_name}");

let mut db = db.write().await;
let collection = db
.get_collection_mut(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

collection.delete_by_metadata(&req.filter);
drop(db);

Ok(StatusCode::NO_CONTENT)
}

/// Get an embedding from a collection
async fn get_embedding(
Path((collection_name, embedding_id)): Path<(String, String)>,
Extension(db): DbExtension,
) -> Result<Json<Embedding>, HTTPError> {
tracing::trace!("Getting {embedding_id} from collection {collection_name}");

if embedding_id.len() == 0 {
return Err(HTTPError::new("Embedding identifier empty").with_status(StatusCode::BAD_REQUEST));
}

let db = db.read().await;
let collection = db
.get_collection(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

let embedding = collection
.get(&embedding_id)
.ok_or_else(|| HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND))?;

Ok(Json(embedding.to_owned()))
}

/// Delete an embedding from a collection
async fn delete_embedding(
Path((collection_name, embedding_id)): Path<(String, String)>,
Extension(db): DbExtension,
) -> Result<StatusCode, HTTPError> {
tracing::trace!("Removing embedding {embedding_id} from collection {collection_name}");

let mut db = db.write().await;
let collection = db
.get_collection_mut(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

let delete_result = collection.delete(&embedding_id);
drop(db);

match delete_result {
true => Ok(StatusCode::NO_CONTENT),
false => Err(HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND)),
}
}