Skip to content

fix: deserialize LangChain messages in async checkpoint operations (#85, #87) #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

aarongrtech
Copy link
Contributor

Fixes message deserialization when calling aget_state()

@bsbodden bsbodden self-requested a review August 15, 2025 00:12
Copy link
Contributor

@bsbodden bsbodden left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this contribution! Your fix correctly addresses the message deserialization issue in AsyncRedisSaver.aget_tuple(). The addition of self._recursive_deserialize() to both code paths (with and without parent checkpoint) is correct.

I noticed that the alist() method also needs the same deserialization fix. While your PR correctly addresses the main issue in aget_tuple(), can you apply the same fix to the alist() method around lines 710-712 where it processes channel_values. This would ensure complete consistency across all async methods.

The relevant section in alist() that needs updating:

# Current (missing deserialization)
channel_values = checkpoint_dict.get("channel_values", {})

# Should be (with deserialization)
channel_values = self._recursive_deserialize(checkpoint_dict.get("channel_values", {}))

I also created a test file you can use. See the test file attached in a separate comment

@bsbodden
Copy link
Contributor

Test file:

"""Tests for PR #87 - Async message deserialization fix.

This test suite validates that AsyncRedisSaver properly deserializes
LangChain messages
"""

from typing import Any, Dict, List
from uuid import uuid4

import pytest
from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    HumanMessage,
    SystemMessage,
    ToolMessage,
)
from langgraph.checkpoint.base import (
    CheckpointTuple,
    create_checkpoint,
    empty_checkpoint,
)

from langgraph.checkpoint.redis import AsyncRedisSaver, RedisSaver


@pytest.mark.asyncio
async def test_pr87_fix_approach(redis_url: str):
    """Test if PR #87's fix approach using _recursive_deserialize works."""
    async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
        # Test if _recursive_deserialize is accessible (inherited from BaseRedisSaver)
        assert hasattr(
            saver, "_recursive_deserialize"
        ), "Missing _recursive_deserialize method"

        # Create test data - raw channel values as they come from aget_channel_values
        raw_channel_values = {
            "messages": [
                {
                    "lc": 1,
                    "type": "constructor",
                    "id": ["langchain", "schema", "messages", "HumanMessage"],
                    "kwargs": {
                        "content": "Test message",
                        "type": "human",
                        "id": "test-123",
                    },
                }
            ]
        }

        # Test the PR #87 approach
        deserialized = saver._recursive_deserialize(raw_channel_values)

        # Verify the fix works
        assert isinstance(deserialized["messages"][0], HumanMessage)
        assert deserialized["messages"][0].content == "Test message"
        assert deserialized["messages"][0].id == "test-123"


@pytest.mark.asyncio
async def test_async_deserializes_langchain_messages(redis_url: str):
    """Test that AsyncRedisSaver properly deserializes LangChain message objects.

    This is the core test for PR #87 - verifies that messages are returned as
    proper Message objects, not as serialized dictionaries.
    """
    async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
        thread_id = str(uuid4())

        # Create checkpoint with various message types
        messages = [
            SystemMessage(content="You are a helpful assistant.", id="sys-1"),
            HumanMessage(content="What's the weather like?", id="human-1"),
            AIMessage(content="I'll help you check the weather.", id="ai-1"),
            ToolMessage(
                content="Weather data retrieved",
                tool_call_id="call-1",
                name="weather_tool",
            ),
        ]

        checkpoint = create_checkpoint(
            checkpoint=empty_checkpoint(), channels={"messages": messages}, step=1
        )
        checkpoint["channel_values"]["messages"] = messages

        config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}

        # Save checkpoint
        await saver.aput(
            config, checkpoint, {"source": "test", "step": 1, "writes": {}}, {}
        )

        # Load checkpoint
        loaded_tuple = await saver.aget_tuple(config)

        assert loaded_tuple is not None
        loaded_messages = loaded_tuple.checkpoint["channel_values"]["messages"]

        # Verify all messages are properly deserialized
        assert len(loaded_messages) == 4
        assert isinstance(loaded_messages[0], SystemMessage)
        assert isinstance(loaded_messages[1], HumanMessage)
        assert isinstance(loaded_messages[2], AIMessage)
        assert isinstance(loaded_messages[3], ToolMessage)

        # Verify content and IDs
        assert loaded_messages[0].content == "You are a helpful assistant."
        assert loaded_messages[1].content == "What's the weather like?"
        assert loaded_messages[1].id == "human-1"
        assert loaded_messages[2].content == "I'll help you check the weather."
        assert loaded_messages[2].id == "ai-1"
        assert loaded_messages[3].content == "Weather data retrieved"
        assert loaded_messages[3].tool_call_id == "call-1"


@pytest.mark.asyncio
async def test_async_handles_serialized_langchain_format(redis_url: str):
    """Test that async handles the serialized LangChain format that causes MESSAGE_COERCION_FAILURE.

    This tests the specific format from issue #85:
    {'lc': 1, 'type': 'constructor', 'id': [...], 'kwargs': {...}}
    """
    async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
        thread_id = str(uuid4())

        # This is the format that gets stored in Redis
        serialized_messages = [
            {
                "lc": 1,
                "type": "constructor",
                "id": ["langchain", "schema", "messages", "SystemMessage"],
                "kwargs": {
                    "content": "System prompt",
                    "type": "system",
                },
            },
            {
                "lc": 1,
                "type": "constructor",
                "id": ["langchain", "schema", "messages", "HumanMessage"],
                "kwargs": {"content": "User input", "type": "human", "id": "msg-123"},
            },
            {
                "lc": 1,
                "type": "constructor",
                "id": ["langchain", "schema", "messages", "AIMessage"],
                "kwargs": {
                    "content": "Assistant response",
                    "type": "ai",
                    "id": "msg-456",
                },
            },
        ]

        # Simulate what happens when checkpoint is saved with serialized messages
        checkpoint = create_checkpoint(
            checkpoint=empty_checkpoint(),
            channels={"messages": serialized_messages},
            step=1,
        )
        checkpoint["channel_values"]["messages"] = serialized_messages

        config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}

        # Manually save the checkpoint to simulate the serialized state
        await saver.aput(
            config, checkpoint, {"source": "test", "step": 1, "writes": {}}, {}
        )

        # Load and verify deserialization
        loaded_tuple = await saver.aget_tuple(config)

        assert loaded_tuple is not None
        loaded_messages = loaded_tuple.checkpoint["channel_values"]["messages"]

        # Messages should be deserialized to proper objects
        assert len(loaded_messages) == 3
        assert isinstance(loaded_messages[0], SystemMessage)
        assert isinstance(loaded_messages[1], HumanMessage)
        assert isinstance(loaded_messages[2], AIMessage)

        # Verify content
        assert loaded_messages[0].content == "System prompt"
        assert loaded_messages[1].content == "User input"
        assert loaded_messages[1].id == "msg-123"
        assert loaded_messages[2].content == "Assistant response"
        assert loaded_messages[2].id == "msg-456"


@pytest.mark.asyncio
async def test_async_get_channel_values_directly(redis_url: str):
    """Test aget_channel_values method directly to ensure proper deserialization."""
    async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
        thread_id = str(uuid4())

        # Create checkpoint with messages
        messages = [
            HumanMessage(content="Direct test", id="msg-1"),
            AIMessage(content="Response", id="msg-2"),
        ]

        checkpoint = create_checkpoint(
            checkpoint=empty_checkpoint(), channels={"messages": messages}, step=1
        )
        checkpoint["channel_values"]["messages"] = messages

        config = {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_ns": "",
                "checkpoint_id": "test-checkpoint",
            }
        }

        saved_config = await saver.aput(
            config, checkpoint, {"source": "test", "step": 1, "writes": {}}, {}
        )

        # Call aget_channel_values directly
        channel_values = await saver.aget_channel_values(
            thread_id=thread_id,
            checkpoint_ns="",
            checkpoint_id=saved_config["configurable"]["checkpoint_id"],
        )

        # Note: Without PR #87, aget_channel_values returns raw serialized data
        # With PR #87, aget_tuple deserializes it, but aget_channel_values itself doesn't
        # This test documents the current behavior
        assert "messages" in channel_values
        loaded_messages = channel_values["messages"]

        # This will fail without additional fixes to aget_channel_values itself
        # The PR #87 fix is in aget_tuple, not aget_channel_values
        # Keeping this test to document the behavior
        if isinstance(loaded_messages[0], dict):
            # Without the fix - raw serialized format
            assert loaded_messages[0].get("lc") == 1
            assert loaded_messages[0].get("type") == "constructor"
        else:
            # With a complete fix
            assert isinstance(loaded_messages[0], HumanMessage)
            assert isinstance(loaded_messages[1], AIMessage)


@pytest.mark.asyncio
async def test_async_sync_parity(redis_url: str):
    """Test that async and sync implementations return the same deserialized messages."""
    thread_id = str(uuid4())

    # Create test messages
    messages = [
        SystemMessage(content="System message"),
        HumanMessage(content="User message", id="user-1"),
        AIMessage(content="AI response", id="ai-1"),
    ]

    checkpoint = create_checkpoint(
        checkpoint=empty_checkpoint(), channels={"messages": messages}, step=1
    )
    checkpoint["channel_values"]["messages"] = messages

    config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}

    # Test with sync saver
    sync_saver = RedisSaver(redis_url)
    sync_saver.setup()

    sync_saver.put(config, checkpoint, {"source": "test", "step": 1, "writes": {}}, {})

    sync_tuple = sync_saver.get_tuple(config)
    sync_messages = sync_tuple.checkpoint["channel_values"]["messages"]

    # Test with async saver using the same checkpoint
    async with AsyncRedisSaver.from_conn_string(redis_url) as async_saver:
        async_tuple = await async_saver.aget_tuple(config)
        async_messages = async_tuple.checkpoint["channel_values"]["messages"]

    # Both should return the same deserialized messages
    assert len(sync_messages) == len(async_messages)

    for sync_msg, async_msg in zip(sync_messages, async_messages):
        assert type(sync_msg) == type(async_msg)
        assert sync_msg.content == async_msg.content
        if hasattr(sync_msg, "id") and sync_msg.id:
            assert sync_msg.id == async_msg.id

        # Both should be proper Message objects, not dicts
        assert isinstance(sync_msg, BaseMessage)
        assert isinstance(async_msg, BaseMessage)


@pytest.mark.asyncio
async def test_async_with_parent_checkpoint(redis_url: str):
    """Test deserialization works when loading checkpoint with parent (pending_sends path)."""
    async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
        thread_id = str(uuid4())

        # Create parent checkpoint
        parent_messages = [HumanMessage(content="First message", id="msg-1")]

        parent_checkpoint = create_checkpoint(
            checkpoint=empty_checkpoint(),
            channels={"messages": parent_messages},
            step=1,
        )
        parent_checkpoint["channel_values"]["messages"] = parent_messages
        parent_checkpoint["id"] = "parent-checkpoint-id"

        parent_config = {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_ns": "",
                "checkpoint_id": "parent-checkpoint-id",
            }
        }

        await saver.aput(
            parent_config,
            parent_checkpoint,
            {"source": "test", "step": 1, "writes": {}},
            {},
        )

        # Create child checkpoint with parent reference
        child_messages = [
            HumanMessage(content="First message", id="msg-1"),
            AIMessage(content="Response", id="msg-2"),
        ]

        child_checkpoint = create_checkpoint(
            checkpoint=empty_checkpoint(), channels={"messages": child_messages}, step=2
        )
        child_checkpoint["channel_values"]["messages"] = child_messages
        child_checkpoint["id"] = "child-checkpoint-id"
        child_checkpoint["parent_checkpoint_id"] = "parent-checkpoint-id"

        child_config = {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_ns": "",
                "checkpoint_id": "parent-checkpoint-id",  # This becomes parent during save
            }
        }

        # Save child with parent reference
        saved_config = await saver.aput(
            child_config,
            child_checkpoint,
            {"source": "test", "step": 2, "writes": {}, "parent_config": parent_config},
            {},
        )

        # Load child checkpoint (should trigger the parent checkpoint path)
        loaded_tuple = await saver.aget_tuple(saved_config)

        assert loaded_tuple is not None
        assert loaded_tuple.parent_config is not None  # Has parent

        loaded_messages = loaded_tuple.checkpoint["channel_values"]["messages"]

        # Messages should be properly deserialized even with parent
        assert len(loaded_messages) == 2
        assert isinstance(loaded_messages[0], HumanMessage)
        assert isinstance(loaded_messages[1], AIMessage)
        assert loaded_messages[0].content == "First message"
        assert loaded_messages[1].content == "Response"


@pytest.mark.asyncio
async def test_async_nested_message_structures(redis_url: str):
    """Test deserialization of complex nested message structures."""
    async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
        thread_id = str(uuid4())

        # Create checkpoint with nested structures
        complex_channel_values = {
            "messages": [
                HumanMessage(content="Question", id="q-1"),
                AIMessage(
                    content="Answer",
                    id="a-1",
                    additional_kwargs={"function_call": {"name": "tool", "args": {}}},
                ),
            ],
            "other_data": {
                "nested": {"messages": [SystemMessage(content="Nested system message")]}
            },
            "list_of_lists": [
                [HumanMessage(content="Deep message")],
                [AIMessage(content="Deep response")],
            ],
        }

        checkpoint = create_checkpoint(
            checkpoint=empty_checkpoint(), channels=complex_channel_values, step=1
        )
        checkpoint["channel_values"] = complex_channel_values

        config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}

        await saver.aput(
            config, checkpoint, {"source": "test", "step": 1, "writes": {}}, {}
        )

        # Load and verify nested structures are deserialized
        loaded_tuple = await saver.aget_tuple(config)

        assert loaded_tuple is not None
        loaded_values = loaded_tuple.checkpoint["channel_values"]

        # Top-level messages
        assert isinstance(loaded_values["messages"][0], HumanMessage)
        assert isinstance(loaded_values["messages"][1], AIMessage)
        assert (
            loaded_values["messages"][1].additional_kwargs["function_call"]["name"]
            == "tool"
        )

        # Nested messages
        assert isinstance(
            loaded_values["other_data"]["nested"]["messages"][0], SystemMessage
        )

        # List of lists
        assert isinstance(loaded_values["list_of_lists"][0][0], HumanMessage)
        assert isinstance(loaded_values["list_of_lists"][1][0], AIMessage)


@pytest.mark.asyncio
async def test_async_mixed_content_types(redis_url: str):
    """Test that non-message content is preserved while messages are deserialized."""
    async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
        thread_id = str(uuid4())

        # Mix messages with other data types
        mixed_values = {
            "messages": [
                HumanMessage(content="User input", id="user-1"),
                {"role": "system", "content": "Plain dict"},  # Not a message object
                AIMessage(content="Response", id="ai-1"),
            ],
            "metadata": {"key": "value"},
            "count": 42,
            "flags": [True, False, True],
        }

        checkpoint = create_checkpoint(
            checkpoint=empty_checkpoint(), channels=mixed_values, step=1
        )
        checkpoint["channel_values"] = mixed_values

        config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}

        await saver.aput(
            config, checkpoint, {"source": "test", "step": 1, "writes": {}}, {}
        )

        # Load and verify mixed content
        loaded_tuple = await saver.aget_tuple(config)

        assert loaded_tuple is not None
        loaded_values = loaded_tuple.checkpoint["channel_values"]

        # Messages are deserialized
        assert isinstance(loaded_values["messages"][0], HumanMessage)
        assert isinstance(loaded_values["messages"][2], AIMessage)

        # Plain dict remains a dict
        assert isinstance(loaded_values["messages"][1], dict)
        assert loaded_values["messages"][1]["role"] == "system"

        # Other data types are preserved
        assert loaded_values["metadata"] == {"key": "value"}
        assert loaded_values["count"] == 42
        assert loaded_values["flags"] == [True, False, True]


@pytest.mark.asyncio
@pytest.mark.skip(
    reason="alist() deserialization not addressed in PR #87 - needs separate fix"
)
async def test_async_alist_with_deserialization(redis_url: str):
    """Test that alist() also properly deserializes messages.

    NOTE: This test is skipped because PR #87 only addresses aget_tuple()
    deserialization. The alist() method also needs the same fix applied
    to properly deserialize LangChain messages. This should be addressed
    in a follow-up PR.
    """
    async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
        thread_id = str(uuid4())

        # Create multiple checkpoints
        for i in range(3):
            messages = [HumanMessage(content=f"Message {i}", id=f"msg-{i}")]

            checkpoint = create_checkpoint(
                checkpoint=empty_checkpoint(), channels={"messages": messages}, step=i
            )
            checkpoint["channel_values"]["messages"] = messages

            config = {
                "configurable": {
                    "thread_id": thread_id,
                    "checkpoint_ns": "",
                    "checkpoint_id": f"checkpoint-{i}",
                }
            }

            await saver.aput(
                config, checkpoint, {"source": "test", "step": i, "writes": {}}, {}
            )

        # List all checkpoints
        checkpoints: List[CheckpointTuple] = []
        async for checkpoint_tuple in saver.alist(
            {"configurable": {"thread_id": thread_id}}
        ):
            checkpoints.append(checkpoint_tuple)

        # Verify all checkpoints have deserialized messages
        assert len(checkpoints) == 3

        for checkpoint_tuple in checkpoints:
            messages = checkpoint_tuple.checkpoint["channel_values"]["messages"]
            assert len(messages) == 1
            assert isinstance(messages[0], HumanMessage)
            # Messages should be objects, not dicts
            assert hasattr(messages[0], "content")
            assert not isinstance(messages[0], dict)


@pytest.mark.asyncio
async def test_regression_issue_85(redis_url: str):
    """Regression test for issue #85 - MESSAGE_COERCION_FAILURE.

    This reproduces the exact scenario from the issue where sending a second
    message with the same thread_id causes an error due to improper deserialization.
    """
    async with AsyncRedisSaver.from_conn_string(redis_url) as saver:
        thread_id = "test-thread-85"

        # First message (should work)
        first_messages = [HumanMessage(content="First message", id="msg-1")]

        first_checkpoint = create_checkpoint(
            checkpoint=empty_checkpoint(), channels={"messages": first_messages}, step=1
        )
        first_checkpoint["channel_values"]["messages"] = first_messages

        first_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}

        await saver.aput(
            first_config,
            first_checkpoint,
            {"source": "user", "step": 1, "writes": {}},
            {},
        )

        # Load first checkpoint - should work
        loaded_first = await saver.aget_tuple(first_config)
        assert loaded_first is not None
        first_loaded_messages = loaded_first.checkpoint["channel_values"]["messages"]
        assert isinstance(first_loaded_messages[0], HumanMessage)

        # Second message with same thread_id (this would fail without the fix)
        second_messages = [
            HumanMessage(content="First message", id="msg-1"),
            AIMessage(content="Response", id="msg-2"),
            HumanMessage(content="Second message", id="msg-3"),
        ]

        second_checkpoint = create_checkpoint(
            checkpoint=empty_checkpoint(),
            channels={"messages": second_messages},
            step=2,
        )
        second_checkpoint["channel_values"]["messages"] = second_messages

        second_config = {
            "configurable": {
                "thread_id": thread_id,  # Same thread_id
                "checkpoint_ns": "",
            }
        }

        # This should NOT raise MESSAGE_COERCION_FAILURE
        await saver.aput(
            second_config,
            second_checkpoint,
            {"source": "user", "step": 2, "writes": {}},
            {},
        )

        # Load second checkpoint - should work with proper deserialization
        loaded_second = await saver.aget_tuple(second_config)
        assert loaded_second is not None
        second_loaded_messages = loaded_second.checkpoint["channel_values"]["messages"]

        # All messages should be properly deserialized
        assert len(second_loaded_messages) == 3
        assert isinstance(second_loaded_messages[0], HumanMessage)
        assert isinstance(second_loaded_messages[1], AIMessage)
        assert isinstance(second_loaded_messages[2], HumanMessage)

        # Content should be correct
        assert second_loaded_messages[0].content == "First message"
        assert second_loaded_messages[1].content == "Response"
        assert second_loaded_messages[2].content == "Second message"

@bsbodden bsbodden self-assigned this Aug 15, 2025
@aarongrtech aarongrtech requested a review from bsbodden August 15, 2025 02:33
@aarongrtech
Copy link
Contributor Author

Changes applied.
Not a problem, sir, happy to help!

Copy link
Contributor

@bsbodden bsbodden left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@bsbodden bsbodden changed the title Fixes asynchronous message deserialization fix: deserialize LangChain messages in async checkpoint operations (#85, #87) Aug 15, 2025
@bsbodden bsbodden merged commit cf6a202 into redis-developer:main Aug 15, 2025
20 checks passed
@aarongrtech aarongrtech deleted the asynchronous-message-deserialization branch August 16, 2025 07:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants