diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java index f2e558eb6a8..dce8ddb3773 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java @@ -26,17 +26,14 @@ import java.nio.file.FileAlreadyExistsException; import java.nio.file.Files; import java.util.Comparator; -import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.ObjectWriter; import com.fasterxml.jackson.databind.json.JsonMapper; import io.micrometer.observation.ObservationRegistry; import org.slf4j.Logger; @@ -50,14 +47,16 @@ import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore; import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext; import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention; +import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.Resource; +import org.springframework.util.Assert; /** - * SimpleVectorStore is a simple implementation of the VectorStore interface. - * + * Simple, in-memory implementation of the {@link VectorStore} interface. + *

* It also provides methods to save the current state of the vectors to a file, and to * load vectors from a file. - * + *

* For a deeper understanding of the mathematical concepts and computations involved in * calculating similarity scores among vectors, refer to this * [resource](https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_understanding_vectors). @@ -67,6 +66,8 @@ * @author Mark Pollack * @author Christian Tzolov * @author Sebastien Deleuze + * @author John Blum + * @see VectorStore */ public class SimpleVectorStore extends AbstractObservationVectorStore { @@ -87,7 +88,8 @@ public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry obse super(observationRegistry, customObservationConvention); - Objects.requireNonNull(embeddingModel, "EmbeddingModel must not be null"); + Assert.notNull(embeddingModel, "EmbeddingModel must not be null"); + this.embeddingModel = embeddingModel; this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build(); } @@ -95,38 +97,55 @@ public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry obse @Override public void doAdd(List documents) { for (Document document : documents) { - logger.info("Calling EmbeddingModel for document id = {}", document.getId()); - float[] embedding = this.embeddingModel.embed(document); - document.setEmbedding(embedding); + logger.info("Calling EmbeddingModel for Document id = {}", document.getId()); + document = embed(document); this.store.put(document.getId(), document); } } + protected Document embed(Document document) { + float[] documentEmbedding = this.embeddingModel.embed(document); + document.setEmbedding(documentEmbedding); + return document; + } + @Override public Optional doDelete(List idList) { - for (String id : idList) { - this.store.remove(id); - } + idList.forEach(this.store::remove); return Optional.of(true); } @Override public List doSimilaritySearch(SearchRequest request) { + if (request.getFilterExpression() != null) { throw new UnsupportedOperationException( - "The [" + this.getClass() + "] doesn't support metadata filtering!"); + "[%s] doesn't support metadata filtering".formatted(getClass().getName())); } - float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery()); - return this.store.values() - .stream() - .map(entry -> new Similarity(entry.getId(), - EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding()))) - .filter(s -> s.score >= request.getSimilarityThreshold()) - .sorted(Comparator.comparingDouble(s -> s.score).reversed()) + // @formatter:off + return this.store.values().stream() + .map(document -> computeSimilarity(request, document)) + .filter(similarity -> similarity.score >= request.getSimilarityThreshold()) + .sorted(Comparator.comparingDouble(similarity -> similarity.score).reversed()) .limit(request.getTopK()) - .map(s -> this.store.get(s.key)) + .map(similarity -> this.store.get(similarity.key)) .toList(); + // @formatter:on + } + + protected Similarity computeSimilarity(SearchRequest request, Document document) { + + float[] userQueryEmbedding = getUserQueryEmbedding(request); + float[] documentEmbedding = document.getEmbedding(); + + double score = computeCosineSimilarity(userQueryEmbedding, documentEmbedding); + + return new Similarity(document.getId(), score); + } + + protected double computeCosineSimilarity(float[] userQueryEmbedding, float[] storedDocumentEmbedding) { + return EmbeddingMath.cosineSimilarity(userQueryEmbedding, storedDocumentEmbedding); } /** @@ -134,7 +153,7 @@ public List doSimilaritySearch(SearchRequest request) { * @param file the file to save the vector store content */ public void save(File file) { - String json = getVectorDbAsJson(); + try { if (!file.exists()) { logger.info("Creating new vector store file: {}", file); @@ -145,28 +164,22 @@ public void save(File file) { throw new RuntimeException("File already exists: " + file, e); } catch (IOException e) { - throw new RuntimeException("Failed to create new file: " + file + ". Reason: " + e.getMessage(), e); + throw new RuntimeException("Failed to create new file: " + file + "; Reason: " + e.getMessage(), e); } } else { logger.info("Overwriting existing vector store file: {}", file); } + try (OutputStream stream = new FileOutputStream(file); Writer writer = new OutputStreamWriter(stream, StandardCharsets.UTF_8)) { + String json = getVectorDbAsJson(); writer.write(json); writer.flush(); } } - catch (IOException ex) { - logger.error("IOException occurred while saving vector store file.", ex); - throw new RuntimeException(ex); - } - catch (SecurityException ex) { - logger.error("SecurityException occurred while saving vector store file.", ex); - throw new RuntimeException(ex); - } - catch (NullPointerException ex) { - logger.error("NullPointerException occurred while saving vector store file.", ex); + catch (IOException | NullPointerException | SecurityException ex) { + logger.error("%s occurred while saving vector store file".formatted(ex.getClass().getSimpleName()), ex); throw new RuntimeException(ex); } } @@ -176,16 +189,7 @@ public void save(File file) { * @param file the file to load the vector store content */ public void load(File file) { - TypeReference> typeRef = new TypeReference<>() { - - }; - try { - Map deserializedMap = this.objectMapper.readValue(file, typeRef); - this.store = deserializedMap; - } - catch (IOException ex) { - throw new RuntimeException(ex); - } + load(new FileSystemResource(file)); } /** @@ -193,28 +197,32 @@ public void load(File file) { * @param resource the resource to load the vector store content */ public void load(Resource resource) { - TypeReference> typeRef = new TypeReference<>() { - }; try { - Map deserializedMap = this.objectMapper.readValue(resource.getInputStream(), typeRef); - this.store = deserializedMap; + this.store = this.objectMapper.readValue(resource.getInputStream(), documentMapTypeRef()); } catch (IOException ex) { throw new RuntimeException(ex); } } + private TypeReference> documentMapTypeRef() { + return new TypeReference<>() { + }; + } + private String getVectorDbAsJson() { - ObjectWriter objectWriter = this.objectMapper.writerWithDefaultPrettyPrinter(); - String json; + try { - json = objectWriter.writeValueAsString(this.store); + return this.objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(this.store); } catch (JsonProcessingException e) { - throw new RuntimeException("Error serializing documentMap to JSON.", e); + throw new RuntimeException("Error serializing Map of Documents to JSON", e); } - return json; + } + + private float[] getUserQueryEmbedding(SearchRequest request) { + return getUserQueryEmbedding(request.getQuery()); } private float[] getUserQueryEmbedding(String query) { @@ -232,9 +240,9 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str public static class Similarity { - private String key; + private final String key; - private double score; + private final double score; public Similarity(String key, double score) { this.key = key; @@ -243,16 +251,18 @@ public Similarity(String key, double score) { } - public final class EmbeddingMath { + public static final class EmbeddingMath { private EmbeddingMath() { throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); } public static double cosineSimilarity(float[] vectorX, float[] vectorY) { + if (vectorX == null || vectorY == null) { - throw new RuntimeException("Vectors must not be null"); + throw new IllegalArgumentException("Vectors must not be null"); } + if (vectorX.length != vectorY.length) { throw new IllegalArgumentException("Vectors lengths must be equal"); } @@ -268,20 +278,22 @@ public static double cosineSimilarity(float[] vectorX, float[] vectorY) { return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY)); } - public static float dotProduct(float[] vectorX, float[] vectorY) { + private static float dotProduct(float[] vectorX, float[] vectorY) { + if (vectorX.length != vectorY.length) { throw new IllegalArgumentException("Vectors lengths must be equal"); } float result = 0; - for (int i = 0; i < vectorX.length; ++i) { - result += vectorX[i] * vectorY[i]; + + for (int index = 0; index < vectorX.length; ++index) { + result += vectorX[index] * vectorY[index]; } return result; } - public static float norm(float[] vector) { + private static float norm(float[] vector) { return dotProduct(vector, vector); }