Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 101 additions & 2 deletions js/src/graph_trajectory/tests/graph_trajectory_utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
/* eslint-disable no-promise-executor-return */
import { expect, test } from "vitest";
import { Annotation, StateGraph, MemorySaver } from "@langchain/langgraph";
import {
Annotation,
StateGraph,
MemorySaver,
Command,
Send,
} from "@langchain/langgraph";

import { extractLangGraphTrajectoryFromThread } from "../utils.js";

test("trajectory match", async () => {
test("extract trajectory", async () => {
const checkpointer = new MemorySaver();

const inner = new StateGraph(
Expand Down Expand Up @@ -94,3 +100,96 @@ test("trajectory match", async () => {
},
});
});

test("extract trajectory from graph with Command", async () => {
const checkpointer = new MemorySaver();

const graph = new StateGraph(
Annotation.Root({
items: Annotation<string[]>({
reducer: (a, b) => a.concat(b),
default: () => [],
}),
processedCount: Annotation<number>({
reducer: (_, b) => b,
default: () => 0,
}),
})
)
.addNode(
"dispatcher",
(state) => {
// Use Command with Send to route to multiple processing nodes dynamically
const sends = state.items.map(
(item, index) =>
new Send(`process_${index % 2}`, { items: [item], index })
);
return new Command({
update: { processedCount: state.items.length },
goto: sends,
});
},
{
ends: ["process_0", "process_1"],
}
)
.addNode("process_0", (state) => {
return { items: [`processed_0: ${state.items?.join(", ")}`] };
})
.addNode("process_1", (state) => {
return { items: [`processed_1: ${state.items?.join(", ")}`] };
})
.addNode("aggregator", (state) => {
return { items: [`final count: ${state.processedCount}`] };
})
.addEdge("__start__", "dispatcher")
.addEdge(["process_0", "process_1"], "aggregator")
.compile({ checkpointer });

const config = { configurable: { thread_id: "3" } };

await graph.invoke(
{
items: ["task1", "task2", "task3"],
},
config
);

const trajectory = await extractLangGraphTrajectoryFromThread(graph, config);

expect(trajectory).toEqual({
inputs: [
{
__start__: {
items: ["task1", "task2", "task3"],
},
},
],
outputs: {
results: [
{
items: [
"task1",
"task2",
"task3",
"processed_0: task1",
"processed_1: task2",
"processed_0: task3",
"final count: 3",
],
processedCount: 3,
},
],
steps: [
[
"__start__",
"dispatcher",
"process_0",
"process_1",
"process_0",
"aggregator",
],
],
},
});
});
6 changes: 4 additions & 2 deletions python/agentevals/graph_trajectory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ def extract_langgraph_trajectory_from_snapshots(
if is_acc_steps:
if snapshot.metadata is not None and snapshot.metadata["source"] == "input":
inputs.append(snapshot.metadata["writes"])
elif i + 1 < len(snapshot_list) and any(
t.interrupts for t in snapshot_list[i + 1].tasks
elif (
i + 1 < len(snapshot_list)
and snapshot_list[i + 1].tasks
and any(t.interrupts for t in snapshot_list[i + 1].tasks)
):
inputs.append("__resuming__") # type: ignore
inputs.reverse()
Expand Down
92 changes: 92 additions & 0 deletions python/tests/graph_trajectory/test_graph_trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import operator
import time

from langgraph.types import Command, Send
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver

Expand Down Expand Up @@ -182,3 +183,94 @@ def outer_2(state: State):
},
},
)["score"]


@pytest.mark.langsmith
def test_trajectory_match_with_command():
checkpointer = MemorySaver()

class State(TypedDict):
items: Annotated[list[str], lambda x, y: x + y]
processedCount: Annotated[int, lambda x, y: y]

def dispatcher(state: State):
# Use Command with Send to route to multiple processing nodes dynamically
sends = [
Send(f"process_{index % 2}", {"items": [item], "index": index})
for index, item in enumerate(state["items"])
]
return Command(
update={"processedCount": len(state["items"])},
goto=sends,
)

def process_0(state: State):
return {"items": [f"processed_0: {', '.join(state['items'])}"]}

def process_1(state: State):
return {"items": [f"processed_1: {', '.join(state['items'])}"]}

def aggregator(state: State):
return {"items": [f"final count: {state['processedCount']}"]}

graph = StateGraph(State)
graph.add_node("dispatcher", dispatcher)
graph.add_node("process_0", process_0)
graph.add_node("process_1", process_1)
graph.add_node("aggregator", aggregator)

graph.add_edge("__start__", "dispatcher")
graph.add_edge(["process_0", "process_1"], "aggregator")

app = graph.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "3"}}

app.invoke(
{
"items": ["task1", "task2", "task3"],
},
config,
)

trajectory = extract_langgraph_trajectory_from_thread(app, config)

assert exact_match(
outputs=trajectory,
reference_outputs={
"inputs": [
{
"__start__": {
"items": ["task1", "task2", "task3"],
},
},
],
"outputs": {
"inputs": [],
"results": [
{
"items": [
"task1",
"task2",
"task3",
"processed_0: task1",
"processed_1: task2",
"processed_0: task3",
"final count: 3",
],
"processedCount": 3,
},
],
"steps": [
[
"__start__",
"dispatcher",
"process_0",
"process_1",
"process_0",
"aggregator",
],
],
},
},
)["score"]