Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/storage
/target
.DS_Store
9 changes: 8 additions & 1 deletion src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::{
use tokio::sync::RwLock;

use crate::similarity::{get_cache_attr, get_distance_fn, normalize, Distance, ScoreIndex};
use crate::search::Filter;

lazy_static! {
pub static ref STORE_PATH: PathBuf = PathBuf::from("./storage/db");
Expand Down Expand Up @@ -55,14 +56,20 @@ pub struct Collection {
}

impl Collection {
pub fn get_similarity(&self, query: &[f32], k: usize) -> Vec<SimilarityResult> {
pub fn get_similarity(&self, query: &[f32], k: usize, comparate: Option<Filter>) -> Vec<SimilarityResult> {
let memo_attr = get_cache_attr(self.distance, query);
let distance_fn = get_distance_fn(self.distance);

let scores = self
.embeddings
.par_iter()
.enumerate()
.filter(|(_, embedding)| {
match comparate {
Some(ref comparate) => (*comparate).compare(embedding),
_ => true,
}
})
.map(|(index, embedding)| {
let score = distance_fn(&embedding.vector, query, memo_attr);
ScoreIndex { score, index }
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use tracing_subscriber::{
mod db;
mod errors;
mod routes;
mod search;
mod server;
mod shutdown;
mod similarity;
Expand Down
7 changes: 5 additions & 2 deletions src/routes/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::time::Instant;
use crate::{
db::{self, Collection, DbExtension, Embedding, Error as DbError, SimilarityResult},
errors::HTTPError,
search::Filter,
similarity::Distance,
};

Expand Down Expand Up @@ -56,6 +57,8 @@ struct QueryCollectionQuery {
query: Vec<f32>,
/// Number of results to return
k: Option<usize>,
/// Filter results by metadata
filter: Option<Filter>,
}

/// Query a collection
Expand All @@ -76,8 +79,9 @@ async fn query_collection(
return Err(HTTPError::new("Query dimension mismatch").with_status(StatusCode::BAD_REQUEST));
}


let instant = Instant::now();
let results = collection.get_similarity(&req.query, req.k.unwrap_or(1));
let results = collection.get_similarity(&req.query, req.k.unwrap_or(1), req.filter);
drop(db);

tracing::trace!("Query to {collection_name} took {:?}", instant.elapsed());
Expand Down Expand Up @@ -125,7 +129,6 @@ async fn delete_collection(
tracing::trace!("Deleting collection {collection_name}");

let mut db = db.write().await;

let delete_result = db.delete_collection(&collection_name);
drop(db);

Expand Down
201 changes: 201 additions & 0 deletions src/search.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

trait Compare {
fn compare(&self, metadata: &HashMap<String, String>) -> bool;
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub enum EqualityCompOp {
#[serde(rename = "eq")]
Eq,
#[serde(rename = "ne")]
Ne,
#[serde(rename = "gt")]
Gt,
#[serde(rename = "gte")]
Gte,
#[serde(rename = "lt")]
Lt,
#[serde(rename = "lte")]
Lte,
}

fn eq(lhs: String, rhs: String) -> bool { lhs == rhs }
fn ne(lhs: String, rhs: String) -> bool { lhs != rhs }
fn gt(lhs: String, rhs: String) -> bool { lhs > rhs }
fn gte(lhs: String, rhs: String) -> bool { lhs >= rhs }
fn lt(lhs: String, rhs: String) -> bool { lhs < rhs }
fn lte(lhs: String, rhs: String) -> bool { lhs <= rhs }

fn get_equality_comp_op_fn(op: EqualityCompOp) -> impl Fn(String, String) -> bool {
match op {
EqualityCompOp::Eq => eq,
EqualityCompOp::Ne => ne,
EqualityCompOp::Gt => gt,
EqualityCompOp::Gte => gte,
EqualityCompOp::Lt => lt,
EqualityCompOp::Lte => lte,
}
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]

pub struct Comparator {
pub metadata_field: String,
pub op: EqualityCompOp,
pub comp_value: String,
}

impl Compare for Comparator {
fn compare(&self, metadata: &HashMap<String, String>) -> bool {
let metadata_value = metadata.get(&self.metadata_field).unwrap_or(&"".to_string());
let op = get_equality_comp_op_fn(self.op);
op(*metadata_value, self.comp_value)
}
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]

pub enum LogicalCompOp {
#[serde(rename = "and")]
And,
#[serde(rename = "or")]
Or,
}

fn and(lhs: bool, rhs: bool) -> bool { lhs && rhs }
fn or(lhs: bool, rhs: bool) -> bool { lhs || rhs }

fn get_logical_comp_op_fn(op: LogicalCompOp) -> impl Fn(bool, bool) -> bool {
match op {
LogicalCompOp::And => and,
LogicalCompOp::Or => or,
}
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub struct Logic {
pub lhs: Box<Filter>,
pub op: LogicalCompOp,
pub rhs: Box<Filter>,
}

impl Compare for Logic {
fn compare(&self, metadata: &HashMap<String, String>) -> bool {
let lhs = self.lhs.compare(metadata);
let rhs = self.rhs.compare(metadata);
let op = get_logical_comp_op_fn(self.op);
op(lhs, rhs)
}
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub enum Filter {
Comparator(Comparator),
Logic(Logic),
}

impl Compare for Filter {
fn compare(&self, metadata: &HashMap<String, String>) -> bool {
match self {
Filter::Comparator(c) => c.compare(metadata),
Filter::Logic(l) => l.compare(metadata),
}
}
}

/***

{
"filter": {
"$and": [
...opt1,
...opt2,
],
"$or": [
...opt1,
...opt2,
],
"$eq": {
"field": "value"
},
"$ne": {
"field": "value"
},
"$gt": {
"field": "value"
},
"$gte": {
"field": "value"
},
"$lt": {
"field": "value"
},
"$lte": {
"field": "value"
},
}
}

***/

#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Filter incorrectly formatted")]
InvalidFilter,
}

fn parse_logic_helper(input: HashMap<String, String>, key: &str) -> Result<Filter, Error> {
match input.get(key) {
Some(s) => {
match parse(s) {
Ok(f) => { Ok(f) },
Err(_) => { Err(Error::InvalidFilter) }
}
},
None => { Err(Error::InvalidFilter) },
}
}



fn parse_logic(input: HashMap<String, String>, op: LogicalCompOp) -> Result<Filter, Error> {
let lhs = parse_logic_helper(input, "lhs");
let rhs = parse_logic_helper(input, "rhs");
Ok(Filter::Logic(Logic { lhs, op, rhs }))
}

fn parse_comparator(input: HashMap<String, String>, op: EqualityCompOp) -> Result<Filter, Error> {
fn parse_field(input: HashMap<String, String>, key: &str) -> Result<String, Error> {
match input.get(key) {
Some(s) => Ok(s.to_string()),
None => return Err(Error::InvalidFilter),
}
}

let metadata_field = match input.keys {
Some(s) => Ok(s.to_string()),
None => return Err(Error::InvalidFilter),
};
Ok(Filter::Comparator(Comparator { metadata_field, op: EqualityCompOp::Eq, comp_value }))
}

pub fn parse(input: HashMap<String, String>) -> Result<Filter, Error> {
if input.keys().len() != 1 {
return Err(Error::InvalidFilter);
}

match input.keys().next().unwrap().as_str() {
"$and" => parse_logic(input.get("$and").unwrap().to_string(), LogicalCompOp::And),
"$or" => parse_logic(input.get("$or").unwrap().to_string(), LogicalCompOp::Or),
"$eq" => parse_comparator(input.get("$eq").unwrap().to_string(), EqualityCompOp::Eq),
"$ne" => parse_comparator(input.get("$ne").unwrap().to_string(), EqualityCompOp::Ne),
"$gt" => parse_comparator(input.get("$gt").unwrap().to_string(), EqualityCompOp::Gt),
"$gte" => parse_comparator(input.get("$gte").unwrap().to_string(), EqualityCompOp::Gte),
"$lt" => parse_comparator(input.get("$lt").unwrap().to_string(), EqualityCompOp::Lt),
"$lte" => parse_comparator(input.get("$lte").unwrap().to_string(), EqualityCompOp::Lte),
_ => Err(Error::InvalidFilter),
}
}