Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,20 @@ public void deleteByConversationId(String conversationId) {
saveAll(conversationId, List.of());
}

@Override
public void refresh(String conversationId, List<Message> deletes, List<Message> adds) {
Assert.hasText(conversationId, "conversationId cannot be null or empty");
Assert.notNull(deletes, "deletes cannot be null");
Assert.notNull(adds, "adds cannot be null");

// RMW (Read-Modify-Write) is the only way with the current schema.
// This is not efficient, but it is correct.
List<Message> currentMessages = new ArrayList<>(this.findByConversationId(conversationId));
currentMessages.removeAll(deletes);
currentMessages.addAll(adds);
this.saveAll(conversationId, currentMessages);
}

private PreparedStatement prepareAddStmt() {
RegularInsert stmt = null;
InsertInto stmtStart = QueryBuilder.insertInto(this.conf.schema.keyspace(), this.conf.schema.table());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,9 @@ public String getSelectConversationIdsSql() {
return "SELECT DISTINCT conversation_id FROM SPRING_AI_CHAT_MEMORY";
}

@Override
public String getDeleteMessageSql() {
return "DELETE FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? AND content = ? AND type = ?";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,51 @@ public void deleteByConversationId(String conversationId) {
this.jdbcTemplate.update(this.dialect.getDeleteMessagesSql(), conversationId);
}

@Override
public void refresh(String conversationId, List<Message> deletes, List<Message> adds) {
Assert.hasText(conversationId, "conversationId cannot be null or empty");
Assert.notNull(deletes, "deletes cannot be null");
Assert.notNull(adds, "adds cannot be null");

this.transactionTemplate.execute(status -> {
if (!deletes.isEmpty()) {
// This is a simplification. In a real implementation, we would need a
// stable
// way to identify messages to delete, perhaps by adding a message_id
// column.
// For now, we delete based on content and type, which is not robust.
this.jdbcTemplate.batchUpdate(this.dialect.getDeleteMessageSql(),
new DeleteBatchPreparedStatement(conversationId, deletes));
}
if (!adds.isEmpty()) {
this.jdbcTemplate.batchUpdate(this.dialect.getInsertMessageSql(),
new AddBatchPreparedStatement(conversationId, adds));
}
return null;
});
}

public static Builder builder() {
return new Builder();
}

private record DeleteBatchPreparedStatement(String conversationId,
List<Message> messages) implements BatchPreparedStatementSetter {

@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
var message = this.messages.get(i);
ps.setString(1, this.conversationId);
ps.setString(2, message.getText());
ps.setString(3, message.getMessageType().name());
}

@Override
public int getBatchSize() {
return this.messages.size();
}
}

private record AddBatchPreparedStatement(String conversationId, List<Message> messages,
AtomicLong instantSeq) implements BatchPreparedStatementSetter {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ public interface JdbcChatMemoryRepositoryDialect {
*/
String getDeleteMessagesSql();

/**
* Returns the SQL to delete a single message for a conversation.
*/
String getDeleteMessageSql();

/**
* Optionally, dialect can provide more advanced SQL as needed.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,9 @@ public String getDeleteMessagesSql() {
return "DELETE FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?";
}

@Override
public String getDeleteMessageSql() {
return "DELETE FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? AND content = ? AND type = ?";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,9 @@ public String getDeleteMessagesSql() {
return "DELETE FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?";
}

@Override
public String getDeleteMessageSql() {
return "DELETE FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? AND content = ? AND type = ?";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,9 @@ public String getDeleteMessagesSql() {
return "DELETE FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?";
}

@Override
public String getDeleteMessageSql() {
return "DELETE FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? AND content = ? AND type = ?";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,31 @@ void testMessageOrder() {
"4-Fourth message");
}

@Test
void refreshConversation() {
var conversationId = UUID.randomUUID().toString();
List<Message> initialMessages = List.of(new UserMessage("Hello"), new AssistantMessage("Hi there"),
new UserMessage("How are you?"));
this.chatMemoryRepository.saveAll(conversationId, initialMessages);

assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);

// Define changes
List<Message> toDelete = List.of(new UserMessage("How are you?"));
List<Message> toAdd = List.of(new AssistantMessage("I am fine, thank you."));

// Apply changes
this.chatMemoryRepository.refresh(conversationId, toDelete, toAdd);

// Verify final state
List<Message> finalMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(finalMessages).hasSize(3);
assertThat(finalMessages).contains(new UserMessage("Hello"));
assertThat(finalMessages).contains(new AssistantMessage("Hi there"));
assertThat(finalMessages).contains(new AssistantMessage("I am fine, thank you."));
assertThat(finalMessages).doesNotContain(new UserMessage("How are you?"));
}

/**
* Base configuration for all integration tests.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,34 @@ OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s)
}
}

@Override
public void refresh(String conversationId, List<Message> deletes, List<Message> adds) {
try (Session s = this.config.getDriver().session()) {
s.executeWriteWithoutResult(tx -> {
if (!deletes.isEmpty()) {
List<String> messageIds = deletes.stream().map(m -> (String) m.getMetadata().get("id")).toList();

String deleteStatement = """
MATCH (m:%s) WHERE m.id IN $messageIds
OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:%s)
OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:%s)
OPTIONAL MATCH (m)-[:HAS_TOOL_RESPONSE]-(tr:%s)
OPTIONAL MATCH (m)-[:HAS_TOOL_CALL]->(tc:%s)
DETACH DELETE m, metadata, media, tr, tc
""".formatted(this.config.getMessageLabel(), this.config.getMetadataLabel(),
this.config.getMediaLabel(), this.config.getToolResponseLabel(),
this.config.getToolCallLabel());
tx.run(deleteStatement, Map.of("messageIds", messageIds));
}
if (!adds.isEmpty()) {
for (Message m : adds) {
addMessageToTransaction(tx, conversationId, m);
}
}
});
}
}

public Neo4jChatMemoryRepositoryConfig getConfig() {
return this.config;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,45 @@ void saveAndFindMessagesWithEmptyContentOrMetadata() {
// messageType
}

@Test
void refreshConversation() {
var conversationId = UUID.randomUUID().toString();

// 1. Save initial messages
List<Message> initialMessages = List.of(new UserMessage("Hello"), new AssistantMessage("Hi there"),
new UserMessage("How are you?"));
this.chatMemoryRepository.saveAll(conversationId, initialMessages);

// Retrieve to get metadata (especially the generated message IDs)
List<Message> savedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(savedMessages).hasSize(3);

// 2. Define changes
var messageToDelete = savedMessages.stream().filter(m -> m.getText().equals("How are you?")).findFirst().get();
var toDelete = List.of(messageToDelete);
List<Message> toAdd = List.of(new AssistantMessage("I am fine, thank you."));

// 3. Apply changes
this.chatMemoryRepository.refresh(conversationId, toDelete, toAdd);

// 4. Verify final state
List<Message> finalMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(finalMessages).hasSize(3);

List<String> finalContents = finalMessages.stream().map(Message::getText).toList();
assertThat(finalContents).contains("Hello", "Hi there", "I am fine, thank you.");
assertThat(finalContents).doesNotContain("How are you?");

// Verify directly in the database
try (Session session = this.driver.session()) {
var result = session.run(
"MATCH (s:%s {id:$conversationId})-[:HAS_MESSAGE]->(m:%s) RETURN count(m) as count"
.formatted(this.config.getSessionLabel(), this.config.getMessageLabel()),
Map.of("conversationId", conversationId));
assertThat(result.single().get("count").asLong()).isEqualTo(3);
}
}

private Message createMessageByType(String content, MessageType messageType) {
return switch (messageType) {
case ASSISTANT -> new AssistantMessage(content);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,15 @@ public interface ChatMemoryRepository {

void deleteByConversationId(String conversationId);

/**
* Atomically removes the messages in {@code deletes} and adds the messages in {@code adds}
* for the given conversation ID. This provides a more efficient way to update
* the memory than reading the entire history and overwriting it.
*
* @param conversationId The ID of the conversation to update.
* @param deletes A list of messages to be removed from the memory.
* @param adds A list of new messages to be added to the memory.
*/
void refresh(String conversationId, List<Message> deletes, List<Message> adds);

}
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,17 @@ public void deleteByConversationId(String conversationId) {
this.chatMemoryStore.remove(conversationId);
}

@Override
public void refresh(String conversationId, List<Message> deletes, List<Message> adds) {
this.chatMemoryStore.compute(conversationId, (key, currentMessages) -> {
if (currentMessages == null) {
return new ArrayList<>(adds);
}
List<Message> updatedMessages = new ArrayList<>(currentMessages);
updatedMessages.removeAll(deletes);
updatedMessages.addAll(adds);
return updatedMessages;
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -61,8 +62,10 @@ public void add(String conversationId, List<Message> messages) {
Assert.noNullElements(messages, "messages cannot contain null elements");

List<Message> memoryMessages = this.chatMemoryRepository.findByConversationId(conversationId);
List<Message> processedMessages = process(memoryMessages, messages);
this.chatMemoryRepository.saveAll(conversationId, processedMessages);
MessageChanges changes = process(memoryMessages, messages);
if (!changes.toDelete.isEmpty() || !changes.toAdd.isEmpty()) {
this.chatMemoryRepository.refresh(conversationId, changes.toDelete, changes.toAdd);
}
}

@Override
Expand All @@ -77,38 +80,58 @@ public void clear(String conversationId) {
this.chatMemoryRepository.deleteByConversationId(conversationId);
}

private List<Message> process(List<Message> memoryMessages, List<Message> newMessages) {
List<Message> processedMessages = new ArrayList<>();
private MessageChanges process(List<Message> memoryMessages, List<Message> newMessages) {
Set<Message> originalMessageSet = new LinkedHashSet<>(memoryMessages);
List<Message> uniqueNewMessages = newMessages.stream()
.filter(msg -> !originalMessageSet.contains(msg))
.toList();
boolean hasNewSystemMessage = uniqueNewMessages.stream().anyMatch(SystemMessage.class::isInstance);

List<Message> finalMessages = new ArrayList<>();
if (hasNewSystemMessage) {
memoryMessages.stream().filter(msg -> !(msg instanceof SystemMessage)).forEach(finalMessages::add);
finalMessages.addAll(uniqueNewMessages);
}
else {
finalMessages.addAll(memoryMessages);
finalMessages.addAll(uniqueNewMessages);
}

Set<Message> memoryMessagesSet = new HashSet<>(memoryMessages);
boolean hasNewSystemMessage = newMessages.stream()
.filter(SystemMessage.class::isInstance)
.anyMatch(message -> !memoryMessagesSet.contains(message));
if (finalMessages.size() > this.maxMessages) {
List<Message> trimmedMessages = new ArrayList<>();
int messagesToRemove = finalMessages.size() - this.maxMessages;
int removed = 0;
for (Message message : finalMessages) {
if (message instanceof SystemMessage || removed >= messagesToRemove) {
trimmedMessages.add(message);
}
else {
removed++;
}
}
finalMessages = trimmedMessages;
}

memoryMessages.stream()
.filter(message -> !(hasNewSystemMessage && message instanceof SystemMessage))
.forEach(processedMessages::add);
Set<Message> finalMessageSet = new LinkedHashSet<>(finalMessages);

processedMessages.addAll(newMessages);
List<Message> toDelete = originalMessageSet.stream().filter(m -> !finalMessageSet.contains(m)).toList();

if (processedMessages.size() <= this.maxMessages) {
return processedMessages;
}
List<Message> toAdd = finalMessageSet.stream().filter(m -> !originalMessageSet.contains(m)).toList();

int messagesToRemove = processedMessages.size() - this.maxMessages;
return new MessageChanges(toDelete, toAdd);
}

List<Message> trimmedMessages = new ArrayList<>();
int removed = 0;
for (Message message : processedMessages) {
if (message instanceof SystemMessage || removed >= messagesToRemove) {
trimmedMessages.add(message);
}
else {
removed++;
}
private static class MessageChanges {

final List<Message> toDelete;

final List<Message> toAdd;

MessageChanges(List<Message> toDelete, List<Message> toAdd) {
this.toDelete = toDelete;
this.toAdd = toAdd;
}

return trimmedMessages;
}

public static Builder builder() {
Expand Down
Loading