diff --git a/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md b/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md index 624e580ad1..8ac2c31c42 100644 --- a/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md +++ b/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md @@ -443,3 +443,373 @@ gepa = dspy.GEPA( auto="medium" ) ``` + +## ReAct Component Optimization + +### What is optimize_react_components? + +Enable `optimize_react_components=True` to apply specialized optimization to `dspy.ReAct` modules while using default optimization for other modules. + +A [`dspy.ReAct`](../../learn/programming/tools.md#approach-1-using-dspyreact-fully-managed) module has three parts: a **react predictor** (iteratively reasons and selects tools), an **extract predictor** (extracts final answers from trajectories), and **tools** with their schemas. + +**What gets optimized for ReAct modules:** + +GEPA can improve textual components across all parts: +- **React instruction** - Guides reasoning and tool selection (always optimized) +- **Extract instruction** - Guides answer extraction from trajectories (optional) +- **Tool descriptions** - Describes what each tool does (optional) +- **Tool argument descriptions** - Describes tool parameters (optional) + +The reflection LM decides which optional components to improve based on observed failures. Non-ReAct modules in your program are optimized using GEPA's default signature optimization. + +**Why this matters:** + +Unlike optimizing signature instructions alone (which improves individual predictors), ReAct optimization improves the **entire agent workflow** - from initial reasoning through tool execution to final answer extraction. + +ReAct agents often fail when their components contradict each other. A clear tool description doesn't help if the react instruction never considers using that tool. GEPA analyzes execution traces to learn how all components should work together. + +### ReAct Optimization Prompt + +GEPA uses a specialized prompt to jointly optimize all ReAct components. The prompt receives complete ReAct trajectories and current component texts: + +```python +class GenerateImprovedReActDescriptionsFromFeedback(dspy.Signature): + """Improve a ReAct agent based on execution examples and feedback. + + These components are progressively optimized - refine what needs improvement. + Analyze the trajectories to identify successful patterns and failure causes. + Generate improved texts to help the agent succeed on similar tasks. + Place improved texts at their appropriate level of abstraction and/or specificity. + """ + + current_react_instruction = dspy.InputField( + desc="Current ReAct module instruction guiding the ReAct agent's reasoning and tool selection" + ) + current_extract_instruction = dspy.InputField( + desc="Current Extract module instruction for extracting final answers from trajectories" + ) + current_tools = dspy.InputField( + annotation=list[dspy.Tool], + desc="Available tools with their complete schemas" + ) + examples_with_feedback = dspy.InputField( + desc="Execution examples with feedback showing successes and failures" + ) + + improved_react_instruction: str | None = dspy.OutputField( + desc="ReAct instruction for reasoning and tool selection", + default=None + ) + improved_extract_instruction: str | None = dspy.OutputField( + desc="Extract instruction for answer extraction", + default=None + ) + # Note: Tool descriptions and arg descriptions are added dynamically via signature.append() + # with field descriptions like "Purpose of tool" and "Usage of parameter" +``` + +The reflection LM receives all current components and execution traces, then decides which components to improve. Tool-specific fields (`improved_tool_{name}_desc`, `improved_tool_{name}_arg_{param}_desc`) are generated dynamically for each tool and parameter. + +**Writing Metrics for ReAct Optimization** + +GEPA optimizes ReAct modules more effectively when metrics provide feedback about the agent's execution. Here's how to write metrics that help: + +```python +def react_metric(example, pred, trace=None, pred_name=None, pred_trace=None): + """Evaluate ReAct agent performance with trajectory feedback.""" + # Check if the answer is correct + answer_match = pred.answer == example.answer + score = 1.0 if answer_match else 0.0 + + # Provide feedback to help GEPA understand what happened + feedback = "Correct answer" if answer_match else "Incorrect answer" + + return dspy.Prediction(score=score, feedback=feedback) +``` + +You can make feedback more informative by examining the trajectory: + +```python +def react_metric_with_trajectory(example, pred, trace=None, pred_name=None, pred_trace=None): + """Evaluate with trajectory analysis.""" + # Check if the answer is correct + answer_match = pred.answer == example.answer + score = 1.0 if answer_match else 0.0 + + # Access the ReAct trajectory to understand agent behavior + trajectory = getattr(pred, 'trajectory', {}) + + # Extract tool names from trajectory (excluding 'finish') + tools_used = [] + for key in trajectory: + if key.startswith('tool_name_'): + tool_name = trajectory[key] + if tool_name != 'finish': + tools_used.append(tool_name) + + # Build feedback message + if answer_match: + feedback = "Correct answer" + else: + feedback = "Incorrect answer" + + if tools_used: + feedback += f". Tools: {', '.join(tools_used)}" + + return dspy.Prediction(score=score, feedback=feedback) +``` + +The trajectory contains the agent's step-by-step execution. Use it to provide feedback about: + +- **Tool selection**: Were appropriate tools chosen? +- **Reasoning quality**: Did the agent think through the problem? +- **Efficiency**: Were there unnecessary steps? + +The reflection LM uses your feedback to jointly improve react instructions, tool descriptions, and extraction logic. + +### How It Works + +When `optimize_react_components=True`, GEPA: + +1. **Discovers ReAct modules** - Finds all `dspy.ReAct` instances in your program (including nested modules) +2. **Extracts components** - Collects react instructions, extract instructions, and tool schemas from each ReAct module +3. **Routes to proposers** - Separates components by type and routes them appropriately: + - **With custom `instruction_proposer`**: Your custom proposer receives all components (both regular instructions and ReAct components) and handles the optimization logic + - **With default proposer**: Regular instructions use default instruction proposer, ReAct components use specialized `ReActModuleProposer` +4. **Optimizes jointly** - ReAct proposer improves all four components together based on execution feedback +5. **Applies updates** - Updates your ReAct modules with improved instructions and tool descriptions + +Non-ReAct modules (like `dspy.Predict` or `dspy.ChainOfThought`) continue using standard GEPA optimization. + +### When to Use optimize_react_components + +Enable `optimize_react_components=True` when you use `dspy.ReAct` in your program and need better agent performance. GEPA jointly optimizes all ReAct components (react instruction, extract instruction, tool descriptions, tool argument descriptions) based on execution feedback. Common scenarios: + +1. **Agent loops with repeated tool calls** - Agent keeps calling `web_search` multiple times with similar queries instead of synthesizing information. GEPA improves react instruction to encourage synthesis and tool descriptions to clarify when searches are sufficient. + +2. **Wrong tool selection** - Agent with `search` and `calculator` tools keeps searching when it should calculate, or vice versa. GEPA refines react instruction and tool descriptions to clarify "use search for factual queries, calculator for numerical analysis." + +3. **Agent gives up without trying tools** - Agent responds "I don't know" without using available tools that could answer the question. GEPA improves react instruction to be more proactive about tool usage. + +4. **Extraction failures** - Agent executes tools correctly but fails to extract the final answer from the trajectory. GEPA improves extract instruction to better identify and format answers from tool outputs. + +5. **Multi-agent delegation issues** - Parent agent has delegation tools to specialized sub-agents but doesn't understand when to use each. GEPA optimizes all ReAct components across both parent and sub-agent modules for coherent delegation. + +See the usage examples below for basic ReAct agents and multi-agent systems. + +### Usage Examples + +#### Basic ReAct Agent + +```python +import dspy + +def search_web(query: str) -> str: + return f"Search results for: {query}" + +def calculate(expression: str) -> float: + return eval(expression) + +# Create ReAct agent with tools (poor initial descriptions) +search_tool = dspy.Tool(search_web, name="search", desc="Finds things") +calc_tool = dspy.Tool(calculate, name="calculator", desc="Does calculations") + +agent = dspy.ReAct("question -> answer", tools=[search_tool, calc_tool]) + +# Enable tool optimization +gepa = dspy.GEPA( + metric=my_metric, + reflection_lm=dspy.LM(model="gpt-5-mini"), + optimize_react_components=True, + component_selector="all", # Optimize all components together + auto="medium" +) + +optimized_agent = gepa.compile(agent, trainset=train_examples, valset=val_examples) + +# View optimized tool descriptions +print("Optimized search tool:", optimized_agent.tools["search"].desc) +print("Optimized calculator tool:", optimized_agent.tools["calculator"].desc) +``` + +**Example output after optimization:** +``` +Optimized search tool: Use when you need to find current information, facts, or data + from external sources. Provide specific search queries to get relevant results. + +Optimized calculator tool: Use for arithmetic operations and mathematical expressions. + Accepts Python-compatible expressions with numbers and operators (+, -, *, /, **). + Do not use for date calculations or string manipulations. +``` + +#### Multi-Agent System + +GEPA automatically discovers and optimizes tools in nested agents: + +```python +import dspy + +def search_web(query: str) -> str: + return f"Search results for: {query}" + +def calculate(expression: str) -> float: + return eval(expression) + +search_tool = dspy.Tool(search_web, name="search", desc="Searches") +calc_tool = dspy.Tool(calculate, name="calculator", desc="Computes") + +class ResearchAssistant(dspy.Module): + def __init__(self): + super().__init__() + self.researcher = dspy.ReAct("query -> findings", tools=[search_tool]) + + def delegate_research(query: str) -> str: + return self.researcher(query=query).findings + + research_tool = dspy.Tool(delegate_research, name="research", desc="Helps with questions") + self.assistant = dspy.ReAct("question -> answer", tools=[research_tool, calc_tool]) + + def forward(self, question): + return self.assistant(question=question) + +# Optimizes ALL tools: calculator, research, search +gepa = dspy.GEPA( + metric=my_metric, + reflection_lm=dspy.LM(model="gpt-5-mini"), + optimize_react_components=True, + component_selector="all", + auto="medium" +) + +optimized_system = gepa.compile(ResearchAssistant(), trainset=train, valset=val) + +# View optimized nested tool descriptions +print(optimized_system.researcher.tools["search"].desc) +print(optimized_system.assistant.tools["research"].desc) +print(optimized_system.assistant.tools["calculator"].desc) +``` + +### Inspecting Optimized ReAct Components + +After optimization, all ReAct components are automatically updated in your program. Access them directly: + +```python +optimized_agent = gepa.compile(agent, trainset=train, valset=val) + +# ReAct instruction (guides reasoning and tool selection) +print("React instruction:", optimized_agent.react.signature.instructions) + +# Extract instruction (guides answer extraction from trajectory) +print("Extract instruction:", optimized_agent.extract.predict.signature.instructions) + +# Tool descriptions +for tool_name, tool in optimized_agent.tools.items(): + if tool_name != 'finish': # Skip the built-in finish tool + print(f"Tool '{tool_name}' description:", tool.desc) + # Tool argument descriptions + print(f" Argument descriptions:", tool.arg_desc) +``` + +### Custom Instruction Proposers and ReAct Optimization + +**Important:** When you provide a custom `instruction_proposer`, it receives ALL components (regular predictors AND ReAct modules). You must set `optimize_react_components=True` to enable ReAct module discovery and serialization, then handle the optimization logic yourself. + +**How it works internally:** + +1. **Component Discovery** - GEPA discovers components in your program: + - Regular predictors → keys like `"predict"`, `"chain_of_thought"` + - ReAct modules → keys like `"react_module"` or `"react_module:agent_name"` + +2. **ReAct Serialization** - When `optimize_react_components=True`, GEPA serializes ReAct modules as JSON: + ```json + { + "react": "instruction for reasoning and tool selection", + "extract": "instruction for answer extraction", + "tools": { + "tool_name": { + "desc": "what the tool does", + "args": {"param": {"type": "string"}}, + "arg_desc": {"param": "description of param"} + } + } + } + ``` + +3. **Custom Proposer Receives**: + - `candidate: dict[str, str]` - **All values are strings** + - Regular component: `candidate["predict"]` → `"Your instruction here"` + - ReAct component: `candidate["react_module"]` → `'{"react": "...", "extract": "...", "tools": {...}}'` (JSON as a string) + - `reflective_dataset: dict[str, list[ReflectiveExample]]` - **GEPA provides this** + - Contains execution traces: inputs, outputs (including full ReAct trajectory), and your metric's feedback + - For ReAct: `Generated_Outputs` includes the entire trajectory with all tool calls and reasoning + - Use this to understand what went wrong and guide your improvements + - `components_to_update: list[str]` - Component keys to optimize this round + +4. **Your Responsibility**: + - For ReAct components: Use `json.loads()` to parse, improve all 4 parts, use `json.dumps()` to return + - For regular components: Improve the instruction string directly + - Return `dict[str, str]` with same keys + +**What this means:** +- Your custom proposer receives ALL components: regular signatures AND ReAct modules +- GEPA still does discovery and JSON serialization, but YOU handle the optimization logic +- ReAct components are passed with keys like `"react_module"` or `"react_module:agent_name"` + +#### Implementing a Custom Proposer for ReAct + +If you need custom optimization logic beyond the default, you can build your own proposer. The best way to start is by looking at the reference implementation: [`ReActModuleProposer`](https://github.com/stanfordnlp/dspy/blob/main/dspy/teleprompt/gepa/instruction_proposal.py). + +**Understanding ReAct component structure** + +When GEPA optimizes ReAct modules, it serializes them as JSON strings containing all the pieces you can improve: + +```json +{ + "react": "instruction for reasoning and tool selection", + "extract": "instruction for answer extraction", + "tools": { + "search": { + "desc": "Search the web for information", + "args": {"query": {"type": "string"}}, + "arg_desc": {"query": "The search query to execute"} + } + } +} +``` + +**What you can improve:** +- **`react`** - How the agent reasons and decides which tools to use +- **`extract`** - How the agent extracts the final answer from execution results +- **`tools[*].desc`** - What each tool does and when to use it +- **`tools[*].arg_desc`** - What each parameter means and how to use it + +**What to preserve:** +- **`tools[*].args`** - The tool's parameter schema (types, required fields, etc.) + +**Your proposer's interface** + +Your custom proposer is a callable that receives component instructions and execution feedback, then returns improved versions: + +```python +def your_custom_proposer( + candidate: dict[str, str], # Current instructions for all components + reflective_dataset: dict[str, list], # Execution examples with feedback + components_to_update: list[str], # Which components to optimize this round +) -> dict[str, str]: # Return improved instructions + """ + For ReAct components: + - Use json.loads() to parse the JSON string + - Improve what needs fixing based on the feedback + - Use json.dumps() to serialize back + + For regular components: + - Just return the improved instruction string + """ + # Your optimization logic here + pass +``` + +**The reference shows how to:** +- Parse and rebuild the JSON structure +- Generate dynamic fields for tools/parameters +- Use execution feedback to guide improvements diff --git a/docs/docs/api/optimizers/GEPA/overview.md b/docs/docs/api/optimizers/GEPA/overview.md index 0125702bea..c36065b6aa 100644 --- a/docs/docs/api/optimizers/GEPA/overview.md +++ b/docs/docs/api/optimizers/GEPA/overview.md @@ -117,6 +117,12 @@ Practical Recipe for GEPA-Friendly Feedback: - **Multi-Objective Tasks** (e.g., PUPA): Decompose aggregate scores to reveal contributions from each objective, highlighting tradeoffs (e.g., quality vs. privacy). - **Stacked Pipelines** (e.g., code generation: parse → compile → run → profile → evaluate): Expose stage-specific failures; natural-language traces often suffice for LLM self-correction. +## ReAct Component Optimization + +GEPA can optimize ReAct modules holistically. When `optimize_react_components=True`, GEPA jointly optimizes all four components of ReAct modules: react instructions, extract instructions, tool descriptions, and tool argument descriptions. This helps agents make better decisions by learning from execution traces how all components work together. + +For details on how ReAct optimization works, when to use it, and usage examples, see [ReAct Component Optimization](GEPA_Advanced.md#react-component-optimization) in the Advanced Features guide. + ## Custom Instruction Proposal For advanced customization of GEPA's instruction proposal mechanism, including custom instruction proposers and component selectors, see [Advanced Features](GEPA_Advanced.md). diff --git a/dspy/teleprompt/gepa/gepa.py b/dspy/teleprompt/gepa/gepa.py index 87cbbf80a5..2b4302145e 100644 --- a/dspy/teleprompt/gepa/gepa.py +++ b/dspy/teleprompt/gepa/gepa.py @@ -1,4 +1,5 @@ import inspect +import json import logging import random from dataclasses import dataclass @@ -9,8 +10,15 @@ from gepa.proposer.reflective_mutation.base import ReflectionComponentSelector from dspy.clients.lm import LM +from dspy.predict.react import ReAct from dspy.primitives import Example, Module, Prediction -from dspy.teleprompt.gepa.gepa_utils import DspyAdapter, DSPyTrace, PredictorFeedbackFn, ScoreWithFeedback +from dspy.teleprompt.gepa.gepa_utils import ( + REACT_MODULE_PREFIX, + DspyAdapter, + DSPyTrace, + PredictorFeedbackFn, + ScoreWithFeedback, +) from dspy.teleprompt.teleprompt import Teleprompter from dspy.utils.annotation import experimental @@ -36,18 +44,18 @@ def __call__( - gold: The gold example. - pred: The predicted output. - trace: Optional. The trace of the program's execution. - - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which + - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which the feedback is being requested. - pred_trace: Optional. The trace of the target predictor's execution GEPA is seeking feedback for. Note the `pred_name` and `pred_trace` arguments. During optimization, GEPA will call the metric to obtain feedback for individual predictors being optimized. GEPA provides the name of the predictor in `pred_name` and the sub-trace (of the trace) corresponding to the predictor in `pred_trace`. - If available at the predictor level, the metric should return dspy.Prediction(score: float, feedback: str) corresponding + If available at the predictor level, the metric should return dspy.Prediction(score: float, feedback: str) corresponding to the predictor. If not available at the predictor level, the metric can also return a text feedback at the program level (using just the gold, pred and trace). - If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: + If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: f"This trajectory got a score of {score}." """ ... @@ -172,18 +180,18 @@ def metric( - gold: The gold example. - pred: The predicted output. - trace: Optional. The trace of the program's execution. - - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which + - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which the feedback is being requested. - pred_trace: Optional. The trace of the target predictor's execution GEPA is seeking feedback for. Note the `pred_name` and `pred_trace` arguments. During optimization, GEPA will call the metric to obtain feedback for individual predictors being optimized. GEPA provides the name of the predictor in `pred_name` and the sub-trace (of the trace) corresponding to the predictor in `pred_trace`. - If available at the predictor level, the metric should return {'score': float, 'feedback': str} corresponding + If available at the predictor level, the metric should return {'score': float, 'feedback': str} corresponding to the predictor. If not available at the predictor level, the metric can also return a text feedback at the program level (using just the gold, pred and trace). - If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: + If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: f"This trajectory got a score of {score}." \""" ... @@ -207,94 +215,99 @@ def metric( max_full_evals: The maximum number of full evaluations to perform. max_metric_calls: The maximum number of metric calls to perform. reflection_minibatch_size: The number of examples to use for reflection in a single GEPA step. Default is 3. - candidate_selection_strategy: The strategy to use for candidate selection. Default is "pareto", - which stochastically selects candidates from the Pareto frontier of all validation scores. + candidate_selection_strategy: The strategy to use for candidate selection. Default is "pareto", + which stochastically selects candidates from the Pareto frontier of all validation scores. Options: "pareto", "current_best". - reflection_lm: The language model to use for reflection. Required parameter. GEPA benefits from - a strong reflection model. Consider using `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)` + reflection_lm: The language model to use for reflection. Required parameter. GEPA benefits from + a strong reflection model. Consider using `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)` for optimal performance. skip_perfect_score: Whether to skip examples with perfect scores during reflection. Default is True. instruction_proposer: Optional custom instruction proposer implementing GEPA's ProposalFn protocol. - **Default: None (recommended for most users)** - Uses GEPA's proven instruction proposer from - the [GEPA library](https://github.com/gepa-ai/gepa), which implements the - [`ProposalFn`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/core/adapter.py). This default - proposer is highly capable and was validated across diverse experiments reported in the GEPA + **Default: None (recommended for most users)** - Uses GEPA's proven instruction proposer from + the [GEPA library](https://github.com/gepa-ai/gepa), which implements the + [`ProposalFn`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/core/adapter.py). This default + proposer is highly capable and was validated across diverse experiments reported in the GEPA paper and tutorials. - See documentation on custom instruction proposers + See documentation on custom instruction proposers [here](https://dspy.ai/api/optimizers/GEPA/GEPA_Advanced/#custom-instruction-proposers). - + **Advanced Feature**: Only needed for specialized scenarios: - **Multi-modal handling**: Processing dspy.Image inputs alongside textual information - - **Nuanced control over constraints**: Fine-grained control over instruction length, format, + - **Nuanced control over constraints**: Fine-grained control over instruction length, format, and structural requirements beyond standard feedback mechanisms - - **Domain-specific knowledge injection**: Specialized terminology or context that cannot be + - **Domain-specific knowledge injection**: Specialized terminology or context that cannot be provided through feedback_func alone - - **Provider-specific prompting**: Optimizations for specific LLM providers (OpenAI, Anthropic) + - **Provider-specific prompting**: Optimizations for specific LLM providers (OpenAI, Anthropic) with unique formatting preferences - - **Coupled component updates**: Coordinated updates of multiple components together rather + - **Coupled component updates**: Coordinated updates of multiple components together rather than independent optimization - **External knowledge integration**: Runtime access to databases, APIs, or knowledge bases - - The default proposer handles the vast majority of use cases effectively. Use - MultiModalInstructionProposer() from dspy.teleprompt.gepa.instruction_proposal for visual + + The default proposer handles the vast majority of use cases effectively. Use + MultiModalInstructionProposer() from dspy.teleprompt.gepa.instruction_proposal for visual content or implement custom ProposalFn for highly specialized requirements. - - Note: When both instruction_proposer and reflection_lm are set, the instruction_proposer is called - in the reflection_lm context. However, reflection_lm is optional when using a custom instruction_proposer. + + Note: When both instruction_proposer and reflection_lm are set, the instruction_proposer is called + in the reflection_lm context. However, reflection_lm is optional when using a custom instruction_proposer. Custom instruction proposers can invoke their own LLMs if needed. component_selector: Custom component selector implementing the ReflectionComponentSelector protocol, - or a string specifying a built-in selector strategy. Controls which components (predictors) are selected - for optimization at each iteration. Defaults to 'round_robin' strategy which cycles through components - one at a time. Available string options: 'round_robin' (cycles through components sequentially), - 'all' (selects all components for simultaneous optimization). Custom selectors can implement strategies - using LLM-driven selection logic based on optimization state and trajectories. - See [gepa component selectors](https://github.com/gepa-ai/gepa/blob/main/src/gepa/strategies/component_selector.py) + or a string specifying a built-in selector strategy. Controls which components (predictors) are selected + for optimization at each iteration. Defaults to 'round_robin' strategy which cycles through components + one at a time. Available string options: 'round_robin' (cycles through components sequentially), + 'all' (selects all components for simultaneous optimization). Custom selectors can implement strategies + using LLM-driven selection logic based on optimization state and trajectories. + See [gepa component selectors](https://github.com/gepa-ai/gepa/blob/main/src/gepa/strategies/component_selector.py) for available built-in selectors and the ReflectionComponentSelector protocol for implementing custom selectors. add_format_failure_as_feedback: Whether to add format failures as feedback. Default is False. use_merge: Whether to use merge-based optimization. Default is True. max_merge_invocations: The maximum number of merge invocations to perform. Default is 5. num_threads: The number of threads to use for evaluation with `Evaluate`. Optional. failure_score: The score to assign to failed examples. Default is 0.0. - perfect_score: The maximum score achievable by the metric. Default is 1.0. Used by GEPA + perfect_score: The maximum score achievable by the metric. Default is 1.0. Used by GEPA to determine if all examples in a minibatch are perfect. - log_dir: The directory to save the logs. GEPA saves elaborate logs, along with all candidate - programs, in this directory. Running GEPA with the same `log_dir` will resume the run + log_dir: The directory to save the logs. GEPA saves elaborate logs, along with all candidate + programs, in this directory. Running GEPA with the same `log_dir` will resume the run from the last checkpoint. - track_stats: Whether to return detailed results and all proposed programs in the `detailed_results` + track_stats: Whether to return detailed results and all proposed programs in the `detailed_results` attribute of the optimized program. Default is False. use_wandb: Whether to use wandb for logging. Default is False. - wandb_api_key: The API key to use for wandb. If not provided, wandb will use the API key + wandb_api_key: The API key to use for wandb. If not provided, wandb will use the API key from the environment variable `WANDB_API_KEY`. wandb_init_kwargs: Additional keyword arguments to pass to `wandb.init`. - track_best_outputs: Whether to track the best outputs on the validation set. track_stats must - be True if track_best_outputs is True. The optimized program's `detailed_results.best_outputs_valset` + track_best_outputs: Whether to track the best outputs on the validation set. track_stats must + be True if track_best_outputs is True. The optimized program's `detailed_results.best_outputs_valset` will contain the best outputs for each task in the validation set. - warn_on_score_mismatch: GEPA (currently) expects the metric to return the same module-level score when - called with and without the pred_name. This flag (defaults to True) determines whether a warning is + warn_on_score_mismatch: GEPA (currently) expects the metric to return the same module-level score when + called with and without the pred_name. This flag (defaults to True) determines whether a warning is raised if a mismatch in module-level and predictor-level score is detected. + optimize_react_components: Whether to optimize ReAct module components including react + instructions, extract instructions, tool descriptions, and tool argument descriptions. + When enabled, GEPA jointly optimizes all four components of ReAct modules. See the + [ReAct Component Optimization guide](https://dspy.ai/api/optimizers/GEPA/GEPA_Advanced/#react-component-optimization) + for details on when to use this feature and how it works. Default is False. seed: The random seed to use for reproducibility. Default is 0. gepa_kwargs: (Optional) provide additional kwargs to be passed to [gepa.optimize](https://github.com/gepa-ai/gepa/blob/main/src/gepa/api.py) method - + Note: Budget Configuration: Exactly one of `auto`, `max_full_evals`, or `max_metric_calls` must be provided. The `auto` parameter provides preset configurations: "light" for quick experimentation, "medium" for balanced optimization, and "heavy" for thorough optimization. - + Reflection Configuration: The `reflection_lm` parameter is required and should be a strong language model. GEPA performs best with models like `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)`. The reflection process analyzes failed examples to generate feedback for program improvement. - + Merge Configuration: GEPA can merge successful program variants using `use_merge=True`. The `max_merge_invocations` parameter controls how many merge attempts are made during optimization. - - Evaluation Configuration: Use `num_threads` to parallelize evaluation. The `failure_score` and + + Evaluation Configuration: Use `num_threads` to parallelize evaluation. The `failure_score` and `perfect_score` parameters help GEPA understand your metric's range and optimize accordingly. - + Logging Configuration: Set `log_dir` to save detailed logs and enable checkpoint resuming. Use `track_stats=True` to access detailed optimization results via the `detailed_results` attribute. Enable `use_wandb=True` for experiment tracking and visualization. - + Reproducibility: Set `seed` to ensure consistent results across runs with the same configuration. """ def __init__( @@ -328,6 +341,7 @@ def __init__( wandb_init_kwargs: dict[str, Any] | None = None, track_best_outputs: bool = False, warn_on_score_mismatch: bool = True, + optimize_react_components: bool = False, use_mlflow: bool = False, # Reproducibility seed: int | None = 0, @@ -390,6 +404,7 @@ def __init__( self.wandb_api_key = wandb_api_key self.wandb_init_kwargs = wandb_init_kwargs self.warn_on_score_mismatch = warn_on_score_mismatch + self.optimize_react_components = optimize_react_components self.use_mlflow = use_mlflow if track_best_outputs: @@ -518,11 +533,65 @@ def feedback_fn( rng=rng, reflection_lm=self.reflection_lm, custom_instruction_proposer=self.custom_instruction_proposer, - warn_on_score_mismatch=self.warn_on_score_mismatch + warn_on_score_mismatch=self.warn_on_score_mismatch, + optimize_react_components=self.optimize_react_components, ) # Instantiate GEPA with the simpler adapter-based API base_program = {name: pred.signature.instructions for name, pred in student.named_predictors()} + + # Always traverse to detect ReAct modules + for module_path, module in student.named_sub_modules(): + # Only process ReAct modules + if not isinstance(module, ReAct): + continue + + if self.optimize_react_components: + normalized_path = module_path.removeprefix("self.") if module_path != "self" else "" + + # Get first predictor name as module identifier + for pred_name, _ in module.named_predictors(): + comp_name = pred_name if not normalized_path else f"{normalized_path}.{pred_name}" + # Use full normalized path to avoid collapsing nested modules + # e.g., "multi_agent.coordinator" not "multi_agent" + module_key = f"{REACT_MODULE_PREFIX}:{normalized_path}" if normalized_path else REACT_MODULE_PREFIX + + # Build JSON config with tool args for reflection + config = { + "react": module.react.signature.instructions, + "extract": module.extract.predict.signature.instructions, + "tools": { + tool_name: { + "desc": tool.desc, + "args": tool.args, + "arg_desc": tool.arg_desc or {} + } + for tool_name, tool in module.tools.items() + if tool_name != "finish" + } + } + + # Replace predictor keys with module key and extract key to prevent duplicates + base_program.pop(comp_name, None) + extract_key = f"{normalized_path}.extract.predict" if normalized_path else "extract.predict" + base_program.pop(extract_key, None) + base_program[module_key] = json.dumps(config, indent=2) + break + else: + logger.warning( + f"Detected ReAct module at '{module_path}'. Consider using " + "`optimize_react_components=True` to jointly optimize react instructions, " + "extract instructions, tool descriptions, and tool argument descriptions." + ) + + # Log base_program keys for debugging + logger.info(f"Initialized base_program with {len(base_program)} components:") + for key in sorted(base_program.keys()): + if key.startswith(REACT_MODULE_PREFIX): + logger.info(f" {key}: ") + else: + logger.info(f" {key}: ") + gepa_result: GEPAResult = optimize( seed_candidate=base_program, trainset=trainset, diff --git a/dspy/teleprompt/gepa/gepa_utils.py b/dspy/teleprompt/gepa/gepa_utils.py index 844afe8b00..a1989606b7 100644 --- a/dspy/teleprompt/gepa/gepa_utils.py +++ b/dspy/teleprompt/gepa/gepa_utils.py @@ -1,3 +1,4 @@ +import json import logging import random from typing import Any, Callable, Protocol, TypedDict @@ -10,11 +11,17 @@ from dspy.adapters.types import History from dspy.adapters.types.base_type import Type from dspy.evaluate import Evaluate +from dspy.predict.react import ReAct from dspy.primitives import Example, Prediction from dspy.teleprompt.bootstrap_trace import TraceData logger = logging.getLogger(__name__) + +# Constants for ReAct module optimization +REACT_MODULE_PREFIX = "react_module" + + class LoggerAdapter: def __init__(self, logger: logging.Logger): self.logger = logger @@ -22,6 +29,7 @@ def __init__(self, logger: logging.Logger): def log(self, x: str): self.logger.info(x) + DSPyTrace = list[tuple[Any, dict[str, Any], Prediction]] @@ -31,15 +39,17 @@ class ReflectiveExample(TypedDict): Each example contains the predictor inputs, generated outputs, and feedback from evaluation. """ - Inputs: dict[str, Any] # Predictor inputs (may include str, dspy.Image, etc.) - Generated_Outputs: dict[str, Any] | str # Success: dict with output fields, Failure: error message string - Feedback: str # Always a string - from metric function or parsing error message + + Inputs: dict[str, Any] # Predictor inputs (may include str, dspy.Image, etc.) + Generated_Outputs: dict[str, Any] | str # Success: dict with output fields, Failure: error message string + Feedback: str # Always a string - from metric function or parsing error message class ScoreWithFeedback(Prediction): score: float feedback: str + class PredictorFeedbackFn(Protocol): def __call__( predictor_output: dict[str, Any], @@ -64,6 +74,7 @@ def __call__( """ ... + class DspyAdapter(GEPAAdapter[Example, TraceData, Prediction]): def __init__( self, @@ -76,7 +87,8 @@ def __init__( rng: random.Random | None = None, reflection_lm=None, custom_instruction_proposer: "ProposalFn | None" = None, - warn_on_score_mismatch: bool = True + warn_on_score_mismatch: bool = True, + optimize_react_components: bool = False, ): self.student = student_module self.metric_fn = metric_fn @@ -88,42 +100,187 @@ def __init__( self.reflection_lm = reflection_lm self.custom_instruction_proposer = custom_instruction_proposer self.warn_on_score_mismatch = warn_on_score_mismatch - - if self.custom_instruction_proposer is not None: - # We are only overriding the propose_new_texts method when a custom - # instruction proposer is provided. Otherwise, we use the GEPA - # default propose_new_texts. - - def custom_propose_new_texts( + self.optimize_react_components = optimize_react_components + + def build_propose_new_texts(): + instruction_proposer = None + + # Init instruction proposer (custom or default) + if self.custom_instruction_proposer is not None: + instruction_proposer = self.custom_instruction_proposer + else: + from gepa.strategies.instruction_proposal import InstructionProposalSignature + + def default_instruction_proposer( + candidate: dict[str, str], + reflective_dataset: dict[str, list[dict[str, Any]]], + components_to_update: list[str], + ) -> dict[str, str]: + lm = self.reflection_lm if self.reflection_lm is not None else dspy.settings.lm + updated_components: dict[str, str] = {} + for name in components_to_update: + base_instruction = candidate[name] + dataset_with_feedback = reflective_dataset[name] + updated_components[name] = InstructionProposalSignature.run( + lm=(lambda x: lm(x)[0]), + input_dict={ + "current_instruction_doc": base_instruction, + "dataset_with_feedback": dataset_with_feedback, + }, + )["new_instruction"] + return updated_components + + instruction_proposer = default_instruction_proposer + + # Init ReAct module proposer if tool optimization is enabled + react_module_proposer = None + if self.optimize_react_components: + from .instruction_proposal import ReActModuleProposer + + react_module_proposer = ReActModuleProposer() + + def propose_component_texts( candidate: dict[str, str], reflective_dataset: dict[str, list[dict[str, Any]]], - components_to_update: list[str] + components_to_update: list[str], ) -> dict[str, str]: + # If custom proposer provided, override everything with custom proposer + if self.custom_instruction_proposer: + if self.reflection_lm is not None: + with dspy.context(lm=self.reflection_lm): + return instruction_proposer( + candidate=candidate, + reflective_dataset=reflective_dataset, + components_to_update=components_to_update, + ) + else: + return instruction_proposer( + candidate=candidate, + reflective_dataset=reflective_dataset, + components_to_update=components_to_update, + ) + + # Otherwise, route to appropriate proposers + # Separate react_module components from regular instruction components + react_module_components = [c for c in components_to_update if c.startswith(REACT_MODULE_PREFIX)] + instruction_components = [c for c in components_to_update if not c.startswith(REACT_MODULE_PREFIX)] + + results: dict[str, str] = {} + + # Handle regular instruction components + logger.debug(f"Routing {len(instruction_components)} instruction components to instruction_proposer") if self.reflection_lm is not None: with dspy.context(lm=self.reflection_lm): - return self.custom_instruction_proposer( + results.update( + instruction_proposer( + candidate=candidate, + reflective_dataset=reflective_dataset, + components_to_update=instruction_components, + ) + ) + else: + results.update( + instruction_proposer( candidate=candidate, reflective_dataset=reflective_dataset, - components_to_update=components_to_update + components_to_update=instruction_components, ) - else: - return self.custom_instruction_proposer( - candidate=candidate, - reflective_dataset=reflective_dataset, - components_to_update=components_to_update ) - self.propose_new_texts = custom_propose_new_texts + # Handle ReAct module components + if react_module_components: + logger.debug(f"Routing {len(react_module_components)} react_module components to react_module_proposer") + if self.reflection_lm is not None: + with dspy.context(lm=self.reflection_lm): + results.update( + react_module_proposer( + candidate=candidate, + reflective_dataset=reflective_dataset, + components_to_update=react_module_components, + ) + ) + else: + results.update( + react_module_proposer( + candidate=candidate, + reflective_dataset=reflective_dataset, + components_to_update=react_module_components, + ) + ) + + return results + + return propose_component_texts + + self.propose_new_texts = build_propose_new_texts() # Cache predictor names/signatures self.named_predictors = list(self.student.named_predictors()) - def build_program(self, candidate: dict[str, str]): new_prog = self.student.deepcopy() + + # Apply regular predictor instructions for name, pred in new_prog.named_predictors(): if name in candidate: pred.signature = pred.signature.with_instructions(candidate[name]) + + # Apply ReAct module updates (JSON configs for ReAct modules: react, extract, tools) + if self.optimize_react_components: + + for module_path, module in new_prog.named_sub_modules(): + # Only process ReAct modules + if not isinstance(module, ReAct): + continue + + # Build module key + normalized_path = module_path.removeprefix("self.") if module_path != "self" else "" + module_key = REACT_MODULE_PREFIX if normalized_path == "" else f"{REACT_MODULE_PREFIX}:{normalized_path}" + + # Check if this module was optimized + if module_key not in candidate: + continue + + # Deserialize JSON containing optimized module configuration + try: + module_config = json.loads(candidate[module_key]) + logger.debug(f"Applying optimized module config to {module_key}") + + # Apply react instruction + if "react" in module_config: + module.react.signature = module.react.signature.with_instructions(module_config["react"]) + logger.debug(" Updated react instruction") + + # Apply extract instruction + if "extract" in module_config: + module.extract.predict.signature = module.extract.predict.signature.with_instructions(module_config["extract"]) + logger.debug(" Updated extract instruction") + + # Apply tool descriptions + if "tools" in module_config: + for tool_name, tool_config in module_config["tools"].items(): + tool = module.tools[tool_name] + + # Update tool description + if tool_config.get("desc"): + tool.desc = tool_config["desc"] + logger.debug(f" Updated tool '{tool_name}' description") + + # Update tool arg descriptions + arg_desc = tool_config.get("arg_desc") + if arg_desc: + tool.arg_desc = tool.arg_desc or {} + tool.arg_desc.update(arg_desc) + # Propagate to tool.args + for arg_name, description in arg_desc.items(): + if arg_name in tool.args: + tool.args[arg_name]["description"] = description + logger.debug(f" Updated tool '{tool_name}' arg descriptions: {list(arg_desc.keys())}") + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON config for {module_key}: {e}") + raise + return new_prog def evaluate(self, batch, candidate, capture_traces=False): @@ -165,7 +322,7 @@ def evaluate(self, batch, candidate, capture_traces=False): return_all_scores=True, failure_score=self.failure_score, provide_traceback=True, - max_errors=len(batch) * 100 + max_errors=len(batch) * 100, ) res = evaluator(program) outputs = [r[1] for r in res.results] @@ -173,18 +330,54 @@ def evaluate(self, batch, candidate, capture_traces=False): scores = [s["score"] if hasattr(s, "score") else s for s in scores] return EvaluationBatch(outputs=outputs, scores=scores, trajectories=None) - def make_reflective_dataset(self, candidate, eval_batch, components_to_update) -> dict[str, list[ReflectiveExample]]: + def make_reflective_dataset( + self, candidate, eval_batch, components_to_update + ) -> dict[str, list[ReflectiveExample]]: from dspy.teleprompt.bootstrap_trace import FailedPrediction + program = self.build_program(candidate) ret_d: dict[str, list[ReflectiveExample]] = {} + + # Debug: Log what components we're trying to update + logger.info(f"make_reflective_dataset called with components_to_update: {components_to_update}") + for pred_name in components_to_update: - module = None - for name, m in program.named_predictors(): - if name == pred_name: - module = m - break - assert module is not None + logger.info(f"Processing component: {pred_name}") + + # Handle ReAct module components - use extract predictor for final outputs + if pred_name.startswith(REACT_MODULE_PREFIX): + # Extract the target path from the key + target_path = pred_name.removeprefix(f"{REACT_MODULE_PREFIX}:") if ":" in pred_name else "" + + # Find the ReAct module by traversing program structure (same as regular predictors) + react_module = None + for module_path, m in program.named_sub_modules(): + if not isinstance(m, ReAct): + continue + + # Normalize path (same pattern as build_program) + normalized_path = module_path.removeprefix("self.") if module_path != "self" else "" + if normalized_path == target_path: + react_module = m + break + + if react_module is None: + logger.warning(f"ReAct module not found for key: {pred_name}") + continue + + module = react_module.extract.predict + logger.debug(f" ReAct module detected: using {target_path or 'top-level'}.extract for final outputs") + + # Regular predictor - find by name + else: + module = None + for name, m in program.named_predictors(): + if name == pred_name: + module = m + break + assert module is not None + logger.debug(f" Regular predictor: {pred_name}") items: list[ReflectiveExample] = [] for data in eval_batch.trajectories or []: @@ -195,22 +388,34 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) - if hasattr(module_score, "score"): module_score = module_score["score"] + logger.debug(f" Processing trace with {len(trace)} entries for example: {example}") trace_instances = [t for t in trace if t[0].signature.equals(module.signature)] + logger.debug(f" Found {len(trace_instances)} matching trace instances for signature: {module.signature}") if not self.add_format_failure_as_feedback: trace_instances = [t for t in trace_instances if not isinstance(t[2], FailedPrediction)] + logger.debug(f" After filtering FailedPrediction: {len(trace_instances)} instances") if len(trace_instances) == 0: + logger.debug(" Skipping example - no matching trace instances") continue - selected = None - for t in trace_instances: - if isinstance(t[2], FailedPrediction): - selected = t - break + # For ReAct modules, use LAST extract invocation (has trajectory + final outputs) + if pred_name.startswith(REACT_MODULE_PREFIX): + selected = trace_instances[-1] + logger.debug(f" Using LAST extract call ({len(trace_instances)} total) with trajectory + final outputs") + if "trajectory" in selected[1]: + traj_preview = str(selected[1]["trajectory"])[:100] + logger.debug(f" Trajectory preview: {traj_preview}...") + else: + selected = None + for t in trace_instances: + if isinstance(t[2], FailedPrediction): + selected = t + break - if selected is None: - if isinstance(prediction, FailedPrediction): - continue - selected = self.rng.choice(trace_instances) + if selected is None: + if isinstance(prediction, FailedPrediction): + continue + selected = self.rng.choice(trace_instances) inputs = selected[1] outputs = selected[2] @@ -262,7 +467,14 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) - d["Feedback"] = "Your output failed to parse. Follow this structure:\n" + structure_instruction # d['score'] = self.failure_score else: - feedback_fn = self.feedback_map[pred_name] + # Map react_module component keys to their react predictor names for feedback lookup + if pred_name.startswith(REACT_MODULE_PREFIX): + # "react_module" → "react", "react_module:salary_agent" → "salary_agent.react" + actual_pred_name = pred_name.split(":", 1)[1] + ".react" if ":" in pred_name else "react" + else: + actual_pred_name = pred_name + + feedback_fn = self.feedback_map[actual_pred_name] fb = feedback_fn( predictor_output=outputs, predictor_inputs=inputs, @@ -279,10 +491,23 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) - items.append(d) + # Log exact reflective example that reflection LM will see + if pred_name.startswith(REACT_MODULE_PREFIX) and len(items) == 1: + logger.info(f" First reflective example for {pred_name}:") + logger.info(f" Inputs: {list(d['Inputs'].keys())}") + if "trajectory" in d["Inputs"]: + traj = d["Inputs"]["trajectory"] + logger.info(f" Trajectory length: {len(traj)} chars") + logger.info(f" Trajectory sample:\n{traj[:300]}...") + logger.info(f" Outputs: {list(d['Generated Outputs'].keys()) if isinstance(d['Generated Outputs'], dict) else ''}") + logger.info(f" Feedback: {d['Feedback'][:100]}...") + if len(items) == 0: - # raise Exception(f"No valid predictions found for module {module.signature}.") + logger.warning(f" No valid reflective examples found for {pred_name}") continue + ret_d[pred_name] = items + logger.info(f" Created {len(items)} reflective examples for {pred_name}") if len(ret_d) == 0: raise Exception("No valid predictions found for any module.") diff --git a/dspy/teleprompt/gepa/instruction_proposal.py b/dspy/teleprompt/gepa/instruction_proposal.py index 23810b9a02..2b8ae1e590 100644 --- a/dspy/teleprompt/gepa/instruction_proposal.py +++ b/dspy/teleprompt/gepa/instruction_proposal.py @@ -1,3 +1,5 @@ +import json +import logging from typing import Any from gepa.core.adapter import ProposalFn @@ -6,6 +8,11 @@ from dspy.adapters.types.base_type import Type from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample +logger = logging.getLogger(__name__) + +# Constants for ReAct module optimization +REACT_MODULE_PREFIX = "react_module" + class GenerateEnhancedMultimodalInstructionFromFeedback(dspy.Signature): """I provided an assistant with instructions to perform a task involving visual content, but the assistant's performance needs improvement based on the examples and feedback below. @@ -310,3 +317,243 @@ def __call__( updated_components[component_name] = new_instruction return updated_components + +class GenerateImprovedReActDescriptionsFromFeedback(dspy.Signature): + """Improve a ReAct agent based on execution examples and feedback. + + These components are progressively optimized - refine what needs improvement. + Analyze the trajectories to identify successful patterns and failure causes. + Generate improved texts to help the agent succeed on similar tasks. + Place improved texts at their appropriate level of abstraction and/or specificity. + """ + + current_react_instruction = dspy.InputField( + desc="Current ReAct module instruction guiding the ReAct agent's reasoning and tool selection" + ) + current_extract_instruction = dspy.InputField( + desc="Current Extract module instruction for extracting final answers from trajectories" + ) + current_tools = dspy.InputField( + annotation=list[dspy.Tool], + desc="Available tools with their complete schemas" + ) + examples_with_feedback = dspy.InputField( + desc="Execution examples with feedback showing successes and failures" + ) + + improved_react_instruction: str | None = dspy.OutputField( + desc="ReAct instruction for reasoning and tool selection", + default=None + ) + improved_extract_instruction: str | None = dspy.OutputField( + desc="Extract instruction for answer extraction", + default=None + ) + + + + + +class ReActModuleProposer(ProposalFn): + """Proposer for optimizing ReAct module configurations. + + Jointly optimizes three components of a ReAct module: the react instruction that guides + reasoning and tool selection, the extract instruction for answer extraction from trajectories, + and tool descriptions with their parameters. Uses dynamic signature generation to create + output fields for each tool and parameter, enabling the reflection LM to optimize all parts + cohesively based on execution feedback. + + This joint optimization approach allows the LM to see how instructions and tool descriptions + work together, leading to more coherent improvements than optimizing each component separately. + """ + + def __init__(self): + """Initialize the ReAct module proposer.""" + pass + + def __call__( + self, + candidate: dict[str, str], + reflective_dataset: dict[str, list[ReflectiveExample]], + components_to_update: list[str], + ) -> dict[str, str]: + """Optimize ReAct module components. + + Args: + candidate: Current component name -> JSON config mapping + reflective_dataset: Component name -> list of reflective examples + components_to_update: List of react_module component names to update + + Returns: + dict: Mapping of component names to improved JSON configs + """ + + logger.info("\n=== ReActModuleProposer Called ===") + logger.info(f"components_to_update: {components_to_update}") + logger.info(f"candidate keys: {list(candidate.keys())}") + logger.info(f"reflective_dataset keys: {list(reflective_dataset.keys())}") + + updated_components = {} + + for module_key in components_to_update: + # Only handle react_module components + if not module_key.startswith(REACT_MODULE_PREFIX): + logger.debug(f"Skipping non-react_module component: {module_key}") + continue + + if module_key not in candidate or module_key not in reflective_dataset: + logger.warning(f"Skipping {module_key}: not in candidate={module_key not in candidate}, not in reflective_dataset={module_key not in reflective_dataset}") + continue + + logger.info(f"\nProcessing react_module: {module_key}") + + # Deserialize react module config + try: + current_react_config = json.loads(candidate[module_key]) + logger.debug(f"Deserialized config keys: {list(current_react_config.keys())}") + except json.JSONDecodeError as e: + logger.error(f"Failed to deserialize config for {module_key}: {e}") + continue + + # Reconstruct Tool objects from JSON metadata so the adapter can format them for the reflection LM. + # Tool.func cannot be serialized in JSON, so we use a placeholder (never executed). + current_tools_dict = current_react_config.get("tools", {}) + logger.info(f"Found {len(current_tools_dict)} tools: {list(current_tools_dict.keys())}") + tools_list = [] + for tool_name, tool_info in current_tools_dict.items(): + tool = dspy.Tool( + func=lambda: None, # Placeholder - Tool requires Callable, but only schema is used + name=tool_name, + desc=tool_info.get("desc", ""), + ) + tool.args = tool_info.get("args", {}) + tool.arg_desc = tool_info.get("arg_desc", {}) + tools_list.append(tool) + + # Build dynamic signature by extending base signature + signature = GenerateImprovedReActDescriptionsFromFeedback + + logger.debug(f"Building dynamic signature with {len(tools_list)} tools...") + + # Add dynamic tool description and arg descriptions output fields + for tool in tools_list: + tool_name = tool.name + tool_info = current_tools_dict[tool_name] + + signature = signature.append( + f"improved_tool_{tool_name}_desc", + dspy.OutputField( + desc=f"Purpose of tool '{tool_name}'", + default=None + ) + ) + + if tool_info.get("args"): + for arg_name in tool_info["args"].keys(): + signature = signature.append( + f"improved_tool_{tool_name}_arg_{arg_name}_desc", + dspy.OutputField( + desc=f"Usage of parameter '{arg_name}'", + default=None + ) + ) + + # Format examples + formatted_examples = self._format_examples(reflective_dataset[module_key]) + logger.info(f"Formatted {len(reflective_dataset[module_key])} reflective examples") + logger.debug(f"Examples preview: {formatted_examples[:200]}...") + + logger.info("Calling reflection LM with dynamic signature...") + propose_descriptions = dspy.Predict(signature) + result = propose_descriptions( + current_react_instruction=current_react_config.get("react", ""), + current_extract_instruction=current_react_config.get("extract", ""), + current_tools=tools_list, # List of Tool objects for adapter formatting + examples_with_feedback=formatted_examples, + ) + + # Build improved config from reflection LM suggestions + # Reflection LM returns None for components it doesn't want to change, or text for improvements + logger.info("Building improved config from reflection LM response...") + improved_react_config = {} + + # Update react instruction if reflection LM suggested improvement + if result.improved_react_instruction is not None: + improved_react_config["react"] = result.improved_react_instruction + logger.debug(f"React instruction: {len(result.improved_react_instruction)} chars") + else: + logger.debug("React instruction: reflection LM suggests keeping original") + + # Update extract instruction if reflection LM suggested improvement + if result.improved_extract_instruction is not None: + improved_react_config["extract"] = result.improved_extract_instruction + logger.debug(f"Extract instruction: {len(result.improved_extract_instruction)} chars") + else: + logger.debug("Extract instruction: reflection LM suggests keeping original)") + + # Update tool descriptions if reflection LM suggested improvements + improved_react_config["tools"] = {} + for tool_name, tool_info in current_tools_dict.items(): + # Check if reflection LM suggested improving this tool's description + improved_desc = getattr(result, f"improved_tool_{tool_name}_desc", None) + + # Skip if reflection LM suggests keeping original + if improved_desc is None: + logger.debug(f" Tool '{tool_name}': reflection LM suggests keeping original") + continue + + improved_tool_info = { + "desc": improved_desc, + "arg_desc": {} + } + + # Update parameter descriptions if reflection LM suggested improvements + if tool_info.get("args"): + for arg_name in tool_info["args"].keys(): + field_name = f"improved_tool_{tool_name}_arg_{arg_name}_desc" + arg_desc = getattr(result, field_name, None) + if arg_desc is not None: # Reflection LM suggested improvement + improved_tool_info["arg_desc"][arg_name] = arg_desc + + improved_react_config["tools"][tool_name] = improved_tool_info + logger.debug(f" Tool '{tool_name}': desc={len(improved_desc)} chars, params={len(improved_tool_info['arg_desc'])}") + + # Serialize back to JSON + updated_components[module_key] = json.dumps(improved_react_config, indent=2) + logger.info(f"Successfully optimized {module_key}") + logger.debug(f"Serialized config length: {len(updated_components[module_key])} chars") + + logger.info(f"\nReActModuleProposer returning {len(updated_components)} components: {list(updated_components.keys())}") + return updated_components + + def _format_examples(self, reflective_dataset: list[ReflectiveExample]) -> str: + """Format reflective examples using GEPA's markdown structure.""" + + def render_value(value, level=3): + if isinstance(value, dict): + s = "" + for key, val in value.items(): + s += f"{'#' * level} {key}\n" + s += render_value(val, min(level + 1, 6)) + if not value: + s += "\n" + return s + if isinstance(value, (list, tuple)): + s = "" + for index, item in enumerate(value): + s += f"{'#' * level} Item {index + 1}\n" + s += render_value(item, min(level + 1, 6)) + if not value: + s += "\n" + return s + return f"{str(value).strip()}\n\n" + + def convert_sample_to_markdown(sample, example_num): + s = f"# Example {example_num}\n" + for key, val in sample.items(): + s += f"## {key}\n" + s += render_value(val, level=3) + return s + + formatted_parts = [convert_sample_to_markdown(example, i + 1) for i, example in enumerate(reflective_dataset)] + return "\n\n".join(formatted_parts) diff --git a/tests/teleprompt/test_gepa_react_optimization.py b/tests/teleprompt/test_gepa_react_optimization.py new file mode 100644 index 0000000000..68e1512b8b --- /dev/null +++ b/tests/teleprompt/test_gepa_react_optimization.py @@ -0,0 +1,849 @@ +"""Tests for GEPA's unified ReAct module optimization with full path preservation. + +Tests the critical bug fix where ReAct module paths must be preserved in full +(e.g., "multi_agent.orchestrator") instead of being truncated (e.g., "multi_agent"). +This ensures correct module identification in multi-agent systems. + +What we test: +1. Detection: GEPA correctly identifies ReAct modules with full paths +2. Reconstruction: build_program applies optimizations using full paths +3. Reflective dataset: make_reflective_dataset captures complete trajectories + +Bug fixed: Path truncation in gepa.py and gepa_utils.py caused: +- Wrong module detection in nested structures +- Incorrect trajectory capture in multi-agent systems +- Optimization applied to wrong modules +""" + +import json + +import dspy +from dspy import Example +from dspy.utils.dummies import DummyLM + + +def setup_capture_for_base_program(monkeypatch): + """Capture base_program passed to gepa.optimize.""" + captured_base_program = {} + + from gepa import optimize as original_optimize + + def capture_optimize(seed_candidate, **kwargs): + captured_base_program.update(seed_candidate) + return original_optimize(seed_candidate=seed_candidate, **kwargs) + + import gepa + monkeypatch.setattr(gepa, "optimize", capture_optimize) + + return captured_base_program + + +def simple_metric_for_detection(example, pred, trace=None, pred_name=None, pred_trace=None): + """Simple metric for GEPA detection tests.""" + return dspy.Prediction(score=0.5, feedback="ok") + + +def simple_metric_for_reconstruction(example, pred, trace=None): + """Simple metric for adapter reconstruction tests.""" + return 0.5 + + +def simple_feedback(*args, **kwargs): + """Generic feedback function for reflective dataset tests.""" + return {"score": 1.0, "feedback": "Good"} + + +def create_gepa_optimizer_for_detection(): + """Create GEPA optimizer with standard test configuration.""" + task_lm = DummyLM([{"answer": "test"}] * 10) + reflection_lm = DummyLM([{"improved_instruction": "optimized"}] * 10) + dspy.settings.configure(lm=task_lm) + + optimizer = dspy.GEPA( + metric=simple_metric_for_detection, + reflection_lm=reflection_lm, + max_metric_calls=2, + optimize_react_components=True, + ) + + trainset = [Example(question="test", answer="test").with_inputs("question")] + + return optimizer, trainset + + +def assert_react_module_detected(captured_base_program, module_path, expected_tools): + """Assert that a ReAct module was detected with all components.""" + from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX + + module_key = REACT_MODULE_PREFIX if module_path == "" else f"{REACT_MODULE_PREFIX}:{module_path}" + + assert module_key in captured_base_program, f"Expected '{module_key}' to be detected" + + config = json.loads(captured_base_program[module_key]) + + assert "react" in config, f"{module_key} should have react instruction" + assert "extract" in config, f"{module_key} should have extract instruction" + assert "tools" in config, f"{module_key} should have tools" + + for tool_name, expected_desc in expected_tools.items(): + assert tool_name in config["tools"], f"{module_key} should have '{tool_name}' tool" + tool = config["tools"][tool_name] + assert "desc" in tool, f"{tool_name} should have desc" + assert tool["desc"] == expected_desc, f"{tool_name} desc should match" + assert "arg_desc" in tool, f"{tool_name} should have arg_desc" + + return config + + +def assert_regular_module_detected(captured_base_program, module_key): + """Assert that a non-ReAct module was detected.""" + assert module_key in captured_base_program, f"Expected '{module_key}' to be detected" + instruction = captured_base_program[module_key] + assert isinstance(instruction, str), f"{module_key} should be string instruction, not JSON" + return instruction + + +def assert_react_module_updated(react_module, expected_react_instruction, expected_extract_instruction, expected_tool_descriptions): + """Assert that a ReAct module was properly updated with optimized instructions. + + Args: + react_module: The ReAct module instance to check + expected_react_instruction: Expected react instruction text + expected_extract_instruction: Expected extract instruction text + expected_tool_descriptions: Dict of {tool_name: {"desc": desc, "arg_desc": {arg: desc}}} + """ + assert react_module.react.signature.instructions == expected_react_instruction, \ + f"React instruction mismatch: got {react_module.react.signature.instructions}" + + assert react_module.extract.predict.signature.instructions == expected_extract_instruction, \ + f"Extract instruction mismatch: got {react_module.extract.predict.signature.instructions}" + + for tool_name, tool_desc in expected_tool_descriptions.items(): + tool = react_module.tools[tool_name] + + if "desc" in tool_desc: + assert tool.desc == tool_desc["desc"], \ + f"Tool '{tool_name}' desc mismatch: got {tool.desc}" + + if "arg_desc" in tool_desc: + for arg_name, expected_arg_desc in tool_desc["arg_desc"].items(): + # Verify arg_desc propagated to tool.args (rendered in prompts) + assert arg_name in tool.args, \ + f"Tool '{tool_name}' arg_desc has '{arg_name}' but args schema doesn't" + assert tool.args[arg_name].get("description") == expected_arg_desc, \ + f"Tool '{tool_name}' args['{arg_name}']['description'] should match arg_desc (got {tool.args[arg_name].get('description')!r}, expected {expected_arg_desc!r})" + + +def assert_regular_module_updated(predictor, expected_instruction): + """Assert that a regular (non-ReAct) predictor was updated with optimized instruction.""" + assert predictor.signature.instructions == expected_instruction, \ + f"Instruction mismatch: expected '{expected_instruction}', got '{predictor.signature.instructions}'" + + +def mock_optimized_react_module(optimized_candidate, module_path, react_instruction, extract_instruction, tool_descriptions): + """Helper to mock an optimized ReAct module in the candidate dict. + + Args: + optimized_candidate: The candidate dict to modify + module_path: Module path (e.g., "multi_agent.orchestrator" or "" for top-level) + react_instruction: New react instruction + extract_instruction: New extract instruction + tool_descriptions: Dict of {tool_name: {"desc": desc, "arg_desc": {arg: desc}}} + """ + from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX + + module_key = REACT_MODULE_PREFIX if module_path == "" else f"{REACT_MODULE_PREFIX}:{module_path}" + config = json.loads(optimized_candidate[module_key]) + config["react"] = react_instruction + config["extract"] = extract_instruction + + for tool_name, tool_desc in tool_descriptions.items(): + if "desc" in tool_desc: + config["tools"][tool_name]["desc"] = tool_desc["desc"] + if "arg_desc" in tool_desc: + config["tools"][tool_name]["arg_desc"] = tool_desc["arg_desc"] + + optimized_candidate[module_key] = json.dumps(config) + + +def create_single_react_program(): + """Create a simple single ReAct module program.""" + def search_tool(query: str) -> str: + """Search for information.""" + return f"Results for: {query}" + + def calculate_tool(expr: str) -> str: + """Calculate math expression.""" + return "42" + + return dspy.ReAct( + "question -> answer", + tools=[ + dspy.Tool(search_tool, name="search", desc="Search the web"), + dspy.Tool(calculate_tool, name="calc", desc="Calculate math"), + ], + max_iters=3 + ) + + +def create_multi_react_workflow_program(): + """Create a mixed workflow program with 2 ReAct + 1 ChainOfThought.""" + class ResearchWorkflow(dspy.Module): + def __init__(self): + super().__init__() + + def search_papers(query: str) -> str: + return f"Papers: {query}" + + def analyze_data(data: str) -> str: + return f"Analysis: {data}" + + self.coordinator = dspy.ReAct( + "task -> plan", + tools=[dspy.Tool(search_papers, name="search", desc="Search tool")], + max_iters=2 + ) + + self.researcher = dspy.ReAct( + "plan -> findings", + tools=[dspy.Tool(analyze_data, name="analyze", desc="Analysis tool")], + max_iters=2 + ) + + self.summarizer = dspy.ChainOfThought("findings -> summary") + + def forward(self, question): + plan = self.coordinator(task=question) + findings = self.researcher(plan=plan.plan) + summary = self.summarizer(findings=findings.findings) + return dspy.Prediction(answer=summary.summary) + + class MixedWorkflowSystem(dspy.Module): + def __init__(self): + super().__init__() + self.workflow = ResearchWorkflow() + + def forward(self, question): + return self.workflow(question=question) + + return MixedWorkflowSystem() + + +def create_orchestrator_with_workers_program(): + """Create orchestrator with 2 worker ReAct modules as tools.""" + class OrchestratorWorkerSystem(dspy.Module): + def __init__(self): + super().__init__() + + def search_web(query: str) -> str: + return f"Search results: {query}" + + def analyze_data(data: str) -> str: + return f"Analysis: {data}" + + def research_topic(topic: str) -> str: + return f"Research: {topic}" + + self.analyst = dspy.ReAct( + "data -> analysis", + tools=[dspy.Tool(analyze_data, name="analyze", desc="Analyze data")], + max_iters=2 + ) + + self.researcher = dspy.ReAct( + "topic -> findings", + tools=[dspy.Tool(research_topic, name="research", desc="Research topic")], + max_iters=2 + ) + + def use_analyst(data: str) -> str: + result = self.analyst(data=data) + return str(result.analysis) if hasattr(result, "analysis") else str(result) + + def use_researcher(topic: str) -> str: + result = self.researcher(topic=topic) + return str(result.findings) if hasattr(result, "findings") else str(result) + + self.orchestrator = dspy.ReAct( + "question -> answer", + tools=[ + dspy.Tool(search_web, name="search", desc="Search tool"), + dspy.Tool(use_analyst, name="analyst", desc="Use analyst"), + dspy.Tool(use_researcher, name="researcher", desc="Use researcher"), + ], + max_iters=3 + ) + + def forward(self, question): + result = self.orchestrator(question=question) + return dspy.Prediction(answer=result.answer) + + class MultiAgentSystem(dspy.Module): + def __init__(self): + super().__init__() + self.multi_agent = OrchestratorWorkerSystem() + + def forward(self, question): + return self.multi_agent(question=question) + + return MultiAgentSystem() + + +def test_single_react_module_detection(monkeypatch): + """Test GEPA detects a single top-level ReAct module with all components. + + Tests: + - ReAct module detected as REACT_MODULE_PREFIX (no path suffix) + - react instruction captured + - extract instruction captured + - All tools with descriptions captured + """ + from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX + + captured_base_program = setup_capture_for_base_program(monkeypatch) + program = create_single_react_program() + + optimizer, trainset = create_gepa_optimizer_for_detection() + + try: + optimizer.compile(program, trainset=trainset, valset=trainset) + except Exception: + pass + + module_key = REACT_MODULE_PREFIX + assert module_key in captured_base_program, f"Expected '{module_key}' to be detected" + + assert_react_module_detected( + captured_base_program=captured_base_program, + module_path="", + expected_tools={"search": "Search the web", "calc": "Calculate math"} + ) + + +def test_multi_react_workflow_detection(monkeypatch): + """Test GEPA detects multiple ReAct modules with FULL paths preserved. + + PRIMARY BUG FIX TEST: Validates paths are NOT truncated. + + Tests: + - workflow.coordinator detected as "react_module:workflow.coordinator" (NOT "react_module:workflow") + - workflow.researcher detected as "react_module:workflow.researcher" (NOT "react_module:workflow") + - Both ReAct modules detected separately (not merged) + - Non-ReAct module (summarizer) detected correctly + + Before fix: Paths truncated at first dot → wrong module matching + After fix: Full paths preserved → correct module identification + """ + from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX + + captured_base_program = setup_capture_for_base_program(monkeypatch) + program = create_multi_react_workflow_program() + + optimizer, trainset = create_gepa_optimizer_for_detection() + + try: + optimizer.compile(program, trainset=trainset, valset=trainset) + except Exception: + pass + + assert f"{REACT_MODULE_PREFIX}:workflow.coordinator" in captured_base_program + assert f"{REACT_MODULE_PREFIX}:workflow.researcher" in captured_base_program + + react_modules = [k for k in captured_base_program.keys() if k.startswith(REACT_MODULE_PREFIX)] + assert len(react_modules) == 2, f"Expected 2 ReAct modules, got {len(react_modules)}" + + assert_react_module_detected( + captured_base_program=captured_base_program, + module_path="workflow.coordinator", + expected_tools={"search": "Search tool"} + ) + assert_react_module_detected( + captured_base_program=captured_base_program, + module_path="workflow.researcher", + expected_tools={"analyze": "Analysis tool"} + ) + assert_regular_module_detected( + captured_base_program=captured_base_program, + module_key="workflow.summarizer.predict" + ) + + +def test_nested_react_orchestrator_worker_detection(monkeypatch): + """Test GEPA detects nested multi-agent system with 3 separate ReAct modules. + + Tests complex nested structure: + - Orchestrator: multi_agent.orchestrator (has analyst + researcher as tools) + - Analyst worker: multi_agent.analyst (wrapped as tool for orchestrator) + - Researcher worker: multi_agent.researcher (wrapped as tool for orchestrator) + + Validates: + - All 3 ReAct modules detected with FULL paths + - Each module has its own tools detected + - No path truncation causes module merging + """ + from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX + + captured_base_program = setup_capture_for_base_program(monkeypatch) + program = create_orchestrator_with_workers_program() + + optimizer, trainset = create_gepa_optimizer_for_detection() + + try: + optimizer.compile(program, trainset=trainset, valset=trainset) + except Exception: + pass + + assert f"{REACT_MODULE_PREFIX}:multi_agent.orchestrator" in captured_base_program + assert f"{REACT_MODULE_PREFIX}:multi_agent.analyst" in captured_base_program + assert f"{REACT_MODULE_PREFIX}:multi_agent.researcher" in captured_base_program + + react_modules = [k for k in captured_base_program.keys() if k.startswith(REACT_MODULE_PREFIX)] + assert len(react_modules) == 3, f"Expected 3 ReAct modules, got {len(react_modules)}" + + assert_react_module_detected( + captured_base_program=captured_base_program, + module_path="multi_agent.orchestrator", + expected_tools={"search": "Search tool", "analyst": "Use analyst", "researcher": "Use researcher"} + ) + assert_react_module_detected( + captured_base_program=captured_base_program, + module_path="multi_agent.analyst", + expected_tools={"analyze": "Analyze data"} + ) + assert_react_module_detected( + captured_base_program=captured_base_program, + module_path="multi_agent.researcher", + expected_tools={"research": "Research topic"} + ) + + +def test_build_program_single_react(monkeypatch): + """Test build_program applies optimizations to single top-level ReAct module.""" + from dspy.teleprompt.gepa.gepa_utils import DspyAdapter + + captured_base_program = setup_capture_for_base_program(monkeypatch) + program = create_single_react_program() + + optimizer, trainset = create_gepa_optimizer_for_detection() + + try: + optimizer.compile(program, trainset=trainset, valset=trainset) + except Exception: + pass + + # Mock optimized candidate + optimized_candidate = dict(captured_base_program) + mock_optimized_react_module( + optimized_candidate=optimized_candidate, + module_path="", + react_instruction="OPTIMIZED: React instruction", + extract_instruction="OPTIMIZED: Extract instruction", + tool_descriptions={ + "search": { + "desc": "OPTIMIZED: Search description", + "arg_desc": {"query": "OPTIMIZED: Search query param"} + }, + "calc": { + "desc": "OPTIMIZED: Calc description", + "arg_desc": {"expr": "OPTIMIZED: Math expression param"} + } + } + ) + + # Build program + adapter = DspyAdapter( + student_module=program, + metric_fn=simple_metric_for_reconstruction, + feedback_map={}, + optimize_react_components=True + ) + rebuilt_program = adapter.build_program(optimized_candidate) + + # Assert updates applied + assert_react_module_updated( + react_module=rebuilt_program, + expected_react_instruction="OPTIMIZED: React instruction", + expected_extract_instruction="OPTIMIZED: Extract instruction", + expected_tool_descriptions={ + "search": { + "desc": "OPTIMIZED: Search description", + "arg_desc": {"query": "OPTIMIZED: Search query param"} + }, + "calc": { + "desc": "OPTIMIZED: Calc description", + "arg_desc": {"expr": "OPTIMIZED: Math expression param"} + } + } + ) + + # Verify original unchanged + assert program.react.signature.instructions != "OPTIMIZED: React instruction" + + +def test_build_program_multi_react_workflow(monkeypatch): + """Test build_program applies optimizations to mixed ReAct + non-ReAct workflow.""" + from dspy.teleprompt.gepa.gepa_utils import DspyAdapter + + captured_base_program = setup_capture_for_base_program(monkeypatch) + program = create_multi_react_workflow_program() + + optimizer, trainset = create_gepa_optimizer_for_detection() + + try: + optimizer.compile(program, trainset=trainset, valset=trainset) + except Exception: + pass + + # Mock optimized candidate + optimized_candidate = dict(captured_base_program) + + mock_optimized_react_module( + optimized_candidate=optimized_candidate, + module_path="workflow.coordinator", + react_instruction="OPTIMIZED: Coordinator react", + extract_instruction="OPTIMIZED: Coordinator extract", + tool_descriptions={ + "search": { + "desc": "OPTIMIZED: Search tool", + "arg_desc": {"query": "OPTIMIZED: Coordinator search query"} + } + } + ) + + mock_optimized_react_module( + optimized_candidate=optimized_candidate, + module_path="workflow.researcher", + react_instruction="OPTIMIZED: Researcher react", + extract_instruction="OPTIMIZED: Researcher extract", + tool_descriptions={ + "analyze": { + "desc": "OPTIMIZED: Analyze tool", + "arg_desc": {"data": "OPTIMIZED: Data to analyze"} + } + } + ) + + # Optimize summarizer (non-ReAct ChainOfThought) + optimized_candidate["workflow.summarizer.predict"] = "OPTIMIZED: Summarizer instruction" + + # Build program + adapter = DspyAdapter( + student_module=program, + metric_fn=simple_metric_for_reconstruction, + feedback_map={}, + optimize_react_components=True + ) + rebuilt_program = adapter.build_program(optimized_candidate) + + # Assert ReAct modules updated + assert_react_module_updated( + react_module=rebuilt_program.workflow.coordinator, + expected_react_instruction="OPTIMIZED: Coordinator react", + expected_extract_instruction="OPTIMIZED: Coordinator extract", + expected_tool_descriptions={ + "search": { + "desc": "OPTIMIZED: Search tool", + "arg_desc": {"query": "OPTIMIZED: Coordinator search query"} + } + } + ) + + assert_react_module_updated( + react_module=rebuilt_program.workflow.researcher, + expected_react_instruction="OPTIMIZED: Researcher react", + expected_extract_instruction="OPTIMIZED: Researcher extract", + expected_tool_descriptions={ + "analyze": { + "desc": "OPTIMIZED: Analyze tool", + "arg_desc": {"data": "OPTIMIZED: Data to analyze"} + } + } + ) + + # Assert non-ReAct module updated + assert_regular_module_updated( + predictor=rebuilt_program.workflow.summarizer.predict, + expected_instruction="OPTIMIZED: Summarizer instruction" + ) + + # Verify original unchanged + assert program.workflow.coordinator.react.signature.instructions != "OPTIMIZED: Coordinator react" + + +def test_build_program_orchestrator_with_workers(monkeypatch): + """Test build_program applies optimizations to orchestrator with worker ReAct modules.""" + from dspy.teleprompt.gepa.gepa_utils import DspyAdapter + + captured_base_program = setup_capture_for_base_program(monkeypatch) + program = create_orchestrator_with_workers_program() + + optimizer, trainset = create_gepa_optimizer_for_detection() + + try: + optimizer.compile(program, trainset=trainset, valset=trainset) + except Exception: + pass + + # Mock optimized candidate + optimized_candidate = dict(captured_base_program) + + mock_optimized_react_module( + optimized_candidate=optimized_candidate, + module_path="multi_agent.orchestrator", + react_instruction="OPTIMIZED: Orchestrator react", + extract_instruction="OPTIMIZED: Orchestrator extract", + tool_descriptions={ + "search": { + "desc": "OPTIMIZED: Search tool", + "arg_desc": {"query": "OPTIMIZED: Query param"} + } + } + ) + + mock_optimized_react_module( + optimized_candidate=optimized_candidate, + module_path="multi_agent.analyst", + react_instruction="OPTIMIZED: Analyst react", + extract_instruction="OPTIMIZED: Analyst extract", + tool_descriptions={"analyze": {"desc": "OPTIMIZED: Analyze tool"}} + ) + + mock_optimized_react_module( + optimized_candidate=optimized_candidate, + module_path="multi_agent.researcher", + react_instruction="OPTIMIZED: Researcher react", + extract_instruction="OPTIMIZED: Researcher extract", + tool_descriptions={"research": {"desc": "OPTIMIZED: Research tool"}} + ) + + # Build program + adapter = DspyAdapter( + student_module=program, + metric_fn=simple_metric_for_reconstruction, + feedback_map={}, + optimize_react_components=True + ) + rebuilt_program = adapter.build_program(optimized_candidate) + + # Assert all modules updated + assert_react_module_updated( + react_module=rebuilt_program.multi_agent.orchestrator, + expected_react_instruction="OPTIMIZED: Orchestrator react", + expected_extract_instruction="OPTIMIZED: Orchestrator extract", + expected_tool_descriptions={ + "search": { + "desc": "OPTIMIZED: Search tool", + "arg_desc": {"query": "OPTIMIZED: Query param"} + } + } + ) + + assert_react_module_updated( + react_module=rebuilt_program.multi_agent.analyst, + expected_react_instruction="OPTIMIZED: Analyst react", + expected_extract_instruction="OPTIMIZED: Analyst extract", + expected_tool_descriptions={"analyze": {"desc": "OPTIMIZED: Analyze tool"}} + ) + + assert_react_module_updated( + react_module=rebuilt_program.multi_agent.researcher, + expected_react_instruction="OPTIMIZED: Researcher react", + expected_extract_instruction="OPTIMIZED: Researcher extract", + expected_tool_descriptions={"research": {"desc": "OPTIMIZED: Research tool"}} + ) + + # Verify original unchanged + assert program.multi_agent.orchestrator.react.signature.instructions != "OPTIMIZED: Orchestrator react" + + +def assert_reflective_example_has_trajectory(actual_example, expected_iterations, answer): + """Assert reflective dataset captured complete trajectory without duplicates. + + Validates: + - All iterations present (thought_0, thought_1, ..., thought_N) + - No duplicate/extra iterations (no thought_(N+1)) + - Expected answer in outputs + - Works for any signature (question→answer, data→analysis, etc.) + + Catches bugs: + - Wrong predictor used (react vs extract.predict) → incomplete trajectory + - Path truncation → wrong module's trajectory captured + """ + # Should have the three main sections + assert "Inputs" in actual_example + assert "Generated Outputs" in actual_example + assert "Feedback" in actual_example + + # Validate Inputs + inputs = actual_example["Inputs"] + # Don't assume "question" - could be "data", "topic", etc depending on module signature + # Just check trajectory exists + assert "trajectory" in inputs + + # Validate trajectory has expected structure and values + trajectory_str = inputs["trajectory"] + num_iterations = len(expected_iterations) + + # Check all expected thoughts are present + for i, (thought, _tool_name, _tool_args) in enumerate(expected_iterations): + assert thought in trajectory_str, f"Trajectory should contain thought_{i}: {thought}" + assert f"thought_{i}" in trajectory_str + assert f"tool_name_{i}" in trajectory_str + assert f"observation_{i}" in trajectory_str + + # NO extra iterations (validates no duplicates) + assert f"thought_{num_iterations}" not in trajectory_str, \ + f"Should not have duplicate iteration {num_iterations}" + + # Validate Generated Outputs contain the expected answer + outputs = actual_example["Generated Outputs"] + # Answer could be in "answer", "analysis", "findings", etc depending on module signature + # Just check the expected answer value appears somewhere in the outputs + output_str = str(outputs) + assert answer in output_str, f"Expected answer '{answer}' not found in outputs: {outputs}" + + # Validate Feedback exists + assert isinstance(actual_example["Feedback"], str) + assert len(actual_example["Feedback"]) > 0 + + +def test_make_reflective_dataset_single_react(): + """Test reflective dataset captures complete trajectory for single ReAct module.""" + from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX, DspyAdapter + + program = create_single_react_program() + + expected_iterations = [ + ("I should search", "search", {"query": "test"}), + ("Done", "finish", {}) + ] + expected_answer = "result" + + lm = DummyLM([ + {"next_thought": "I should search", "next_tool_name": "search", "next_tool_args": {"query": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Based on search", "answer": "result"}, + ] * 10) + dspy.settings.configure(lm=lm) + + adapter = DspyAdapter( + student_module=program, + metric_fn=simple_metric_for_reconstruction, + feedback_map={"react": simple_feedback}, + optimize_react_components=True + ) + + trainset = [Example(question="test", answer="result").with_inputs("question")] + eval_batch = adapter.evaluate(batch=trainset, candidate={}, capture_traces=True) + + result = adapter.make_reflective_dataset( + candidate={}, + eval_batch=eval_batch, + components_to_update=[REACT_MODULE_PREFIX] + ) + + assert REACT_MODULE_PREFIX in result + examples = result[REACT_MODULE_PREFIX] + assert len(examples) == 1, f"Should have 1 reflective example, got {len(examples)}" + + assert_reflective_example_has_trajectory( + actual_example=examples[0], + expected_iterations=expected_iterations, + answer=expected_answer + ) + +def test_make_reflective_dataset_orchestrator_with_workers(): + """Test reflective dataset for multi-agent system with 3 ReAct modules. + + Tests full path preservation in complex nested system: + - Orchestrator: multi_agent.orchestrator (3 iterations) + - Analyst: multi_agent.analyst (2 iterations) + - Researcher: multi_agent.researcher (2 iterations) + + Validates each module's trajectory captured separately with correct iteration counts. + """ + from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX, DspyAdapter + + program = create_orchestrator_with_workers_program() + + orchestrator_iterations = [ + ("Let me use the analyst", "analyst", {"data": "test"}), + ("Now let me use the researcher", "researcher", {"topic": "test"}), + ("Done", "finish", {}) + ] + + analyst_iterations = [ + ("Analyzing the data", "analyze", {"data": "test"}), + ("Done", "finish", {}) + ] + + researcher_iterations = [ + ("Researching the topic", "research", {"topic": "test"}), + ("Done", "finish", {}) + ] + + lm = DummyLM([ + {"next_thought": "Let me use the analyst", "next_tool_name": "analyst", "next_tool_args": {"data": "test"}}, + {"next_thought": "Analyzing the data", "next_tool_name": "analyze", "next_tool_args": {"data": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Analysis complete", "analysis": "analyzed_data"}, + {"next_thought": "Now let me use the researcher", "next_tool_name": "researcher", "next_tool_args": {"topic": "test"}}, + {"next_thought": "Researching the topic", "next_tool_name": "research", "next_tool_args": {"topic": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Research complete", "findings": "research_findings"}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Orchestration complete", "answer": "result"}, + ] * 10) + dspy.settings.configure(lm=lm) + + adapter = DspyAdapter( + student_module=program, + metric_fn=simple_metric_for_reconstruction, + feedback_map={ + "multi_agent.orchestrator.react": simple_feedback, + "multi_agent.analyst.react": simple_feedback, + "multi_agent.researcher.react": simple_feedback, + }, + optimize_react_components=True + ) + + trainset = [Example(question="test", answer="result").with_inputs("question")] + eval_batch = adapter.evaluate(batch=trainset, candidate={}, capture_traces=True) + + result = adapter.make_reflective_dataset( + candidate={}, + eval_batch=eval_batch, + components_to_update=[ + f"{REACT_MODULE_PREFIX}:multi_agent.orchestrator", + f"{REACT_MODULE_PREFIX}:multi_agent.analyst", + f"{REACT_MODULE_PREFIX}:multi_agent.researcher" + ] + ) + + assert f"{REACT_MODULE_PREFIX}:multi_agent.orchestrator" in result + assert f"{REACT_MODULE_PREFIX}:multi_agent.analyst" in result + assert f"{REACT_MODULE_PREFIX}:multi_agent.researcher" in result + assert len(result) == 3 + assert len(result[f"{REACT_MODULE_PREFIX}:multi_agent.orchestrator"]) == 1 + assert len(result[f"{REACT_MODULE_PREFIX}:multi_agent.analyst"]) == 1 + assert len(result[f"{REACT_MODULE_PREFIX}:multi_agent.researcher"]) == 1 + + orch_example = result[f"{REACT_MODULE_PREFIX}:multi_agent.orchestrator"][0] + assert_reflective_example_has_trajectory(orch_example, orchestrator_iterations, "result") + assert "question" in orch_example["Inputs"] + assert "answer" in orch_example["Generated Outputs"] + assert "analyst" in orch_example["Inputs"]["trajectory"] + + analyst_example = result[f"{REACT_MODULE_PREFIX}:multi_agent.analyst"][0] + assert_reflective_example_has_trajectory(analyst_example, analyst_iterations, "analyzed_data") + assert "data" in analyst_example["Inputs"] + assert "analysis" in analyst_example["Generated Outputs"] + assert "Analysis:" in analyst_example["Inputs"]["trajectory"] + + researcher_example = result[f"{REACT_MODULE_PREFIX}:multi_agent.researcher"][0] + assert_reflective_example_has_trajectory(researcher_example, researcher_iterations, "research_findings") + assert "topic" in researcher_example["Inputs"] + assert "findings" in researcher_example["Generated Outputs"] + assert "Research:" in researcher_example["Inputs"]["trajectory"] + +