55
66use crate :: common:: * ;
77use crate :: tests:: * ;
8+ use std:: collections:: HashMap ;
89use std:: time:: Duration ;
9- use tracing:: debug;
1010use tracing:: info;
1111
1212pub ( crate ) async fn new ( ) -> TestCase {
@@ -39,17 +39,23 @@ async fn ann_query_returns_expected_results(actors: TestActors) {
3939 let keyspace = create_keyspace ( & session) . await ;
4040 let table = create_table ( & session, "pk INT PRIMARY KEY, v VECTOR<FLOAT, 3>" , None ) . await ;
4141
42- // Insert 1000 vectors
42+ // Create a map of pk -> embedding
43+ let mut embeddings: HashMap < i32 , Vec < f32 > > = HashMap :: new ( ) ;
4344 for i in 0 ..1000 {
44- let embedding: Vec < f32 > = vec ! [
45+ let embedding = vec ! [
4546 if i < 100 { 0.0 } else { ( i % 3 ) as f32 } ,
4647 if i < 100 { 0.0 } else { ( i % 5 ) as f32 } ,
4748 if i < 100 { 0.0 } else { ( i % 7 ) as f32 } ,
4849 ] ;
50+ embeddings. insert ( i, embedding) ;
51+ }
52+
53+ // Insert 1000 vectors from the map
54+ for ( pk, embedding) in & embeddings {
4955 session
5056 . query_unpaged (
5157 format ! ( "INSERT INTO {table} (pk, v) VALUES (?, ?)" ) ,
52- ( i , embedding) ,
58+ ( pk , embedding) ,
5359 )
5460 . await
5561 . expect ( "failed to insert data" ) ;
@@ -66,23 +72,24 @@ async fn ann_query_returns_expected_results(actors: TestActors) {
6672
6773 // Check if the query returns the expected results (recall at least 85%)
6874 let results = get_query_results (
69- format ! ( "SELECT pk FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 100" ) ,
75+ format ! ( "SELECT pk, v FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 100" ) ,
7076 & session,
7177 )
7278 . await ;
73- let rows = results. rows :: < ( i32 , ) > ( ) . expect ( "failed to get rows" ) ;
74- assert_eq ! ( rows. rows_remaining( ) , 100 ) ;
75- let correct = rows
76- . filter ( |row| {
77- let pk = row. expect ( "failed to get row" ) . 0 ;
78- pk < 100
79- } )
80- . count ( ) ;
81- debug ! ( "Number of matching results: {}" , correct) ;
82- assert ! (
83- correct >= 85 ,
84- "Expected more than 85 matching results, got {correct}"
85- ) ;
79+ let rows = results
80+ . rows :: < ( i32 , Vec < f32 > ) > ( )
81+ . expect ( "failed to get rows" ) ;
82+ assert ! ( rows. rows_remaining( ) <= 100 ) ;
83+ for row in rows {
84+ let row = row. expect ( "failed to get row" ) ;
85+ let ( pk, v) = row;
86+ assert ! (
87+ embeddings. contains_key( & pk) ,
88+ "pk {pk} not found in embeddings"
89+ ) ;
90+ let expected = embeddings. get ( & pk) . unwrap ( ) ;
91+ assert_eq ! ( & v, expected, "Returned vector does not match for pk={pk}" ) ;
92+ }
8693
8794 // Drop keyspace
8895 session
0 commit comments