Skip to content

Commit 61fb2c7

Browse files
committed
validator/tests: add vector_similarity() function test
Add tests to validate the results of vector_similarity() function. Refs: scylladb/scylladb#25993
1 parent 3291db0 commit 61fb2c7

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

crates/validator/src/tests/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod crud;
88
mod full_scan;
99
mod reconnect;
1010
mod serde;
11+
mod vector_similarity;
1112

1213
use crate::ServicesSubnet;
1314
use crate::dns::Dns;
@@ -221,6 +222,7 @@ pub(crate) async fn register() -> Vec<(String, TestCase)> {
221222
("full_scan", full_scan::new().await),
222223
("reconnect", reconnect::new().await),
223224
("serde", serde::new().await),
225+
("vector_similarity", vector_similarity::new().await),
224226
]
225227
.into_iter()
226228
.map(|(name, test_case)| (name.to_string(), test_case))
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright 2025-present ScyllaDB
3+
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
4+
*/
5+
6+
use crate::common::*;
7+
use crate::tests::*;
8+
use std::time::Duration;
9+
use tracing::info;
10+
11+
pub(crate) async fn new() -> TestCase {
12+
let timeout = Duration::from_secs(30);
13+
TestCase::empty()
14+
.with_init(timeout, init)
15+
.with_cleanup(timeout, cleanup)
16+
.with_test(
17+
"vector_similarity_function_returns_expected_results",
18+
timeout,
19+
vector_similarity_function_returns_expected_results,
20+
)
21+
}
22+
23+
async fn vector_similarity_function_returns_expected_results(actors: TestActors) {
24+
info!("started");
25+
26+
let (session, client) = prepare_connection(&actors).await;
27+
28+
let keyspace = create_keyspace(&session).await;
29+
let table = create_table(&session, "pk INT PRIMARY KEY, v VECTOR<FLOAT, 3>", None).await;
30+
31+
let embeddings: Vec<Vec<f32>> = vec![
32+
vec![1.0, 2.0, 3.0],
33+
vec![4.0, 5.0, 6.0],
34+
vec![7.0, 8.0, 9.0],
35+
];
36+
for (i, embedding) in embeddings.into_iter().enumerate() {
37+
session
38+
.query_unpaged(
39+
format!("INSERT INTO {table} (pk, v) VALUES (?, ?)"),
40+
(i as i32, &embedding),
41+
)
42+
.await
43+
.expect("failed to insert data");
44+
}
45+
46+
let index = create_index(
47+
&session,
48+
&client,
49+
&table,
50+
"v",
51+
Some("{'similarity_function' : 'EUCLIDEAN'}"),
52+
)
53+
.await;
54+
55+
wait_for(
56+
|| async { client.count(&index.keyspace, &index.index).await == Some(3) },
57+
"Waiting for 3 vectors to be indexed",
58+
Duration::from_secs(5),
59+
)
60+
.await;
61+
62+
// Check if the query returns the expected results (recall at least 85%)
63+
let rows = get_query_results(
64+
format!(
65+
"SELECT pk, vector_similarity() FROM {table} ORDER BY v ANN OF [0.0, 0.0, 0.0] LIMIT 5"
66+
),
67+
&session,
68+
)
69+
.await;
70+
assert_eq!(rows.len(), 3);
71+
let results: Vec<(i32, f32)> = vec![(0, 14.0), (1, 77.0), (2, 194.0)];
72+
for (i, row) in rows.iter().enumerate() {
73+
let pk: i32 = row.columns[0].as_ref().unwrap().as_int().unwrap();
74+
let similarity: f32 = row.columns[1].as_ref().unwrap().as_float().unwrap();
75+
assert_eq!(
76+
(pk, similarity),
77+
results[i],
78+
"Row {i} does not match expected result"
79+
);
80+
}
81+
82+
// Drop keyspace
83+
session
84+
.query_unpaged(format!("DROP KEYSPACE {keyspace}"), ())
85+
.await
86+
.expect("failed to drop a keyspace");
87+
88+
info!("finished");
89+
}

0 commit comments

Comments
 (0)