Skip to content

Commit 3e9256f

Browse files
committed
Fix MariaDBVectorStore Document embedding
- Create an explicit MariaDBDocument to store the embeddings of its content - This is because the Spring AI Document no longer holds reference to its embeddings - Address the test case which checks just the MariaDB documents storing without their embeddings - Re-enable the MariDB ITs
1 parent 957a2ea commit 3e9256f

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

vector-stores/spring-ai-mariadb-store/src/main/java/org/springframework/ai/vectorstore/mariadb/MariaDBVectorStore.java

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
* vector index will be auto-created if not available.
5252
*
5353
* @author Diego Dupin
54+
* @author Ilayaperumal Gopinathan
5455
* @since 1.0.0
5556
*/
5657
public class MariaDBVectorStore extends AbstractObservationVectorStore implements InitializingBean {
@@ -192,21 +193,36 @@ public MariaDBDistanceType getDistanceType() {
192193
@Override
193194
public void doAdd(List<Document> documents) {
194195
// Batch the documents based on the batching strategy
195-
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
196+
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
197+
this.batchingStrategy);
196198

197-
List<List<Document>> batchedDocuments = batchDocuments(documents);
199+
List<List<MariaDBDocument>> batchedDocuments = batchDocuments(documents, embeddings);
198200
batchedDocuments.forEach(this::insertOrUpdateBatch);
199201
}
200202

201-
private List<List<Document>> batchDocuments(List<Document> documents) {
202-
List<List<Document>> batches = new ArrayList<>();
203-
for (int i = 0; i < documents.size(); i += this.maxDocumentBatchSize) {
204-
batches.add(documents.subList(i, Math.min(i + this.maxDocumentBatchSize, documents.size())));
203+
private List<List<MariaDBDocument>> batchDocuments(List<Document> documents, List<float[]> embeddings) {
204+
List<List<MariaDBDocument>> batches = new ArrayList<>();
205+
List<MariaDBDocument> mariaDBDocuments = new ArrayList<>(documents.size());
206+
if (embeddings.size() == documents.size()) {
207+
for (Document document : documents) {
208+
mariaDBDocuments.add(new MariaDBDocument(document.getId(), document.getContent(),
209+
document.getMetadata(), embeddings.get(documents.indexOf(document))));
210+
}
211+
}
212+
else {
213+
for (Document document : documents) {
214+
mariaDBDocuments
215+
.add(new MariaDBDocument(document.getId(), document.getContent(), document.getMetadata(), null));
216+
}
217+
}
218+
219+
for (int i = 0; i < mariaDBDocuments.size(); i += this.maxDocumentBatchSize) {
220+
batches.add(mariaDBDocuments.subList(i, Math.min(i + this.maxDocumentBatchSize, mariaDBDocuments.size())));
205221
}
206222
return batches;
207223
}
208224

209-
private void insertOrUpdateBatch(List<Document> batch) {
225+
private void insertOrUpdateBatch(List<MariaDBDocument> batch) {
210226
String sql = String.format(
211227
"INSERT INTO %s (%s, %s, %s, %s) VALUES (?, ?, ?, ?) "
212228
+ "ON DUPLICATE KEY UPDATE %s = VALUES(%s) , %s = VALUES(%s) , %s = VALUES(%s)",
@@ -219,10 +235,10 @@ private void insertOrUpdateBatch(List<Document> batch) {
219235
@Override
220236
public void setValues(PreparedStatement ps, int i) throws SQLException {
221237
var document = batch.get(i);
222-
ps.setObject(1, document.getId());
223-
ps.setString(2, document.getContent());
224-
ps.setString(3, toJson(document.getMetadata()));
225-
ps.setObject(4, document.getEmbedding());
238+
ps.setObject(1, document.id());
239+
ps.setString(2, document.content());
240+
ps.setString(3, toJson(document.metadata()));
241+
ps.setObject(4, document.embedding());
226242
}
227243

228244
@Override
@@ -556,4 +572,15 @@ public MariaDBVectorStore build() {
556572

557573
}
558574

575+
/**
576+
* The representation of {@link Document} along with its embedding.
577+
*
578+
* @param id The id of the document
579+
* @param content The content of the document
580+
* @param metadata The metadata of the document
581+
* @param embedding The vectors representing the content of the document
582+
*/
583+
public record MariaDBDocument(String id, String content, Map<String, Object> metadata, float[] embedding) {
584+
}
585+
559586
}

vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
*/
6666
@Testcontainers
6767
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
68-
@Disabled("Failing after commit ebd29e0")
6968
public class MariaDBStoreIT {
7069

7170
private static String schemaName = "testdb";

vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreObservationIT.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
*/
6565
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
6666
@Testcontainers
67-
@Disabled("Failing after commit ebd29e0")
6867
public class MariaDBStoreObservationIT {
6968

7069
private static String schemaName = "testdb";

0 commit comments

Comments
 (0)