|
1 | 1 | import base64
|
2 | 2 | import binascii
|
| 3 | +import logging |
3 | 4 | import random
|
4 | 5 | from abc import abstractmethod
|
5 | 6 | from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, Union, cast
|
|
27 | 28 | from .jsonplus_redis import JsonPlusRedisSerializer
|
28 | 29 | from .types import IndexType, RedisClientType
|
29 | 30 |
|
| 31 | +logger = logging.getLogger(__name__) |
| 32 | + |
30 | 33 | REDIS_KEY_SEPARATOR = ":"
|
31 | 34 | CHECKPOINT_PREFIX = "checkpoint"
|
32 | 35 | CHECKPOINT_BLOB_PREFIX = "checkpoint_blob"
|
@@ -305,25 +308,78 @@ def _deserialize_channel_values(
|
305 | 308 |
|
306 | 309 | When channel values are stored inline in the checkpoint, they're in their
|
307 | 310 | serialized form. This method deserializes them back to their original types.
|
| 311 | +
|
| 312 | + This specifically handles LangChain message objects that may be stored in their |
| 313 | + serialized format: {'lc': 1, 'type': 'constructor', 'id': [...], 'kwargs': {...}} |
| 314 | + and ensures they are properly reconstructed as message objects. |
308 | 315 | """
|
309 | 316 | if not channel_values:
|
310 | 317 | return {}
|
311 | 318 |
|
312 |
| - # Apply recursive deserialization to handle nested structures and LangChain objects |
313 |
| - return self._recursive_deserialize(channel_values) |
| 319 | + try: |
| 320 | + # Apply recursive deserialization to handle nested structures and LangChain objects |
| 321 | + return self._recursive_deserialize(channel_values) |
| 322 | + except Exception as e: |
| 323 | + logger.warning( |
| 324 | + f"Error deserializing channel values, attempting recovery: {e}" |
| 325 | + ) |
| 326 | + # Attempt to recover by processing each channel individually |
| 327 | + recovered = {} |
| 328 | + for key, value in channel_values.items(): |
| 329 | + try: |
| 330 | + recovered[key] = self._recursive_deserialize(value) |
| 331 | + except Exception as inner_e: |
| 332 | + logger.error( |
| 333 | + f"Failed to deserialize channel '{key}': {inner_e}. " |
| 334 | + f"Value will be returned as-is." |
| 335 | + ) |
| 336 | + recovered[key] = value |
| 337 | + return recovered |
314 | 338 |
|
315 | 339 | def _recursive_deserialize(self, obj: Any) -> Any:
|
316 |
| - """Recursively deserialize LangChain objects and nested structures.""" |
| 340 | + """Recursively deserialize LangChain objects and nested structures. |
| 341 | +
|
| 342 | + This method specifically handles the deserialization of LangChain message objects |
| 343 | + that may be stored in their serialized format to prevent MESSAGE_COERCION_FAILURE. |
| 344 | +
|
| 345 | + Args: |
| 346 | + obj: The object to deserialize, which may be a dict, list, or primitive. |
| 347 | +
|
| 348 | + Returns: |
| 349 | + The deserialized object, with LangChain objects properly reconstructed. |
| 350 | + """ |
317 | 351 | if isinstance(obj, dict):
|
318 | 352 | # Check if this is a LangChain serialized object
|
319 | 353 | if obj.get("lc") in (1, 2) and obj.get("type") == "constructor":
|
320 |
| - # Use the serde's reviver to reconstruct the object |
321 |
| - if hasattr(self.serde, "_reviver"): |
322 |
| - return self.serde._reviver(obj) |
323 |
| - elif hasattr(self.serde, "_revive_if_needed"): |
324 |
| - return self.serde._revive_if_needed(obj) |
325 |
| - else: |
326 |
| - # Fallback: return as-is if serde doesn't have reviver |
| 354 | + try: |
| 355 | + # Use the serde's reviver to reconstruct the object |
| 356 | + if hasattr(self.serde, "_reviver"): |
| 357 | + return self.serde._reviver(obj) |
| 358 | + elif hasattr(self.serde, "_revive_if_needed"): |
| 359 | + return self.serde._revive_if_needed(obj) |
| 360 | + else: |
| 361 | + # Log warning if serde doesn't have reviver |
| 362 | + logger.warning( |
| 363 | + "Serializer does not have a reviver method. " |
| 364 | + "LangChain object may not be properly deserialized. " |
| 365 | + f"Object ID: {obj.get('id')}" |
| 366 | + ) |
| 367 | + return obj |
| 368 | + except Exception as e: |
| 369 | + # Provide detailed error message for debugging |
| 370 | + obj_id = obj.get("id", "unknown") |
| 371 | + obj_type = ( |
| 372 | + obj.get("id", ["unknown"])[-1] |
| 373 | + if isinstance(obj.get("id"), list) |
| 374 | + else "unknown" |
| 375 | + ) |
| 376 | + logger.error( |
| 377 | + f"Failed to deserialize LangChain object of type '{obj_type}'. " |
| 378 | + f"This may cause MESSAGE_COERCION_FAILURE. Error: {e}. " |
| 379 | + f"Object structure: lc={obj.get('lc')}, type={obj.get('type')}, " |
| 380 | + f"id={obj_id}" |
| 381 | + ) |
| 382 | + # Return the object as-is to prevent complete failure |
327 | 383 | return obj
|
328 | 384 | # Recursively process nested dicts
|
329 | 385 | return {k: self._recursive_deserialize(v) for k, v in obj.items()}
|
|
0 commit comments