|
20 | 20 | import org.slf4j.Logger; |
21 | 21 | import org.slf4j.LoggerFactory; |
22 | 22 | import org.springframework.ai.chat.memory.ChatMemory; |
| 23 | +import org.springframework.ai.chat.memory.ChatMemoryRepository; |
23 | 24 | import org.springframework.ai.chat.messages.AssistantMessage; |
24 | 25 | import org.springframework.ai.chat.messages.Message; |
25 | 26 | import org.springframework.ai.chat.messages.MessageType; |
26 | 27 | import org.springframework.ai.chat.messages.UserMessage; |
| 28 | +import org.springframework.ai.content.Media; |
| 29 | +import org.springframework.ai.content.MediaContent; |
27 | 30 | import org.springframework.util.Assert; |
28 | 31 | import redis.clients.jedis.JedisPooled; |
29 | 32 | import redis.clients.jedis.Pipeline; |
30 | 33 | import redis.clients.jedis.json.Path2; |
31 | 34 | import redis.clients.jedis.search.*; |
| 35 | +import redis.clients.jedis.search.aggr.AggregationBuilder; |
| 36 | +import redis.clients.jedis.search.aggr.AggregationResult; |
| 37 | +import redis.clients.jedis.search.aggr.Reducers; |
32 | 38 | import redis.clients.jedis.search.schemafields.NumericField; |
33 | 39 | import redis.clients.jedis.search.schemafields.SchemaField; |
34 | 40 | import redis.clients.jedis.search.schemafields.TagField; |
|
37 | 43 | import java.time.Duration; |
38 | 44 | import java.time.Instant; |
39 | 45 | import java.util.ArrayList; |
| 46 | +import java.util.HashMap; |
| 47 | +import java.util.HashSet; |
40 | 48 | import java.util.List; |
41 | 49 | import java.util.Map; |
| 50 | +import java.util.Set; |
42 | 51 | import java.util.concurrent.atomic.AtomicLong; |
43 | 52 |
|
44 | 53 | /** |
45 | | - * Redis implementation of {@link ChatMemory} using Redis Stack (RedisJSON + RediSearch). |
46 | | - * Stores chat messages as JSON documents and uses RediSearch for querying. |
| 54 | + * Redis implementation of {@link ChatMemory} using Redis (JSON + Query Engine). Stores |
| 55 | + * chat messages as JSON documents and uses the Redis Query Engine for querying. |
47 | 56 | * |
48 | 57 | * @author Brian Sam-Bodden |
49 | 58 | */ |
50 | | -public final class RedisChatMemory implements ChatMemory { |
| 59 | +public final class RedisChatMemory implements ChatMemory, ChatMemoryRepository { |
51 | 60 |
|
52 | 61 | private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class); |
53 | 62 |
|
@@ -113,26 +122,79 @@ public List<Message> get(String conversationId, int lastN) { |
113 | 122 | Assert.isTrue(lastN > 0, "LastN must be greater than 0"); |
114 | 123 |
|
115 | 124 | String queryStr = String.format("@conversation_id:{%s}", RediSearchUtil.escape(conversationId)); |
| 125 | + // Use ascending order (oldest first) to match test expectations |
116 | 126 | Query query = new Query(queryStr).setSortBy("timestamp", true).limit(0, lastN); |
117 | 127 |
|
118 | 128 | SearchResult result = jedis.ftSearch(config.getIndexName(), query); |
119 | 129 |
|
| 130 | + if (logger.isDebugEnabled()) { |
| 131 | + logger.debug("Redis search for conversation {} returned {} results", conversationId, |
| 132 | + result.getDocuments().size()); |
| 133 | + result.getDocuments().forEach(doc -> { |
| 134 | + if (doc.get("$") != null) { |
| 135 | + JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class); |
| 136 | + logger.debug("Document: {}", json); |
| 137 | + } |
| 138 | + }); |
| 139 | + } |
| 140 | + |
120 | 141 | List<Message> messages = new ArrayList<>(); |
121 | 142 | result.getDocuments().forEach(doc -> { |
122 | 143 | if (doc.get("$") != null) { |
123 | 144 | JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class); |
124 | 145 | String type = json.get("type").getAsString(); |
125 | 146 | String content = json.get("content").getAsString(); |
126 | 147 |
|
| 148 | + // Convert metadata from JSON to Map if present |
| 149 | + Map<String, Object> metadata = new HashMap<>(); |
| 150 | + if (json.has("metadata") && json.get("metadata").isJsonObject()) { |
| 151 | + JsonObject metadataJson = json.getAsJsonObject("metadata"); |
| 152 | + metadataJson.entrySet().forEach(entry -> { |
| 153 | + metadata.put(entry.getKey(), gson.fromJson(entry.getValue(), Object.class)); |
| 154 | + }); |
| 155 | + } |
| 156 | + |
127 | 157 | if (MessageType.ASSISTANT.toString().equals(type)) { |
128 | | - messages.add(new AssistantMessage(content)); |
| 158 | + // Handle tool calls if present |
| 159 | + List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>(); |
| 160 | + if (json.has("toolCalls") && json.get("toolCalls").isJsonArray()) { |
| 161 | + json.getAsJsonArray("toolCalls").forEach(element -> { |
| 162 | + JsonObject toolCallJson = element.getAsJsonObject(); |
| 163 | + toolCalls.add(new AssistantMessage.ToolCall( |
| 164 | + toolCallJson.has("id") ? toolCallJson.get("id").getAsString() : "", |
| 165 | + toolCallJson.has("type") ? toolCallJson.get("type").getAsString() : "", |
| 166 | + toolCallJson.has("name") ? toolCallJson.get("name").getAsString() : "", |
| 167 | + toolCallJson.has("arguments") ? toolCallJson.get("arguments").getAsString() : "")); |
| 168 | + }); |
| 169 | + } |
| 170 | + |
| 171 | + // Handle media if present |
| 172 | + List<Media> media = new ArrayList<>(); |
| 173 | + if (json.has("media") && json.get("media").isJsonArray()) { |
| 174 | + // Media deserialization would go here if needed |
| 175 | + // Left as empty list for simplicity |
| 176 | + } |
| 177 | + |
| 178 | + messages.add(new AssistantMessage(content, metadata, toolCalls, media)); |
129 | 179 | } |
130 | 180 | else if (MessageType.USER.toString().equals(type)) { |
131 | | - messages.add(new UserMessage(content)); |
| 181 | + // Create a UserMessage with the builder to properly set metadata |
| 182 | + List<Media> userMedia = new ArrayList<>(); |
| 183 | + if (json.has("media") && json.get("media").isJsonArray()) { |
| 184 | + // Media deserialization would go here if needed |
| 185 | + } |
| 186 | + messages.add(UserMessage.builder().text(content).metadata(metadata).media(userMedia).build()); |
132 | 187 | } |
| 188 | + // Add handling for other message types if needed |
133 | 189 | } |
134 | 190 | }); |
135 | 191 |
|
| 192 | + if (logger.isDebugEnabled()) { |
| 193 | + logger.debug("Returning {} messages for conversation {}", messages.size(), conversationId); |
| 194 | + messages.forEach(message -> logger.debug("Message type: {}, content: {}", message.getMessageType(), |
| 195 | + message.getText())); |
| 196 | + } |
| 197 | + |
136 | 198 | return messages; |
137 | 199 | } |
138 | 200 |
|
@@ -179,14 +241,133 @@ private String createKey(String conversationId, long timestamp) { |
179 | 241 | } |
180 | 242 |
|
181 | 243 | private Map<String, Object> createMessageDocument(String conversationId, Message message) { |
182 | | - return Map.of("type", message.getMessageType().toString(), "content", message.getText(), "conversation_id", |
183 | | - conversationId, "timestamp", Instant.now().toEpochMilli()); |
| 244 | + Map<String, Object> documentMap = new HashMap<>(); |
| 245 | + documentMap.put("type", message.getMessageType().toString()); |
| 246 | + documentMap.put("content", message.getText()); |
| 247 | + documentMap.put("conversation_id", conversationId); |
| 248 | + documentMap.put("timestamp", Instant.now().toEpochMilli()); |
| 249 | + |
| 250 | + // Store metadata/properties |
| 251 | + if (message.getMetadata() != null && !message.getMetadata().isEmpty()) { |
| 252 | + documentMap.put("metadata", message.getMetadata()); |
| 253 | + } |
| 254 | + |
| 255 | + // Handle tool calls for AssistantMessage |
| 256 | + if (message instanceof AssistantMessage assistantMessage && assistantMessage.hasToolCalls()) { |
| 257 | + documentMap.put("toolCalls", assistantMessage.getToolCalls()); |
| 258 | + } |
| 259 | + |
| 260 | + // Handle media content |
| 261 | + if (message instanceof MediaContent mediaContent && !mediaContent.getMedia().isEmpty()) { |
| 262 | + documentMap.put("media", mediaContent.getMedia()); |
| 263 | + } |
| 264 | + |
| 265 | + return documentMap; |
184 | 266 | } |
185 | 267 |
|
186 | 268 | private String escapeKey(String key) { |
187 | 269 | return key.replace(":", "\\:"); |
188 | 270 | } |
189 | 271 |
|
| 272 | + // ChatMemoryRepository implementation |
| 273 | + |
| 274 | + /** |
| 275 | + * Finds all unique conversation IDs using Redis aggregation. This method is optimized |
| 276 | + * to perform the deduplication on the Redis server side. |
| 277 | + * @return a list of unique conversation IDs |
| 278 | + */ |
| 279 | + @Override |
| 280 | + public List<String> findConversationIds() { |
| 281 | + try { |
| 282 | + // Use Redis aggregation to get distinct conversation_ids |
| 283 | + AggregationBuilder aggregation = new AggregationBuilder("*") |
| 284 | + .groupBy("@conversation_id", Reducers.count().as("count")) |
| 285 | + .limit(0, config.getMaxConversationIds()); // Use configured limit |
| 286 | + |
| 287 | + AggregationResult result = jedis.ftAggregate(config.getIndexName(), aggregation); |
| 288 | + |
| 289 | + List<String> conversationIds = new ArrayList<>(); |
| 290 | + result.getResults().forEach(row -> { |
| 291 | + String conversationId = (String) row.get("conversation_id"); |
| 292 | + if (conversationId != null) { |
| 293 | + conversationIds.add(conversationId); |
| 294 | + } |
| 295 | + }); |
| 296 | + |
| 297 | + if (logger.isDebugEnabled()) { |
| 298 | + logger.debug("Found {} unique conversation IDs using Redis aggregation", conversationIds.size()); |
| 299 | + conversationIds.forEach(id -> logger.debug("Conversation ID: {}", id)); |
| 300 | + } |
| 301 | + |
| 302 | + return conversationIds; |
| 303 | + } |
| 304 | + catch (Exception e) { |
| 305 | + logger.warn("Error executing Redis aggregation for conversation IDs, falling back to client-side approach", |
| 306 | + e); |
| 307 | + return findConversationIdsLegacy(); |
| 308 | + } |
| 309 | + } |
| 310 | + |
| 311 | + /** |
| 312 | + * Fallback method to find conversation IDs if aggregation fails. This is less |
| 313 | + * efficient as it requires fetching all documents and deduplicating on the client |
| 314 | + * side. |
| 315 | + * @return a list of unique conversation IDs |
| 316 | + */ |
| 317 | + private List<String> findConversationIdsLegacy() { |
| 318 | + // Keep the current implementation as a fallback |
| 319 | + String queryStr = "*"; // Match all documents |
| 320 | + Query query = new Query(queryStr); |
| 321 | + query.limit(0, config.getMaxConversationIds()); // Use configured limit |
| 322 | + |
| 323 | + SearchResult result = jedis.ftSearch(config.getIndexName(), query); |
| 324 | + |
| 325 | + // Use a Set to deduplicate conversation IDs |
| 326 | + Set<String> conversationIds = new HashSet<>(); |
| 327 | + |
| 328 | + result.getDocuments().forEach(doc -> { |
| 329 | + if (doc.get("$") != null) { |
| 330 | + JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class); |
| 331 | + if (json.has("conversation_id")) { |
| 332 | + conversationIds.add(json.get("conversation_id").getAsString()); |
| 333 | + } |
| 334 | + } |
| 335 | + }); |
| 336 | + |
| 337 | + if (logger.isDebugEnabled()) { |
| 338 | + logger.debug("Found {} unique conversation IDs using legacy method", conversationIds.size()); |
| 339 | + } |
| 340 | + |
| 341 | + return new ArrayList<>(conversationIds); |
| 342 | + } |
| 343 | + |
| 344 | + /** |
| 345 | + * Finds all messages for a given conversation ID. Uses the configured maximum |
| 346 | + * messages per conversation limit to avoid exceeding Redis limits. |
| 347 | + * @param conversationId the conversation ID to find messages for |
| 348 | + * @return a list of messages for the conversation |
| 349 | + */ |
| 350 | + @Override |
| 351 | + public List<Message> findByConversationId(String conversationId) { |
| 352 | + // Reuse existing get method with the configured limit |
| 353 | + return get(conversationId, config.getMaxMessagesPerConversation()); |
| 354 | + } |
| 355 | + |
| 356 | + @Override |
| 357 | + public void saveAll(String conversationId, List<Message> messages) { |
| 358 | + // First clear any existing messages for this conversation |
| 359 | + clear(conversationId); |
| 360 | + |
| 361 | + // Then add all the new messages |
| 362 | + add(conversationId, messages); |
| 363 | + } |
| 364 | + |
| 365 | + @Override |
| 366 | + public void deleteByConversationId(String conversationId) { |
| 367 | + // Reuse existing clear method |
| 368 | + clear(conversationId); |
| 369 | + } |
| 370 | + |
190 | 371 | /** |
191 | 372 | * Builder for RedisChatMemory configuration. |
192 | 373 | */ |
|
0 commit comments