diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java
index f2e558eb6a8..dce8ddb3773 100644
--- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java
+++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java
@@ -26,17 +26,14 @@
import java.nio.file.FileAlreadyExistsException;
import java.nio.file.Files;
import java.util.Comparator;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
-import com.fasterxml.jackson.databind.ObjectWriter;
import com.fasterxml.jackson.databind.json.JsonMapper;
import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
@@ -50,14 +47,16 @@
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
+import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
+import org.springframework.util.Assert;
/**
- * SimpleVectorStore is a simple implementation of the VectorStore interface.
- *
+ * Simple, in-memory implementation of the {@link VectorStore} interface.
+ *
* It also provides methods to save the current state of the vectors to a file, and to
* load vectors from a file.
- *
+ *
* For a deeper understanding of the mathematical concepts and computations involved in
* calculating similarity scores among vectors, refer to this
* [resource](https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_understanding_vectors).
@@ -67,6 +66,8 @@
* @author Mark Pollack
* @author Christian Tzolov
* @author Sebastien Deleuze
+ * @author John Blum
+ * @see VectorStore
*/
public class SimpleVectorStore extends AbstractObservationVectorStore {
@@ -87,7 +88,8 @@ public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry obse
super(observationRegistry, customObservationConvention);
- Objects.requireNonNull(embeddingModel, "EmbeddingModel must not be null");
+ Assert.notNull(embeddingModel, "EmbeddingModel must not be null");
+
this.embeddingModel = embeddingModel;
this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();
}
@@ -95,38 +97,55 @@ public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry obse
@Override
public void doAdd(List documents) {
for (Document document : documents) {
- logger.info("Calling EmbeddingModel for document id = {}", document.getId());
- float[] embedding = this.embeddingModel.embed(document);
- document.setEmbedding(embedding);
+ logger.info("Calling EmbeddingModel for Document id = {}", document.getId());
+ document = embed(document);
this.store.put(document.getId(), document);
}
}
+ protected Document embed(Document document) {
+ float[] documentEmbedding = this.embeddingModel.embed(document);
+ document.setEmbedding(documentEmbedding);
+ return document;
+ }
+
@Override
public Optional doDelete(List idList) {
- for (String id : idList) {
- this.store.remove(id);
- }
+ idList.forEach(this.store::remove);
return Optional.of(true);
}
@Override
public List doSimilaritySearch(SearchRequest request) {
+
if (request.getFilterExpression() != null) {
throw new UnsupportedOperationException(
- "The [" + this.getClass() + "] doesn't support metadata filtering!");
+ "[%s] doesn't support metadata filtering".formatted(getClass().getName()));
}
- float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery());
- return this.store.values()
- .stream()
- .map(entry -> new Similarity(entry.getId(),
- EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
- .filter(s -> s.score >= request.getSimilarityThreshold())
- .sorted(Comparator.comparingDouble(s -> s.score).reversed())
+ // @formatter:off
+ return this.store.values().stream()
+ .map(document -> computeSimilarity(request, document))
+ .filter(similarity -> similarity.score >= request.getSimilarityThreshold())
+ .sorted(Comparator.comparingDouble(similarity -> similarity.score).reversed())
.limit(request.getTopK())
- .map(s -> this.store.get(s.key))
+ .map(similarity -> this.store.get(similarity.key))
.toList();
+ // @formatter:on
+ }
+
+ protected Similarity computeSimilarity(SearchRequest request, Document document) {
+
+ float[] userQueryEmbedding = getUserQueryEmbedding(request);
+ float[] documentEmbedding = document.getEmbedding();
+
+ double score = computeCosineSimilarity(userQueryEmbedding, documentEmbedding);
+
+ return new Similarity(document.getId(), score);
+ }
+
+ protected double computeCosineSimilarity(float[] userQueryEmbedding, float[] storedDocumentEmbedding) {
+ return EmbeddingMath.cosineSimilarity(userQueryEmbedding, storedDocumentEmbedding);
}
/**
@@ -134,7 +153,7 @@ public List doSimilaritySearch(SearchRequest request) {
* @param file the file to save the vector store content
*/
public void save(File file) {
- String json = getVectorDbAsJson();
+
try {
if (!file.exists()) {
logger.info("Creating new vector store file: {}", file);
@@ -145,28 +164,22 @@ public void save(File file) {
throw new RuntimeException("File already exists: " + file, e);
}
catch (IOException e) {
- throw new RuntimeException("Failed to create new file: " + file + ". Reason: " + e.getMessage(), e);
+ throw new RuntimeException("Failed to create new file: " + file + "; Reason: " + e.getMessage(), e);
}
}
else {
logger.info("Overwriting existing vector store file: {}", file);
}
+
try (OutputStream stream = new FileOutputStream(file);
Writer writer = new OutputStreamWriter(stream, StandardCharsets.UTF_8)) {
+ String json = getVectorDbAsJson();
writer.write(json);
writer.flush();
}
}
- catch (IOException ex) {
- logger.error("IOException occurred while saving vector store file.", ex);
- throw new RuntimeException(ex);
- }
- catch (SecurityException ex) {
- logger.error("SecurityException occurred while saving vector store file.", ex);
- throw new RuntimeException(ex);
- }
- catch (NullPointerException ex) {
- logger.error("NullPointerException occurred while saving vector store file.", ex);
+ catch (IOException | NullPointerException | SecurityException ex) {
+ logger.error("%s occurred while saving vector store file".formatted(ex.getClass().getSimpleName()), ex);
throw new RuntimeException(ex);
}
}
@@ -176,16 +189,7 @@ public void save(File file) {
* @param file the file to load the vector store content
*/
public void load(File file) {
- TypeReference> typeRef = new TypeReference<>() {
-
- };
- try {
- Map deserializedMap = this.objectMapper.readValue(file, typeRef);
- this.store = deserializedMap;
- }
- catch (IOException ex) {
- throw new RuntimeException(ex);
- }
+ load(new FileSystemResource(file));
}
/**
@@ -193,28 +197,32 @@ public void load(File file) {
* @param resource the resource to load the vector store content
*/
public void load(Resource resource) {
- TypeReference> typeRef = new TypeReference<>() {
- };
try {
- Map deserializedMap = this.objectMapper.readValue(resource.getInputStream(), typeRef);
- this.store = deserializedMap;
+ this.store = this.objectMapper.readValue(resource.getInputStream(), documentMapTypeRef());
}
catch (IOException ex) {
throw new RuntimeException(ex);
}
}
+ private TypeReference