Skip to content

Commit 34a805b

Browse files
committed
Add endpoints for complete management of embeddings
* GET /collections - list collection names * GET `/collections/:collection_name/embeddings` - get embedding identifiers * POST /collections/:collection_name/embeddings - filter embeddings with metadata * DELETE /collections/:collection_name/embeddings - delete embeddings by metadata * GET /collections/:collection_name/embeddings/:embedding_id - get an embedding * DELETE /collections/:collection_name/embeddings/:embedding_id - delete an embedding
1 parent 404bdcc commit 34a805b

File tree

2 files changed

+208
-2
lines changed

2 files changed

+208
-2
lines changed

src/db.rs

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,36 @@ pub struct Collection {
5555
}
5656

5757
impl Collection {
58+
pub fn list(&self) -> Vec<String> {
59+
self
60+
.embeddings
61+
.iter()
62+
.map(|e| e.id.to_owned())
63+
.collect()
64+
}
65+
66+
pub fn get(&self, id: &str) -> Option<&Embedding> {
67+
self
68+
.embeddings
69+
.iter()
70+
.find(|e| e.id == id)
71+
}
72+
73+
pub fn get_by_metadata(&self, filter: &[HashMap<String, String>], k: usize) -> Vec<Embedding> {
74+
self
75+
.embeddings
76+
.iter()
77+
.filter_map(|embedding| {
78+
if match_embedding(embedding, filter) {
79+
Some(embedding.clone())
80+
} else {
81+
None
82+
}
83+
})
84+
.take(k)
85+
.collect()
86+
}
87+
5888
pub fn get_by_metadata_and_similarity(&self, filter: &[HashMap<String, String>], query: &[f32], k: usize) -> Vec<SimilarityResult> {
5989
let memo_attr = get_cache_attr(self.distance, query);
6090
let distance_fn = get_distance_fn(self.distance);
@@ -92,6 +122,41 @@ impl Collection {
92122
})
93123
.collect()
94124
}
125+
126+
pub fn delete(&mut self, id: &str) -> bool {
127+
let index_opt = self.embeddings
128+
.iter()
129+
.position(|e| e.id == id);
130+
131+
match index_opt {
132+
None => false,
133+
Some(index) => { self.embeddings.remove(index); true }
134+
}
135+
}
136+
137+
pub fn delete_by_metadata(&mut self, filter: &[HashMap<String, String>]) {
138+
if filter.len() == 0 {
139+
self.embeddings.clear();
140+
return
141+
}
142+
143+
let indexes = self
144+
.embeddings
145+
.par_iter()
146+
.enumerate()
147+
.filter_map(|(index, embedding)| {
148+
if match_embedding(embedding, filter) {
149+
Some(index)
150+
} else {
151+
None
152+
}
153+
})
154+
.collect::<Vec<_>>();
155+
156+
for index in indexes {
157+
self.embeddings.remove(index);
158+
}
159+
}
95160
}
96161

97162
fn match_embedding(embedding: &Embedding, filter: &[HashMap<String, String>]) -> bool {
@@ -104,7 +169,7 @@ fn match_embedding(embedding: &Embedding, filter: &[HashMap<String, String>]) ->
104169
// no metadata in an embedding cannot be matched by a not empty filter
105170
None => false,
106171
Some(metadata) => {
107-
// enumerate criteria with OR semantics; look for the first one matching
172+
// enumerate criteria with OR semantics; look for the first one matching
108173
for criteria in filter {
109174
let mut matches = true;
110175
// enumerate entries with AND semantics; look for the first one failing
@@ -211,6 +276,18 @@ impl Db {
211276
self.collections.get(name)
212277
}
213278

279+
pub fn get_collection_mut(&mut self, name: &str) -> Option<&mut Collection> {
280+
self.collections.get_mut(name)
281+
}
282+
283+
pub fn list(&self) -> Vec<String> {
284+
self
285+
.collections
286+
.keys()
287+
.map(|name| name.to_owned())
288+
.collect()
289+
}
290+
214291
fn load_from_store() -> anyhow::Result<Self> {
215292
if !STORE_PATH.exists() {
216293
tracing::debug!("Creating database store");

src/routes/collection.rs

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,33 @@ pub fn handler() -> ApiRouter {
2020
ApiRouter::new().nest(
2121
"/collections",
2222
ApiRouter::new()
23+
.api_route("/", get(get_collections))
2324
.api_route("/:collection_name", put(create_collection))
2425
.api_route("/:collection_name", post(query_collection))
2526
.api_route("/:collection_name", get(get_collection_info))
2627
.api_route("/:collection_name", delete(delete_collection))
27-
.api_route("/:collection_name/insert", post(insert_into_collection)),
28+
.api_route("/:collection_name/insert", post(insert_into_collection))
29+
.api_route("/:collection_name/embeddings", get(get_embeddings))
30+
.api_route("/:collection_name/embeddings", post(query_embeddings))
31+
.api_route("/:collection_name/embeddings", delete(delete_embeddings))
32+
.api_route("/:collection_name/embeddings/:embedding_id", get(get_embedding))
33+
.api_route("/:collection_name/embeddings/:embedding_id", delete(delete_embedding)),
2834
)
2935
}
3036

37+
/// Get collection names
38+
async fn get_collections(
39+
Extension(db): DbExtension,
40+
) -> Result<Json<Vec<String>>, HTTPError> {
41+
tracing::trace!("Getting collection names");
42+
43+
let db = db.read().await;
44+
45+
let results = db.list();
46+
47+
Ok(Json(results))
48+
}
49+
3150
/// Create a new collection
3251
async fn create_collection(
3352
Path(collection_name): Path<String>,
@@ -170,3 +189,113 @@ async fn insert_into_collection(
170189
.with_status(StatusCode::BAD_REQUEST)),
171190
}
172191
}
192+
193+
/// Query embeddings in a collection
194+
async fn get_embeddings(
195+
Path(collection_name): Path<String>,
196+
Extension(db): DbExtension,
197+
) -> Result<Json<Vec<String>>, HTTPError> {
198+
tracing::trace!("Querying embeddings from collection {collection_name}");
199+
200+
let db = db.read().await;
201+
let collection = db
202+
.get_collection(&collection_name)
203+
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;
204+
205+
let results = collection.list();
206+
drop(db);
207+
208+
Ok(Json(results))
209+
}
210+
211+
#[derive(Debug, serde::Deserialize, JsonSchema)]
212+
struct EmbeddingsQuery {
213+
/// Metadata to filter with
214+
filter: Vec<HashMap<String, String>>,
215+
/// Number of results to return
216+
k: Option<usize>,
217+
}
218+
219+
/// Query embeddings in a collection
220+
async fn query_embeddings(
221+
Path(collection_name): Path<String>,
222+
Extension(db): DbExtension,
223+
Json(req): Json<EmbeddingsQuery>,
224+
) -> Result<Json<Vec<Embedding>>, HTTPError> {
225+
tracing::trace!("Querying embeddings from collection {collection_name}");
226+
227+
let db = db.read().await;
228+
let collection = db
229+
.get_collection(&collection_name)
230+
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;
231+
232+
let instant = Instant::now();
233+
let results = collection.get_by_metadata(&req.filter, req.k.unwrap_or(1));
234+
drop(db);
235+
236+
tracing::trace!("Query embeddings from {collection_name} took {:?}", instant.elapsed());
237+
Ok(Json(results))
238+
}
239+
240+
/// Delete embeddings in a collection
241+
async fn delete_embeddings(
242+
Path(collection_name): Path<String>,
243+
Extension(db): DbExtension,
244+
Json(req): Json<EmbeddingsQuery>,
245+
) -> Result<StatusCode, HTTPError> {
246+
tracing::trace!("Querying embeddings from collection {collection_name}");
247+
248+
let mut db = db.write().await;
249+
let collection = db
250+
.get_collection_mut(&collection_name)
251+
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;
252+
253+
collection.delete_by_metadata(&req.filter);
254+
drop(db);
255+
256+
Ok(StatusCode::NO_CONTENT)
257+
}
258+
259+
/// Get an embedding from a collection
260+
async fn get_embedding(
261+
Path((collection_name, embedding_id)): Path<(String, String)>,
262+
Extension(db): DbExtension,
263+
) -> Result<Json<Embedding>, HTTPError> {
264+
tracing::trace!("Getting {embedding_id} from collection {collection_name}");
265+
266+
if embedding_id.len() == 0 {
267+
return Err(HTTPError::new("Embedding identifier empty").with_status(StatusCode::BAD_REQUEST));
268+
}
269+
270+
let db = db.read().await;
271+
let collection = db
272+
.get_collection(&collection_name)
273+
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;
274+
275+
let embedding = collection
276+
.get(&embedding_id)
277+
.ok_or_else(|| HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND))?;
278+
279+
Ok(Json(embedding.to_owned()))
280+
}
281+
282+
/// Delete an embedding from a collection
283+
async fn delete_embedding(
284+
Path((collection_name, embedding_id)): Path<(String, String)>,
285+
Extension(db): DbExtension,
286+
) -> Result<StatusCode, HTTPError> {
287+
tracing::trace!("Removing embedding {embedding_id} from collection {collection_name}");
288+
289+
let mut db = db.write().await;
290+
let collection = db
291+
.get_collection_mut(&collection_name)
292+
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;
293+
294+
let delete_result = collection.delete(&embedding_id);
295+
drop(db);
296+
297+
match delete_result {
298+
true => Ok(StatusCode::NO_CONTENT),
299+
false => Err(HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND)),
300+
}
301+
}

0 commit comments

Comments
 (0)