diff --git a/js/src/graph_trajectory/tests/graph_trajectory_utils.test.ts b/js/src/graph_trajectory/tests/graph_trajectory_utils.test.ts index 8892261..afe270b 100644 --- a/js/src/graph_trajectory/tests/graph_trajectory_utils.test.ts +++ b/js/src/graph_trajectory/tests/graph_trajectory_utils.test.ts @@ -1,10 +1,16 @@ /* eslint-disable no-promise-executor-return */ import { expect, test } from "vitest"; -import { Annotation, StateGraph, MemorySaver } from "@langchain/langgraph"; +import { + Annotation, + StateGraph, + MemorySaver, + Command, + Send, +} from "@langchain/langgraph"; import { extractLangGraphTrajectoryFromThread } from "../utils.js"; -test("trajectory match", async () => { +test("extract trajectory", async () => { const checkpointer = new MemorySaver(); const inner = new StateGraph( @@ -94,3 +100,96 @@ test("trajectory match", async () => { }, }); }); + +test("extract trajectory from graph with Command", async () => { + const checkpointer = new MemorySaver(); + + const graph = new StateGraph( + Annotation.Root({ + items: Annotation({ + reducer: (a, b) => a.concat(b), + default: () => [], + }), + processedCount: Annotation({ + reducer: (_, b) => b, + default: () => 0, + }), + }) + ) + .addNode( + "dispatcher", + (state) => { + // Use Command with Send to route to multiple processing nodes dynamically + const sends = state.items.map( + (item, index) => + new Send(`process_${index % 2}`, { items: [item], index }) + ); + return new Command({ + update: { processedCount: state.items.length }, + goto: sends, + }); + }, + { + ends: ["process_0", "process_1"], + } + ) + .addNode("process_0", (state) => { + return { items: [`processed_0: ${state.items?.join(", ")}`] }; + }) + .addNode("process_1", (state) => { + return { items: [`processed_1: ${state.items?.join(", ")}`] }; + }) + .addNode("aggregator", (state) => { + return { items: [`final count: ${state.processedCount}`] }; + }) + .addEdge("__start__", "dispatcher") + .addEdge(["process_0", "process_1"], "aggregator") + .compile({ checkpointer }); + + const config = { configurable: { thread_id: "3" } }; + + await graph.invoke( + { + items: ["task1", "task2", "task3"], + }, + config + ); + + const trajectory = await extractLangGraphTrajectoryFromThread(graph, config); + + expect(trajectory).toEqual({ + inputs: [ + { + __start__: { + items: ["task1", "task2", "task3"], + }, + }, + ], + outputs: { + results: [ + { + items: [ + "task1", + "task2", + "task3", + "processed_0: task1", + "processed_1: task2", + "processed_0: task3", + "final count: 3", + ], + processedCount: 3, + }, + ], + steps: [ + [ + "__start__", + "dispatcher", + "process_0", + "process_1", + "process_0", + "aggregator", + ], + ], + }, + }); +}); diff --git a/python/agentevals/graph_trajectory/utils.py b/python/agentevals/graph_trajectory/utils.py index 2da7b89..0d2906d 100644 --- a/python/agentevals/graph_trajectory/utils.py +++ b/python/agentevals/graph_trajectory/utils.py @@ -61,8 +61,10 @@ def extract_langgraph_trajectory_from_snapshots( if is_acc_steps: if snapshot.metadata is not None and snapshot.metadata["source"] == "input": inputs.append(snapshot.metadata["writes"]) - elif i + 1 < len(snapshot_list) and any( - t.interrupts for t in snapshot_list[i + 1].tasks + elif ( + i + 1 < len(snapshot_list) + and snapshot_list[i + 1].tasks + and any(t.interrupts for t in snapshot_list[i + 1].tasks) ): inputs.append("__resuming__") # type: ignore inputs.reverse() diff --git a/python/tests/graph_trajectory/test_graph_trajectory_utils.py b/python/tests/graph_trajectory/test_graph_trajectory_utils.py index 9f62867..7a9220f 100644 --- a/python/tests/graph_trajectory/test_graph_trajectory_utils.py +++ b/python/tests/graph_trajectory/test_graph_trajectory_utils.py @@ -11,6 +11,7 @@ import operator import time +from langgraph.types import Command, Send from langgraph.graph import StateGraph from langgraph.checkpoint.memory import MemorySaver @@ -182,3 +183,94 @@ def outer_2(state: State): }, }, )["score"] + + +@pytest.mark.langsmith +def test_trajectory_match_with_command(): + checkpointer = MemorySaver() + + class State(TypedDict): + items: Annotated[list[str], lambda x, y: x + y] + processedCount: Annotated[int, lambda x, y: y] + + def dispatcher(state: State): + # Use Command with Send to route to multiple processing nodes dynamically + sends = [ + Send(f"process_{index % 2}", {"items": [item], "index": index}) + for index, item in enumerate(state["items"]) + ] + return Command( + update={"processedCount": len(state["items"])}, + goto=sends, + ) + + def process_0(state: State): + return {"items": [f"processed_0: {', '.join(state['items'])}"]} + + def process_1(state: State): + return {"items": [f"processed_1: {', '.join(state['items'])}"]} + + def aggregator(state: State): + return {"items": [f"final count: {state['processedCount']}"]} + + graph = StateGraph(State) + graph.add_node("dispatcher", dispatcher) + graph.add_node("process_0", process_0) + graph.add_node("process_1", process_1) + graph.add_node("aggregator", aggregator) + + graph.add_edge("__start__", "dispatcher") + graph.add_edge(["process_0", "process_1"], "aggregator") + + app = graph.compile(checkpointer=checkpointer) + + config = {"configurable": {"thread_id": "3"}} + + app.invoke( + { + "items": ["task1", "task2", "task3"], + }, + config, + ) + + trajectory = extract_langgraph_trajectory_from_thread(app, config) + + assert exact_match( + outputs=trajectory, + reference_outputs={ + "inputs": [ + { + "__start__": { + "items": ["task1", "task2", "task3"], + }, + }, + ], + "outputs": { + "inputs": [], + "results": [ + { + "items": [ + "task1", + "task2", + "task3", + "processed_0: task1", + "processed_1: task2", + "processed_0: task3", + "final count: 3", + ], + "processedCount": 3, + }, + ], + "steps": [ + [ + "__start__", + "dispatcher", + "process_0", + "process_1", + "process_0", + "aggregator", + ], + ], + }, + }, + )["score"]