Skip to content

Commit 6131aed

Browse files
committed
Refactor ID handling for different IdType formats
- Add handling for UUID, TEXT, INTEGER, SERIAL, BIGSERIAL formats in `convertIdToPgType` function. - Implemented type conversion logic based on the IdType value (UUID, TEXT, INTEGER, SERIAL, BIGSERIAL). - Add unit tests to validate correct conversion for UUID and non-UUID IdType formats. - `testToPgTypeWithUuidIdType`: Validates UUID handling. - `testToPgTypeWithNonUuidIdType`: Validates handling for non-UUID IdTypes. Signed-off-by: jitokim <[email protected]>
1 parent 61409f0 commit 6131aed

File tree

2 files changed

+96
-4
lines changed

2 files changed

+96
-4
lines changed

vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
166166

167167
public static final String DEFAULT_TABLE_NAME = "vector_store";
168168

169+
public static final PgIdType DEFAULT_ID_TYPE = PgIdType.UUID;
170+
169171
public static final String DEFAULT_VECTOR_INDEX_NAME = "spring_ai_vector_index";
170172

171173
public static final String DEFAULT_SCHEMA_NAME = "public";
@@ -191,6 +193,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
191193

192194
private final String schemaName;
193195

196+
private final PgIdType idType;
197+
194198
private final boolean schemaValidation;
195199

196200
private final boolean initializeSchema;
@@ -231,12 +235,22 @@ public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, i
231235
createIndexMethod, initializeSchema);
232236
}
233237

238+
@Deprecated(forRemoval = true, since = "1.0.0-M5")
239+
public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions,
240+
PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod,
241+
boolean initializeSchema) {
242+
243+
this(vectorTableName, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable,
244+
createIndexMethod, initializeSchema, DEFAULT_ID_TYPE);
245+
}
246+
234247
@Deprecated(forRemoval = true, since = "1.0.0-M5")
235248
public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel,
236249
int dimensions, PgDistanceType distanceType, boolean removeExistingVectorStoreTable,
237-
PgIndexType createIndexMethod, boolean initializeSchema) {
250+
PgIndexType createIndexMethod, boolean initializeSchema, PgIdType pgIdType) {
238251

239252
this(builder(jdbcTemplate, embeddingModel).schemaName(DEFAULT_SCHEMA_NAME)
253+
.idType(pgIdType)
240254
.vectorTableName(vectorTableName)
241255
.vectorTableValidationsEnabled(DEFAULT_SCHEMA_VALIDATION)
242256
.dimensions(dimensions)
@@ -265,6 +279,7 @@ protected PgVectorStore(PgVectorStoreBuilder builder) {
265279
: this.vectorTableName + "_index";
266280

267281
this.schemaName = builder.schemaName;
282+
this.idType = builder.idType;
268283
this.schemaValidation = builder.vectorTableValidationsEnabled;
269284

270285
this.jdbcTemplate = builder.jdbcTemplate;
@@ -314,7 +329,7 @@ private void insertOrUpdateBatch(List<Document> batch, List<Document> documents,
314329
public void setValues(PreparedStatement ps, int i) throws SQLException {
315330

316331
var document = batch.get(i);
317-
var id = document.getId();
332+
var id = convertIdToPgType(document.getId());
318333
var content = document.getText();
319334
var json = toJson(document.getMetadata());
320335
var embedding = embeddings.get(documents.indexOf(document));
@@ -345,6 +360,19 @@ private String toJson(Map<String, Object> map) {
345360
}
346361
}
347362

363+
private Object convertIdToPgType(String id) {
364+
if (this.initializeSchema) {
365+
return UUID.fromString(id);
366+
}
367+
368+
return switch (getIdType()) {
369+
case UUID -> UUID.fromString(id);
370+
case TEXT -> id;
371+
case INTEGER, SERIAL -> Integer.valueOf(id);
372+
case BIGSERIAL -> Long.valueOf(id);
373+
};
374+
}
375+
348376
@Override
349377
public Optional<Boolean> doDelete(List<String> idList) {
350378
int updateCount = 0;
@@ -454,6 +482,10 @@ private String getFullyQualifiedTableName() {
454482
return this.schemaName + "." + this.vectorTableName;
455483
}
456484

485+
private PgIdType getIdType() {
486+
return this.idType;
487+
}
488+
457489
private String getVectorTableName() {
458490
return this.vectorTableName;
459491
}
@@ -531,6 +563,14 @@ public enum PgIndexType {
531563

532564
}
533565

566+
public enum PgIdType {
567+
UUID,
568+
TEXT,
569+
INTEGER,
570+
SERIAL,
571+
BIGSERIAL
572+
}
573+
534574
/**
535575
* Defaults to CosineDistance. But if vectors are normalized to length 1 (like OpenAI
536576
* embeddings), use inner product (NegativeInnerProduct) for best performance.
@@ -626,6 +666,8 @@ public static class PgVectorStoreBuilder extends AbstractVectorStoreBuilder<PgVe
626666

627667
private String vectorTableName = PgVectorStore.DEFAULT_TABLE_NAME;
628668

669+
private PgIdType idType = PgVectorStore.DEFAULT_ID_TYPE;
670+
629671
private boolean vectorTableValidationsEnabled = PgVectorStore.DEFAULT_SCHEMA_VALIDATION;
630672

631673
private int dimensions = PgVectorStore.INVALID_EMBEDDING_DIMENSION;
@@ -658,6 +700,11 @@ public PgVectorStoreBuilder vectorTableName(String vectorTableName) {
658700
return this;
659701
}
660702

703+
public PgVectorStoreBuilder idType(PgIdType idType) {
704+
this.idType = idType;
705+
return this;
706+
}
707+
661708
public PgVectorStoreBuilder vectorTableValidationsEnabled(boolean vectorTableValidationsEnabled) {
662709
this.vectorTableValidationsEnabled = vectorTableValidationsEnabled;
663710
return this;

vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@
4141
import org.testcontainers.junit.jupiter.Testcontainers;
4242

4343
import org.springframework.ai.document.Document;
44+
import org.springframework.ai.document.id.RandomIdGenerator;
4445
import org.springframework.ai.embedding.EmbeddingModel;
4546
import org.springframework.ai.openai.OpenAiEmbeddingModel;
4647
import org.springframework.ai.openai.api.OpenAiApi;
48+
import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIdType;
4749
import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexType;
4850
import org.springframework.ai.vectorstore.SearchRequest;
4951
import org.springframework.ai.vectorstore.VectorStore;
@@ -105,6 +107,26 @@ public static String getText(String uri) {
105107
}
106108
}
107109

110+
private static void initSchema(ApplicationContext context) {
111+
PgVectorStore vectorStore = context.getBean(PgVectorStore.class);
112+
JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class);
113+
// Enable the PGVector, JSONB and UUID support.
114+
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
115+
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
116+
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");
117+
118+
jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", PgVectorStore.DEFAULT_SCHEMA_NAME));
119+
120+
jdbcTemplate.execute(String.format("""
121+
CREATE TABLE IF NOT EXISTS %s.%s (
122+
id text PRIMARY KEY,
123+
content text,
124+
metadata json,
125+
embedding vector(%d)
126+
)
127+
""", PgVectorStore.DEFAULT_SCHEMA_NAME, PgVectorStore.DEFAULT_TABLE_NAME, vectorStore.embeddingDimensions()));
128+
}
129+
108130
private static void dropTable(ApplicationContext context) {
109131
JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class);
110132
jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
@@ -169,11 +191,27 @@ public void addAndSearch(String distanceType) {
169191
}
170192

171193
@Test
172-
public void shouldAllowNonUuidFormat() {
194+
public void testToPgTypeWithUuidIdType() {
195+
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
196+
.run(context -> {
197+
198+
VectorStore vectorStore = context.getBean(VectorStore.class);
199+
200+
vectorStore.add(List.of(new Document(new RandomIdGenerator().generateId(), "TEXT")));
201+
202+
dropTable(context);
203+
});
204+
}
205+
206+
@Test
207+
public void testToPgTypeWithNonUuidIdType() {
173208
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
209+
.withPropertyValues("test.spring.ai.vectorstore.pgvector.initializeSchema=" + false)
210+
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "TEXT")
174211
.run(context -> {
175212

176213
VectorStore vectorStore = context.getBean(VectorStore.class);
214+
initSchema(context);
177215

178216
vectorStore.add(List.of(new Document("NOT_UUID", "TEXT")));
179217

@@ -386,12 +424,19 @@ public static class TestApplication {
386424
@Value("${test.spring.ai.vectorstore.pgvector.distanceType}")
387425
PgVectorStore.PgDistanceType distanceType;
388426

427+
@Value("${test.spring.ai.vectorstore.pgvector.initializeSchema:true}")
428+
boolean initializeSchema;
429+
430+
@Value("${test.spring.ai.vectorstore.pgvector.idType:UUID}")
431+
PgIdType idType;
432+
389433
@Bean
390434
public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
391435
return PgVectorStore.builder(jdbcTemplate, embeddingModel)
392436
.dimensions(PgVectorStore.INVALID_EMBEDDING_DIMENSION)
437+
.idType(idType)
393438
.distanceType(this.distanceType)
394-
.initializeSchema(true)
439+
.initializeSchema(initializeSchema)
395440
.indexType(PgIndexType.HNSW)
396441
.removeExistingVectorStoreTable(true)
397442
.build();

0 commit comments

Comments
 (0)