diff --git a/.gitignore b/.gitignore index 401319c..2bc5772 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /storage /target +.DS_Store \ No newline at end of file diff --git a/src/db.rs b/src/db.rs index 423ece3..9c52ea8 100644 --- a/src/db.rs +++ b/src/db.rs @@ -12,6 +12,7 @@ use std::{ use tokio::sync::RwLock; use crate::similarity::{get_cache_attr, get_distance_fn, normalize, Distance, ScoreIndex}; +use crate::search::Filter; lazy_static! { pub static ref STORE_PATH: PathBuf = PathBuf::from("./storage/db"); @@ -55,7 +56,7 @@ pub struct Collection { } impl Collection { - pub fn get_similarity(&self, query: &[f32], k: usize) -> Vec { + pub fn get_similarity(&self, query: &[f32], k: usize, comparate: Option) -> Vec { let memo_attr = get_cache_attr(self.distance, query); let distance_fn = get_distance_fn(self.distance); @@ -63,6 +64,12 @@ impl Collection { .embeddings .par_iter() .enumerate() + .filter(|(_, embedding)| { + match comparate { + Some(ref comparate) => (*comparate).compare(embedding), + _ => true, + } + }) .map(|(index, embedding)| { let score = distance_fn(&embedding.vector, query, memo_attr); ScoreIndex { score, index } diff --git a/src/main.rs b/src/main.rs index 07fd97b..6ba2e21 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ use tracing_subscriber::{ mod db; mod errors; mod routes; +mod search; mod server; mod shutdown; mod similarity; diff --git a/src/routes/collection.rs b/src/routes/collection.rs index c0f8f47..5cead7e 100644 --- a/src/routes/collection.rs +++ b/src/routes/collection.rs @@ -10,6 +10,7 @@ use std::time::Instant; use crate::{ db::{self, Collection, DbExtension, Embedding, Error as DbError, SimilarityResult}, errors::HTTPError, + search::Filter, similarity::Distance, }; @@ -56,6 +57,8 @@ struct QueryCollectionQuery { query: Vec, /// Number of results to return k: Option, + /// Filter results by metadata + filter: Option, } /// Query a collection @@ -76,8 +79,9 @@ async fn query_collection( return Err(HTTPError::new("Query dimension mismatch").with_status(StatusCode::BAD_REQUEST)); } + let instant = Instant::now(); - let results = collection.get_similarity(&req.query, req.k.unwrap_or(1)); + let results = collection.get_similarity(&req.query, req.k.unwrap_or(1), req.filter); drop(db); tracing::trace!("Query to {collection_name} took {:?}", instant.elapsed()); @@ -125,7 +129,6 @@ async fn delete_collection( tracing::trace!("Deleting collection {collection_name}"); let mut db = db.write().await; - let delete_result = db.delete_collection(&collection_name); drop(db); diff --git a/src/search.rs b/src/search.rs new file mode 100644 index 0000000..fd65be6 --- /dev/null +++ b/src/search.rs @@ -0,0 +1,201 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +trait Compare { + fn compare(&self, metadata: &HashMap) -> bool; +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +pub enum EqualityCompOp { + #[serde(rename = "eq")] + Eq, + #[serde(rename = "ne")] + Ne, + #[serde(rename = "gt")] + Gt, + #[serde(rename = "gte")] + Gte, + #[serde(rename = "lt")] + Lt, + #[serde(rename = "lte")] + Lte, +} + +fn eq(lhs: String, rhs: String) -> bool { lhs == rhs } +fn ne(lhs: String, rhs: String) -> bool { lhs != rhs } +fn gt(lhs: String, rhs: String) -> bool { lhs > rhs } +fn gte(lhs: String, rhs: String) -> bool { lhs >= rhs } +fn lt(lhs: String, rhs: String) -> bool { lhs < rhs } +fn lte(lhs: String, rhs: String) -> bool { lhs <= rhs } + +fn get_equality_comp_op_fn(op: EqualityCompOp) -> impl Fn(String, String) -> bool { + match op { + EqualityCompOp::Eq => eq, + EqualityCompOp::Ne => ne, + EqualityCompOp::Gt => gt, + EqualityCompOp::Gte => gte, + EqualityCompOp::Lt => lt, + EqualityCompOp::Lte => lte, + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] + +pub struct Comparator { + pub metadata_field: String, + pub op: EqualityCompOp, + pub comp_value: String, +} + +impl Compare for Comparator { + fn compare(&self, metadata: &HashMap) -> bool { + let metadata_value = metadata.get(&self.metadata_field).unwrap_or(&"".to_string()); + let op = get_equality_comp_op_fn(self.op); + op(*metadata_value, self.comp_value) + } +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] + +pub enum LogicalCompOp { + #[serde(rename = "and")] + And, + #[serde(rename = "or")] + Or, +} + +fn and(lhs: bool, rhs: bool) -> bool { lhs && rhs } +fn or(lhs: bool, rhs: bool) -> bool { lhs || rhs } + +fn get_logical_comp_op_fn(op: LogicalCompOp) -> impl Fn(bool, bool) -> bool { + match op { + LogicalCompOp::And => and, + LogicalCompOp::Or => or, + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +pub struct Logic { + pub lhs: Box, + pub op: LogicalCompOp, + pub rhs: Box, +} + +impl Compare for Logic { + fn compare(&self, metadata: &HashMap) -> bool { + let lhs = self.lhs.compare(metadata); + let rhs = self.rhs.compare(metadata); + let op = get_logical_comp_op_fn(self.op); + op(lhs, rhs) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +pub enum Filter { + Comparator(Comparator), + Logic(Logic), +} + +impl Compare for Filter { + fn compare(&self, metadata: &HashMap) -> bool { + match self { + Filter::Comparator(c) => c.compare(metadata), + Filter::Logic(l) => l.compare(metadata), + } + } +} + +/*** + +{ + "filter": { + "$and": [ + ...opt1, + ...opt2, + ], + "$or": [ + ...opt1, + ...opt2, + ], + "$eq": { + "field": "value" + }, + "$ne": { + "field": "value" + }, + "$gt": { + "field": "value" + }, + "$gte": { + "field": "value" + }, + "$lt": { + "field": "value" + }, + "$lte": { + "field": "value" + }, + } +} + + ***/ + + #[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Filter incorrectly formatted")] + InvalidFilter, +} + +fn parse_logic_helper(input: HashMap, key: &str) -> Result { + match input.get(key) { + Some(s) => { + match parse(s) { + Ok(f) => { Ok(f) }, + Err(_) => { Err(Error::InvalidFilter) } + } + }, + None => { Err(Error::InvalidFilter) }, + } +} + + + +fn parse_logic(input: HashMap, op: LogicalCompOp) -> Result { + let lhs = parse_logic_helper(input, "lhs"); + let rhs = parse_logic_helper(input, "rhs"); + Ok(Filter::Logic(Logic { lhs, op, rhs })) +} + +fn parse_comparator(input: HashMap, op: EqualityCompOp) -> Result { + fn parse_field(input: HashMap, key: &str) -> Result { + match input.get(key) { + Some(s) => Ok(s.to_string()), + None => return Err(Error::InvalidFilter), + } + } + + let metadata_field = match input.keys { + Some(s) => Ok(s.to_string()), + None => return Err(Error::InvalidFilter), + }; + Ok(Filter::Comparator(Comparator { metadata_field, op: EqualityCompOp::Eq, comp_value })) +} + +pub fn parse(input: HashMap) -> Result { + if input.keys().len() != 1 { + return Err(Error::InvalidFilter); + } + + match input.keys().next().unwrap().as_str() { + "$and" => parse_logic(input.get("$and").unwrap().to_string(), LogicalCompOp::And), + "$or" => parse_logic(input.get("$or").unwrap().to_string(), LogicalCompOp::Or), + "$eq" => parse_comparator(input.get("$eq").unwrap().to_string(), EqualityCompOp::Eq), + "$ne" => parse_comparator(input.get("$ne").unwrap().to_string(), EqualityCompOp::Ne), + "$gt" => parse_comparator(input.get("$gt").unwrap().to_string(), EqualityCompOp::Gt), + "$gte" => parse_comparator(input.get("$gte").unwrap().to_string(), EqualityCompOp::Gte), + "$lt" => parse_comparator(input.get("$lt").unwrap().to_string(), EqualityCompOp::Lt), + "$lte" => parse_comparator(input.get("$lte").unwrap().to_string(), EqualityCompOp::Lte), + _ => Err(Error::InvalidFilter), + } +}