Skip to content

Commit fb54a70

Browse files
authored
Fix ANN flaky tests (#246)
Fix ANN LIMIT tests as the USearch library does not guarantee to return all the items, but only those nearly matching the expected vector. Refactor the `get_query_results` as it's a better practice not to allocate a whole vector of result rows, but work with iterator. It also helps to serialize the results into expected types. Fix ANN expected results test to check for a valid PKs and if the rows returned are the ones that was inserted. Refs: VECTOR-221 Fixes: VECTOR-239
2 parents 548f9ff + 5e551ea commit fb54a70

File tree

2 files changed

+53
-39
lines changed

2 files changed

+53
-39
lines changed

crates/validator/src/common.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::vector_store_cluster::VectorStoreClusterExt;
1010
use httpclient::HttpClient;
1111
use scylla::client::session::Session;
1212
use scylla::client::session_builder::SessionBuilder;
13-
use scylla::value::Row;
13+
use scylla::response::query_result::QueryRowsResult;
1414
use std::net::Ipv4Addr;
1515
use std::sync::Arc;
1616
use std::time::Duration;
@@ -97,17 +97,13 @@ where
9797
.unwrap_or_else(|_| panic!("Timeout on: {msg}"))
9898
}
9999

100-
pub(crate) async fn get_query_results(query: String, session: &Session) -> Vec<Row> {
100+
pub(crate) async fn get_query_results(query: String, session: &Session) -> QueryRowsResult {
101101
session
102102
.query_unpaged(query, ())
103103
.await
104104
.expect("failed to run query")
105105
.into_rows_result()
106106
.expect("failed to get rows")
107-
.rows()
108-
.unwrap()
109-
.collect::<Result<Vec<Row>, _>>()
110-
.expect("failed to decode rows")
111107
}
112108

113109
pub(crate) async fn create_keyspace(session: &Session) -> String {

crates/validator/src/tests/ann.rs

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
use crate::common::*;
77
use crate::tests::*;
8+
use std::collections::HashMap;
89
use std::time::Duration;
9-
use tracing::debug;
1010
use tracing::info;
1111

1212
pub(crate) async fn new() -> TestCase {
@@ -39,17 +39,23 @@ async fn ann_query_returns_expected_results(actors: TestActors) {
3939
let keyspace = create_keyspace(&session).await;
4040
let table = create_table(&session, "pk INT PRIMARY KEY, v VECTOR<FLOAT, 3>", None).await;
4141

42-
// Insert 1000 vectors
42+
// Create a map of pk -> embedding
43+
let mut embeddings: HashMap<i32, Vec<f32>> = HashMap::new();
4344
for i in 0..1000 {
44-
let embedding: Vec<f32> = vec![
45+
let embedding = vec![
4546
if i < 100 { 0.0 } else { (i % 3) as f32 },
4647
if i < 100 { 0.0 } else { (i % 5) as f32 },
4748
if i < 100 { 0.0 } else { (i % 7) as f32 },
4849
];
50+
embeddings.insert(i, embedding);
51+
}
52+
53+
// Insert 1000 vectors from the map
54+
for (pk, embedding) in &embeddings {
4955
session
5056
.query_unpaged(
5157
format!("INSERT INTO {table} (pk, v) VALUES (?, ?)"),
52-
(i, embedding),
58+
(pk, embedding),
5359
)
5460
.await
5561
.expect("failed to insert data");
@@ -65,24 +71,25 @@ async fn ann_query_returns_expected_results(actors: TestActors) {
6571
.await;
6672

6773
// Check if the query returns the expected results (recall at least 85%)
68-
let rows = get_query_results(
69-
format!("SELECT pk FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 100"),
74+
let results = get_query_results(
75+
format!("SELECT pk, v FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 100"),
7076
&session,
7177
)
7278
.await;
73-
assert_eq!(rows.len(), 100);
74-
let correct = rows
75-
.iter()
76-
.filter(|row| {
77-
let pk: i32 = row.columns[0].as_ref().unwrap().as_int().unwrap();
78-
pk < 100
79-
})
80-
.count();
81-
debug!("Number of matching results: {}", correct);
82-
assert!(
83-
correct >= 85,
84-
"Expected more than 85 matching results, got {correct}"
85-
);
79+
let rows = results
80+
.rows::<(i32, Vec<f32>)>()
81+
.expect("failed to get rows");
82+
assert!(rows.rows_remaining() <= 100);
83+
for row in rows {
84+
let row = row.expect("failed to get row");
85+
let (pk, v) = row;
86+
assert!(
87+
embeddings.contains_key(&pk),
88+
"pk {pk} not found in embeddings"
89+
);
90+
let expected = embeddings.get(&pk).unwrap();
91+
assert_eq!(&v, expected, "Returned vector does not match for pk={pk}");
92+
}
8693

8794
// Drop keyspace
8895
session
@@ -124,19 +131,25 @@ async fn ann_query_respects_limit(actors: TestActors) {
124131
.await;
125132

126133
// Check if queries return the expected number of results
127-
let rows = get_query_results(
134+
let results = get_query_results(
128135
format!("SELECT * FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 10"),
129136
&session,
130137
)
131138
.await;
132-
assert_eq!(rows.len(), 10);
139+
let rows = results
140+
.rows::<(i32, Vec<f32>)>()
141+
.expect("failed to get rows");
142+
assert!(rows.rows_remaining() <= 10);
133143

134-
let rows = get_query_results(
144+
let results = get_query_results(
135145
format!("SELECT * FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 1000"),
136146
&session,
137147
)
138148
.await;
139-
assert_eq!(rows.len(), 10); // Should return only 10, as there are only 10 vectors
149+
let rows = results
150+
.rows::<(i32, Vec<f32>)>()
151+
.expect("failed to get rows");
152+
assert!(rows.rows_remaining() <= 10); // Should return only 10, as there are only 10 vectors
140153

141154
// Check if LIMIT over 1000 fails
142155
session
@@ -186,20 +199,25 @@ async fn ann_query_respects_limit_over_1000_vectors(actors: TestActors) {
186199
.await;
187200

188201
// Check if queries return the expected number of results
189-
let rows = get_query_results(
202+
let results = get_query_results(
190203
format!("SELECT * FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 10"),
191204
&session,
192205
)
193206
.await;
194-
assert_eq!(rows.len(), 10);
195-
196-
// Due to VECTOR-221 the test fails. Uncomment after fixing.
197-
// let rows = get_query_results(
198-
// format!("SELECT * FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 1000"),
199-
// &session,
200-
// )
201-
// .await;
202-
// assert_eq!(rows.len(), 1000);
207+
let rows = results
208+
.rows::<(i32, Vec<f32>)>()
209+
.expect("failed to get rows");
210+
assert!(rows.rows_remaining() <= 10);
211+
212+
let results = get_query_results(
213+
format!("SELECT * FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 1000"),
214+
&session,
215+
)
216+
.await;
217+
let rows = results
218+
.rows::<(i32, Vec<f32>)>()
219+
.expect("failed to get rows");
220+
assert!(rows.rows_remaining() <= 1000);
203221

204222
// Check if LIMIT over 1000 fails
205223
session

0 commit comments

Comments
 (0)