Skip to content

Commit 82e4110

Browse files
committed
[Store][MariaDB] Add support for custom WHERE clause
1 parent f78ac35 commit 82e4110

File tree

3 files changed

+220
-38
lines changed

3 files changed

+220
-38
lines changed

src/store/src/Bridge/MariaDb/Store.php

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,20 @@ public function add(VectorDocument ...$documents): void
143143
*/
144144
public function query(Vector $vector, array $options = []): array
145145
{
146+
$where = null;
147+
146148
$maxScore = $options['maxScore'] ?? null;
149+
if ($maxScore) {
150+
$where = \sprintf('WHERE VEC_DISTANCE_EUCLIDEAN(`%1$s`, VEC_FromText(:embedding)) <= :maxScore', $this->vectorFieldName);
151+
}
152+
153+
if ($options['where'] ?? false) {
154+
if ($where) {
155+
$where .= ' AND ('.$options['where'].')';
156+
} else {
157+
$where = 'WHERE '.$options['where'];
158+
}
159+
}
147160

148161
$statement = $this->connection->prepare(
149162
\sprintf(
@@ -156,12 +169,15 @@ public function query(Vector $vector, array $options = []): array
156169
SQL,
157170
$this->vectorFieldName,
158171
$this->tableName,
159-
null !== $maxScore ? \sprintf('WHERE VEC_DISTANCE_EUCLIDEAN(%1$s, VEC_FromText(:embedding)) <= :maxScore', $this->vectorFieldName) : '',
172+
$where ?? '',
160173
$options['limit'] ?? 5,
161174
),
162175
);
163176

164-
$params = ['embedding' => json_encode($vector->getData())];
177+
$params = [
178+
'embedding' => json_encode($vector->getData()),
179+
...$options['params'] ?? [],
180+
];
165181

166182
if (null !== $maxScore) {
167183
$params['maxScore'] = $maxScore;

src/store/tests/Bridge/MariaDb/StoreTest.php

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public function testQueryWithMaxScore()
3030
$expectedQuery = <<<'SQL'
3131
SELECT id, VEC_ToText(`embedding`) embedding, metadata, VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) AS score
3232
FROM embeddings_table
33-
WHERE VEC_DISTANCE_EUCLIDEAN(embedding, VEC_FromText(:embedding)) <= :maxScore
33+
WHERE VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) <= :maxScore
3434
ORDER BY score ASC
3535
LIMIT 5
3636
SQL;
@@ -155,6 +155,151 @@ public function testQueryWithCustomLimit()
155155
$this->assertCount(0, $results);
156156
}
157157

158+
public function testQueryWithCustomWhereExpression()
159+
{
160+
$pdo = $this->createMock(\PDO::class);
161+
$statement = $this->createMock(\PDOStatement::class);
162+
163+
$store = new Store($pdo, 'embeddings_table', 'embedding_idx', 'embedding');
164+
165+
$expectedQuery = <<<SQL
166+
SELECT id, VEC_ToText(`embedding`) embedding, metadata, VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) AS score
167+
FROM embeddings_table
168+
WHERE metadata->>'category' = 'products'
169+
ORDER BY score
170+
ASC LIMIT 5
171+
SQL;
172+
173+
$pdo->expects($this->once())
174+
->method('prepare')
175+
->with($this->callback(function ($sql) use ($expectedQuery) {
176+
$this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql));
177+
178+
return true;
179+
}))
180+
->willReturn($statement);
181+
182+
$statement->expects($this->once())
183+
->method('execute')
184+
->with(['embedding' => '[0.1,0.2,0.3]']);
185+
186+
$statement->expects($this->once())
187+
->method('fetchAll')
188+
->with(\PDO::FETCH_ASSOC)
189+
->willReturn([]);
190+
191+
$results = $store->query(new Vector([0.1, 0.2, 0.3]), ['where' => 'metadata->>\'category\' = \'products\'']);
192+
193+
$this->assertCount(0, $results);
194+
}
195+
196+
public function testQueryWithCustomWhereExpressionAndMaxScore()
197+
{
198+
$pdo = $this->createMock(\PDO::class);
199+
$statement = $this->createMock(\PDOStatement::class);
200+
201+
$store = new Store($pdo, 'embeddings_table', 'embedding_idx', 'embedding');
202+
203+
$expectedQuery = <<<SQL
204+
SELECT id, VEC_ToText(`embedding`) embedding, metadata, VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) AS score
205+
FROM embeddings_table
206+
WHERE VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) <= :maxScore
207+
AND (metadata->>'active' = 'true')
208+
ORDER BY score ASC
209+
LIMIT 5
210+
SQL;
211+
212+
$pdo->expects($this->once())
213+
->method('prepare')
214+
->with($this->callback(function ($sql) use ($expectedQuery) {
215+
$this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql));
216+
217+
return true;
218+
}))
219+
->willReturn($statement);
220+
221+
$statement->expects($this->once())
222+
->method('execute')
223+
->with([
224+
'embedding' => '[0.1,0.2,0.3]',
225+
'maxScore' => 0.5,
226+
]);
227+
228+
$statement->expects($this->once())
229+
->method('fetchAll')
230+
->with(\PDO::FETCH_ASSOC)
231+
->willReturn([]);
232+
233+
$results = $store->query(new Vector([0.1, 0.2, 0.3]), [
234+
'maxScore' => 0.5,
235+
'where' => 'metadata->>\'active\' = \'true\'',
236+
]);
237+
238+
$this->assertCount(0, $results);
239+
}
240+
241+
public function testQueryWithCustomWhereExpressionAndParams()
242+
{
243+
$pdo = $this->createMock(\PDO::class);
244+
$statement = $this->createMock(\PDOStatement::class);
245+
246+
$store = new Store($pdo, 'embeddings_table', 'embedding_idx', 'embedding');
247+
248+
$expectedQuery = <<<SQL
249+
SELECT id, VEC_ToText(`embedding`) embedding, metadata, VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(:embedding)) AS score
250+
FROM embeddings_table
251+
WHERE metadata->>'crawlId' = :crawlId
252+
AND id != :currentId
253+
ORDER BY score
254+
ASC LIMIT 5
255+
SQL;
256+
257+
$pdo->expects($this->once())
258+
->method('prepare')
259+
->with($this->callback(function ($sql) use ($expectedQuery) {
260+
$this->assertSame($this->normalizeQuery($expectedQuery), $this->normalizeQuery($sql));
261+
262+
return true;
263+
}))
264+
->willReturn($statement);
265+
266+
$uuid = Uuid::v4();
267+
$crawlId = '396af6fe-0dfd-47ed-b222-3dbcced3f38e';
268+
269+
$statement->expects($this->once())
270+
->method('execute')
271+
->with([
272+
'embedding' => '[0.1,0.2,0.3]',
273+
'crawlId' => $crawlId,
274+
'currentId' => $uuid->toRfc4122(),
275+
]);
276+
277+
$statement->expects($this->once())
278+
->method('fetchAll')
279+
->with(\PDO::FETCH_ASSOC)
280+
->willReturn([
281+
[
282+
'id' => Uuid::v4()->toRfc4122(),
283+
'embedding' => '[0.4,0.5,0.6]',
284+
'metadata' => json_encode(['crawlId' => $crawlId, 'url' => 'https://example.com']),
285+
'score' => 0.85,
286+
],
287+
]);
288+
289+
$results = $store->query(new Vector([0.1, 0.2, 0.3]), [
290+
'where' => 'metadata->>\'crawlId\' = :crawlId AND id != :currentId',
291+
'params' => [
292+
'crawlId' => $crawlId,
293+
'currentId' => $uuid->toRfc4122(),
294+
],
295+
]);
296+
297+
$this->assertCount(1, $results);
298+
$this->assertSame(0.85, $results[0]->score);
299+
$this->assertSame($crawlId, $results[0]->metadata['crawlId']);
300+
$this->assertSame('https://example.com', $results[0]->metadata['url']);
301+
}
302+
158303
public function testItCanDrop()
159304
{
160305
$pdo = $this->createMock(\PDO::class);
@@ -168,4 +313,9 @@ public function testItCanDrop()
168313

169314
$store->drop();
170315
}
316+
317+
private function normalizeQuery(string $query): string
318+
{
319+
return trim(preg_replace('/\s+/', ' ', $query));
320+
}
171321
}

0 commit comments

Comments
 (0)