55
66use crate :: common:: * ;
77use crate :: tests:: * ;
8+ use std:: collections:: HashMap ;
89use std:: time:: Duration ;
9- use tracing:: debug;
1010use tracing:: info;
1111
1212pub ( crate ) async fn new ( ) -> TestCase {
@@ -39,17 +39,23 @@ async fn ann_query_returns_expected_results(actors: TestActors) {
3939 let keyspace = create_keyspace ( & session) . await ;
4040 let table = create_table ( & session, "pk INT PRIMARY KEY, v VECTOR<FLOAT, 3>" , None ) . await ;
4141
42- // Insert 1000 vectors
42+ // Create a map of pk -> embedding
43+ let mut embeddings: HashMap < i32 , Vec < f32 > > = HashMap :: new ( ) ;
4344 for i in 0 ..1000 {
44- let embedding: Vec < f32 > = vec ! [
45+ let embedding = vec ! [
4546 if i < 100 { 0.0 } else { ( i % 3 ) as f32 } ,
4647 if i < 100 { 0.0 } else { ( i % 5 ) as f32 } ,
4748 if i < 100 { 0.0 } else { ( i % 7 ) as f32 } ,
4849 ] ;
50+ embeddings. insert ( i, embedding) ;
51+ }
52+
53+ // Insert 1000 vectors from the map
54+ for ( pk, embedding) in & embeddings {
4955 session
5056 . query_unpaged (
5157 format ! ( "INSERT INTO {table} (pk, v) VALUES (?, ?)" ) ,
52- ( i , embedding) ,
58+ ( pk , embedding) ,
5359 )
5460 . await
5561 . expect ( "failed to insert data" ) ;
@@ -65,24 +71,25 @@ async fn ann_query_returns_expected_results(actors: TestActors) {
6571 . await ;
6672
6773 // Check if the query returns the expected results (recall at least 85%)
68- let rows = get_query_results (
69- format ! ( "SELECT pk FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 100" ) ,
74+ let results = get_query_results (
75+ format ! ( "SELECT pk, v FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 100" ) ,
7076 & session,
7177 )
7278 . await ;
73- assert_eq ! ( rows. len( ) , 100 ) ;
74- let correct = rows
75- . iter ( )
76- . filter ( |row| {
77- let pk: i32 = row. columns [ 0 ] . as_ref ( ) . unwrap ( ) . as_int ( ) . unwrap ( ) ;
78- pk < 100
79- } )
80- . count ( ) ;
81- debug ! ( "Number of matching results: {}" , correct) ;
82- assert ! (
83- correct >= 85 ,
84- "Expected more than 85 matching results, got {correct}"
85- ) ;
79+ let rows = results
80+ . rows :: < ( i32 , Vec < f32 > ) > ( )
81+ . expect ( "failed to get rows" ) ;
82+ assert ! ( rows. rows_remaining( ) <= 100 ) ;
83+ for row in rows {
84+ let row = row. expect ( "failed to get row" ) ;
85+ let ( pk, v) = row;
86+ assert ! (
87+ embeddings. contains_key( & pk) ,
88+ "pk {pk} not found in embeddings"
89+ ) ;
90+ let expected = embeddings. get ( & pk) . unwrap ( ) ;
91+ assert_eq ! ( & v, expected, "Returned vector does not match for pk={pk}" ) ;
92+ }
8693
8794 // Drop keyspace
8895 session
@@ -124,19 +131,25 @@ async fn ann_query_respects_limit(actors: TestActors) {
124131 . await ;
125132
126133 // Check if queries return the expected number of results
127- let rows = get_query_results (
134+ let results = get_query_results (
128135 format ! ( "SELECT * FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 10" ) ,
129136 & session,
130137 )
131138 . await ;
132- assert_eq ! ( rows. len( ) , 10 ) ;
139+ let rows = results
140+ . rows :: < ( i32 , Vec < f32 > ) > ( )
141+ . expect ( "failed to get rows" ) ;
142+ assert ! ( rows. rows_remaining( ) <= 10 ) ;
133143
134- let rows = get_query_results (
144+ let results = get_query_results (
135145 format ! ( "SELECT * FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 1000" ) ,
136146 & session,
137147 )
138148 . await ;
139- assert_eq ! ( rows. len( ) , 10 ) ; // Should return only 10, as there are only 10 vectors
149+ let rows = results
150+ . rows :: < ( i32 , Vec < f32 > ) > ( )
151+ . expect ( "failed to get rows" ) ;
152+ assert ! ( rows. rows_remaining( ) <= 10 ) ; // Should return only 10, as there are only 10 vectors
140153
141154 // Check if LIMIT over 1000 fails
142155 session
@@ -186,20 +199,25 @@ async fn ann_query_respects_limit_over_1000_vectors(actors: TestActors) {
186199 . await ;
187200
188201 // Check if queries return the expected number of results
189- let rows = get_query_results (
202+ let results = get_query_results (
190203 format ! ( "SELECT * FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 10" ) ,
191204 & session,
192205 )
193206 . await ;
194- assert_eq ! ( rows. len( ) , 10 ) ;
195-
196- // Due to VECTOR-221 the test fails. Uncomment after fixing.
197- // let rows = get_query_results(
198- // format!("SELECT * FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 1000"),
199- // &session,
200- // )
201- // .await;
202- // assert_eq!(rows.len(), 1000);
207+ let rows = results
208+ . rows :: < ( i32 , Vec < f32 > ) > ( )
209+ . expect ( "failed to get rows" ) ;
210+ assert ! ( rows. rows_remaining( ) <= 10 ) ;
211+
212+ let results = get_query_results (
213+ format ! ( "SELECT * FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 1000" ) ,
214+ & session,
215+ )
216+ . await ;
217+ let rows = results
218+ . rows :: < ( i32 , Vec < f32 > ) > ( )
219+ . expect ( "failed to get rows" ) ;
220+ assert ! ( rows. rows_remaining( ) <= 1000 ) ;
203221
204222 // Check if LIMIT over 1000 fails
205223 session
0 commit comments