From 22df91d8f132014d14999946e6cbbbf7f3cc87b4 Mon Sep 17 00:00:00 2001 From: Krista Opsahl-Ong Date: Sat, 3 May 2025 15:28:51 -0400 Subject: [PATCH 1/5] adding in error messages & timeout for user permission message --- dspy/teleprompt/mipro_optimizer_v2.py | 67 +++++++++++++++++++-------- 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/dspy/teleprompt/mipro_optimizer_v2.py b/dspy/teleprompt/mipro_optimizer_v2.py index d650d7da71..3b33a25306 100644 --- a/dspy/teleprompt/mipro_optimizer_v2.py +++ b/dspy/teleprompt/mipro_optimizer_v2.py @@ -3,11 +3,14 @@ import textwrap from collections import defaultdict from typing import Any, Callable, Dict, List, Literal, Optional, Tuple +import select +import sys +import time import numpy as np import optuna from optuna.distributions import CategoricalDistribution - +import math import dspy from dspy.evaluate.evaluate import Evaluate from dspy.propose import GroundedProposer @@ -53,10 +56,8 @@ def __init__( teacher_settings: Dict = {}, max_bootstrapped_demos: int = 4, max_labeled_demos: int = 4, - auto: Optional[Literal["light", "medium", "heavy"]] = "medium", - num_candidates: int = 10, - num_fewshot_candidates: Optional[int] = None, - num_instruct_candidates: Optional[int] = None, + auto: Optional[Literal["light", "medium", "heavy"]] = "light", + num_candidates: Optional[int] = None, num_threads: Optional[int] = None, max_errors: int = 10, seed: int = 9, @@ -71,9 +72,9 @@ def __init__( if auto not in allowed_modes: raise ValueError(f"Invalid value for auto: {auto}. Must be one of {allowed_modes}.") self.auto = auto - - self.num_fewshot_candidates = num_fewshot_candidates or num_candidates - self.num_instruct_candidates = num_instruct_candidates or num_candidates + self.num_fewshot_candidates = num_candidates + self.num_instruct_candidates = num_candidates + self.num_candidates = num_candidates self.metric = metric self.init_temperature = init_temperature self.task_model = task_model if task_model else dspy.settings.lm @@ -99,7 +100,7 @@ def compile( trainset: List, teacher: Any = None, valset: Optional[List] = None, - num_trials: int = 30, + num_trials: Optional[int] = None, max_bootstrapped_demos: Optional[int] = None, max_labeled_demos: Optional[int] = None, seed: Optional[int] = None, @@ -114,6 +115,21 @@ def compile( requires_permission_to_run: bool = True, provide_traceback: Optional[bool] = None, ) -> Any: + + zeroshot_opt = (self.max_bootstrapped_demos == 0) and (self.max_labeled_demos == 0) + + # If auto is None, and num_trials is not provided (but num_candidates is), raise an error that suggests a good num_trials value + if self.auto is None and (self.num_candidates is not None and num_trials is None): + raise ValueError(f"If auto is None, num_trials must also be provided. Given num_candidates={self.num_candidates}, we'd recommend setting num_trials to ~{self._set_num_trials_from_num_candidates(student, zeroshot_opt, self.num_candidates)}.") + + # If auto is None, and num_candidates or num_trials is None, raise an error + if self.auto is None and (self.num_candidates is None or num_trials is None): + raise ValueError("If auto is None, num_candidates must also be provided.") + + # If auto is provided, and either num_candidates or num_trials is not None, raise an error + if self.auto is not None and (self.num_candidates is not None or num_trials is not None): + raise ValueError("If auto is not None, num_candidates and num_trials cannot be set, since they would be overrided by the auto settings. Please either set auto to None, or do not specify num_candidates and num_trials.") + # Set random seeds seed = seed or self.seed self._set_random_seeds(seed) @@ -128,7 +144,6 @@ def compile( trainset, valset = self._set_and_validate_datasets(trainset, valset) # Set hyperparameters based on run mode (if set) - zeroshot_opt = (self.max_bootstrapped_demos == 0) and (self.max_labeled_demos == 0) num_trials, valset, minibatch = self._set_hyperparams_from_run_mode( student, num_trials, minibatch, zeroshot_opt, valset ) @@ -204,6 +219,15 @@ def _set_random_seeds(self, seed): self.rng = random.Random(seed) np.random.seed(seed) + def _set_num_trials_from_num_candidates(self, program, zeroshot_opt, num_candidates): + num_vars = len(program.predictors()) + if not zeroshot_opt: + num_vars *= 2 # Account for few-shot examples + instruction variables + # Trials = MAX(c*M*log(N), c=2, 3/2*N) + num_trials = int(max(2 * num_vars * np.log2(num_candidates), 1.5 * num_candidates)) + + return num_trials + def _set_hyperparams_from_run_mode( self, program: Any, @@ -226,11 +250,7 @@ def _set_hyperparams_from_run_mode( self.num_instruct_candidates = auto_settings["n"] if zeroshot_opt else int(auto_settings["n"] * 0.5) self.num_fewshot_candidates = auto_settings["n"] - num_vars = len(program.predictors()) - if not zeroshot_opt: - num_vars *= 2 # Account for few-shot examples + instruction variables - # Trials = MAX(c*M*log(N), c=2, 3/2*N) - num_trials = max(2 * num_vars * np.log(auto_settings["n"]), 1.5 * auto_settings["n"]) + num_trials = self._set_num_trials_from_num_candidates(program, zeroshot_opt, auto_settings["n"]) return num_trials, valset, minibatch @@ -353,6 +373,7 @@ def _get_user_confirmation( user_confirmation_message = textwrap.dedent( f"""\ To proceed with the execution of this program, please confirm by typing {BLUE}'y'{ENDC} for yes or {BLUE}'n'{ENDC} for no. + If no input is received within 20 seconds, the program will proceed automatically. If you would like to bypass this confirmation step in future executions, set the {YELLOW}`requires_permission_to_run`{ENDC} flag to {YELLOW}`False`{ENDC} when calling compile. @@ -360,10 +381,18 @@ def _get_user_confirmation( """ ) - user_input = ( - input(f"{user_message}\n{user_confirmation_message}\nDo you wish to continue? (y/n): ").strip().lower() - ) - return user_input == "y" + print(f"{user_message}\n{user_confirmation_message}\nDo you wish to continue? (y/n): ", end='', flush=True) + + # Wait for input with timeout + start_time = time.time() + while time.time() - start_time < 20: + if select.select([sys.stdin], [], [], 0.1)[0]: + user_input = sys.stdin.readline().strip().lower() + return user_input == "y" + time.sleep(0.1) + + print("\nNo input received within 20 seconds. Proceeding with execution...") + return True def _bootstrap_fewshot_examples(self, program: Any, trainset: List, seed: int, teacher: Any) -> Optional[List]: logger.info("\n==> STEP 1: BOOTSTRAP FEWSHOT EXAMPLES <==") From f2408fab373fc6aa6980605c03994fffa2a582b1 Mon Sep 17 00:00:00 2001 From: Krista Opsahl-Ong Date: Mon, 2 Jun 2025 14:11:56 -0400 Subject: [PATCH 2/5] wip --- dspy/teleprompt/simba_utils.py | 89 ++++++++++++++++++++++++++-------- pyproject.toml | 2 +- 2 files changed, 69 insertions(+), 22 deletions(-) diff --git a/dspy/teleprompt/simba_utils.py b/dspy/teleprompt/simba_utils.py index 3765a33f1f..3a23eb6d11 100644 --- a/dspy/teleprompt/simba_utils.py +++ b/dspy/teleprompt/simba_utils.py @@ -3,20 +3,40 @@ import inspect import logging import textwrap +import re from dspy.adapters.utils import get_field_description_string from dspy.signatures import InputField, OutputField -from typing import Callable +from typing import Callable, Optional, Dict, Any logger = logging.getLogger(__name__) +def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: Optional[Dict] = None): + + models = [] + if teacher_settings: + with dspy.settings.context(trace=[], **teacher_settings): + lm = dspy.settings.lm + models.append(lm) -def prepare_models_for_resampling(program: dspy.Module, n: int): lm = program.get_lm() or dspy.settings.lm - temps = [lm.kwargs["temperature"]] + [0.5 + i * (0.5 / n) for i in range(n)] - temps = list(dict.fromkeys(temps))[:n] - return [lm.copy(temperature=t) for t in temps] + # Check to see if our model is a reasoning model, which means temp must stay as 1.0 + model_family = lm.model.split("/")[-1].lower() if "/" in lm.model else lm.model.lower() + model_pattern = re.match(r"^o([13])(?:-mini)?", model_family) + + if model_pattern: # Vary the seed + start_seed = 0 if "seed" not in lm.kwargs else lm.kwargs["seed"] + seeds = [start_seed + 1 + i for i in range(n-len(models))] + seeds = list(dict.fromkeys(seeds))[:(n-len(models))] + models.extend([lm.copy(seed=seed) for seed in seeds]) + else: # Vary the temperature + start_temp = 0 if "temperature" not in lm.kwargs else lm.kwargs["temperature"] + temps = [start_temp + 0.5 + i * (0.5 / n) for i in range(n-len(models))] + temps = list(dict.fromkeys(temps))[:(n-len(models))] + models.extend([lm.copy(temperature=t) for t in temps]) + + return models def wrap_program(program: dspy.Module, metric: Callable): def wrapped_program(example): @@ -25,33 +45,53 @@ def wrapped_program(example): try: prediction = program(**example.inputs()) except Exception as e: - print(e) + logger.info(e) trace = dspy.settings.trace.copy() + output = None + score = 0.0 + output_metadata = {} + try: - score = metric(example, prediction) + output = metric(example, prediction) + if isinstance(output, (int, float)): + score = output + elif isinstance(output, dspy.Prediction): + if not hasattr(output, 'score'): + raise ValueError("dspy.Prediction must contain a 'score' attribute") + score = output.score + # Just extract fields from _store, excluding 'score' + output_metadata = { + k: v for k, v in output._store.items() if k != "score" + } except Exception as e: - print(e) + logger.info(e) - # Include the `example` in the output for subsequent usage in buckets/strategies. return { "prediction": prediction, "trace": trace, "score": score, - "example": example + "example": example, + "output_metadata": output_metadata } return wrapped_program - - def append_a_demo(demo_input_field_maxlen): def append_a_demo_(bucket, system, **kwargs): predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"] + batch_10p_score = kwargs["batch_10p_score"] - trace = bucket[0]["trace"] + logger.info(f"Appending a demo with max length {demo_input_field_maxlen}") + + good = bucket[0] + trace = good["trace"] name2demo = {} + if good["score"] <= batch_10p_score: + logger.info(f"Skipping appending a demo as good score {good['score']} is at or below the 10th percentile (<={batch_10p_score}).") + return False + for step in trace: predictor, _inputs, _outputs = step @@ -62,28 +102,29 @@ def append_a_demo_(bucket, system, **kwargs): demo = dspy.Example(augmented=True, **_inputs, **_outputs) name = predictor2name[id(predictor)] name2demo[name] = demo # keep the last demo for each predictor - for name, demo in name2demo.items(): predictor = name2predictor[name] predictor.demos.append(demo) - logger.info(f"Added {len(name2demo)} demos (one each) across all predictors.") + logger.info(f"Added {len(name2demo)} demos (one each) across all predictors. Each predictor now has {len(predictor.demos)} demos total.") return True return append_a_demo_ def append_a_rule(bucket, system, **kwargs): + # Read in kwargs predictor2name = kwargs["predictor2name"] batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"] + prompt_model = kwargs["prompt_model"] or dspy.settings.lm module_names = [name for name, _ in system.named_predictors()] good, bad = bucket[0], bucket[-1] example = good["example"] - if good["score"] < batch_10p_score or bad["score"] > batch_90p_score: - logger.info(f"Skipping rule generation as good score {good['score']} is below the 10th percentile " - f"*or* bad score {bad['score']} is above the 90th percentile.") + if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score: + logger.info(f"Skipping rule generation as good score {good['score']} is at or below the 10th percentile (<={batch_10p_score}) " + f"*or* bad score {bad['score']} is at or above the 90th percentile, (>={batch_90p_score}).") return False if good["score"] <= bad["score"]: @@ -116,12 +157,17 @@ def append_a_rule(bucket, system, **kwargs): worse_program_outputs=dict(bad["prediction"] or {}), worse_reward_value=bad["score"], better_reward_value=good["score"], + worse_reward_info=bad["output_metadata"], + better_reward_info=good["output_metadata"], module_names=module_names, ) kwargs = {k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2) for k, v in kwargs.items()} - advice = dspy.Predict(OfferFeedback)(**kwargs).module_advice + + with dspy.settings.context(trace=[], lm=prompt_model): + advice_program = dspy.Predict(OfferFeedback) + advice = advice_program(**kwargs).module_advice for name, predictor in system.named_predictors(): if name in advice: @@ -155,11 +201,13 @@ class OfferFeedback(dspy.Signature): ) worse_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") worse_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") + worse_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") better_program_trajectory: str = InputField( desc="The trajectory of the program's execution, showing each module's I/O" ) better_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") better_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") + better_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") module_names: list[str] = InputField(desc="The names of the modules in the program, for which we seek advice") discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did") module_advice: dict[str, str] = OutputField( @@ -169,7 +217,6 @@ class OfferFeedback(dspy.Signature): "like the successful trajectory rather than the lower-scoring trajectory." ) - def inspect_modules(program): separator = "-" * 80 output = [separator] @@ -209,4 +256,4 @@ def recursive_mask(o): return tuple(recursive_mask(v) for v in o) # Otherwise, replace it with a placeholder string (or use repr(o)). else: - return f"" + return f"" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 78c1e637d8..6eb315a221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ dependencies = [ "backoff>=2.2", "joblib~=1.3", - "openai>=0.28.1", + "openai>=0.28.1,<=1.67.0", "pandas>=2.1.1", "regex>=2023.10.3", "ujson>=5.8.0", From 8b3b6e60afdabd91505002a1463354c25db72d9a Mon Sep 17 00:00:00 2001 From: Krista Opsahl-Ong Date: Mon, 2 Jun 2025 14:29:57 -0400 Subject: [PATCH 3/5] wip --- dspy/teleprompt/mipro_optimizer_v2.py | 37 +-------------------------- 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/dspy/teleprompt/mipro_optimizer_v2.py b/dspy/teleprompt/mipro_optimizer_v2.py index 879d25711a..05dd4eec1a 100644 --- a/dspy/teleprompt/mipro_optimizer_v2.py +++ b/dspy/teleprompt/mipro_optimizer_v2.py @@ -7,14 +7,9 @@ from typing import TYPE_CHECKING from collections import defaultdict from typing import Any, Callable, Dict, List, Literal, Optional, Tuple -import select -import sys -import time import numpy as np -import optuna -from optuna.distributions import CategoricalDistribution -import math + import dspy from dspy.evaluate.evaluate import Evaluate from dspy.propose import GroundedProposer @@ -122,23 +117,6 @@ def compile( requires_permission_to_run: bool = True, provide_traceback: Optional[bool] = None, ) -> Any: -<<<<<<< HEAD - - zeroshot_opt = (self.max_bootstrapped_demos == 0) and (self.max_labeled_demos == 0) - - # If auto is None, and num_trials is not provided (but num_candidates is), raise an error that suggests a good num_trials value - if self.auto is None and (self.num_candidates is not None and num_trials is None): - raise ValueError(f"If auto is None, num_trials must also be provided. Given num_candidates={self.num_candidates}, we'd recommend setting num_trials to ~{self._set_num_trials_from_num_candidates(student, zeroshot_opt, self.num_candidates)}.") - - # If auto is None, and num_candidates or num_trials is None, raise an error - if self.auto is None and (self.num_candidates is None or num_trials is None): - raise ValueError("If auto is None, num_candidates must also be provided.") - - # If auto is provided, and either num_candidates or num_trials is not None, raise an error - if self.auto is not None and (self.num_candidates is not None or num_trials is not None): - raise ValueError("If auto is not None, num_candidates and num_trials cannot be set, since they would be overrided by the auto settings. Please either set auto to None, or do not specify num_candidates and num_trials.") - -======= zeroshot_opt = (self.max_bootstrapped_demos == 0) and (self.max_labeled_demos == 0) @@ -154,7 +132,6 @@ def compile( if self.auto is not None and (self.num_candidates is not None or num_trials is not None): raise ValueError("If auto is not None, num_candidates and num_trials cannot be set, since they would be overrided by the auto settings. Please either set auto to None, or do not specify num_candidates and num_trials.") ->>>>>>> 82d3878b12b4632b3c549d9c4e85eaef360ad1f7 # Set random seeds seed = seed or self.seed self._set_random_seeds(seed) @@ -252,11 +229,7 @@ def _set_num_trials_from_num_candidates(self, program, zeroshot_opt, num_candida num_trials = int(max(2 * num_vars * np.log2(num_candidates), 1.5 * num_candidates)) return num_trials -<<<<<<< HEAD - -======= ->>>>>>> 82d3878b12b4632b3c549d9c4e85eaef360ad1f7 def _set_hyperparams_from_run_mode( self, program: Any, @@ -411,11 +384,7 @@ def _get_user_confirmation( ) print(f"{user_message}\n{user_confirmation_message}\nDo you wish to continue? (y/n): ", end='', flush=True) -<<<<<<< HEAD - -======= ->>>>>>> 82d3878b12b4632b3c549d9c4e85eaef360ad1f7 # Wait for input with timeout start_time = time.time() while time.time() - start_time < 20: @@ -423,11 +392,7 @@ def _get_user_confirmation( user_input = sys.stdin.readline().strip().lower() return user_input == "y" time.sleep(0.1) -<<<<<<< HEAD - -======= ->>>>>>> 82d3878b12b4632b3c549d9c4e85eaef360ad1f7 print("\nNo input received within 20 seconds. Proceeding with execution...") return True From 13249668d913031607d6580a6d5d3382c08280cd Mon Sep 17 00:00:00 2001 From: Krista Opsahl-Ong Date: Fri, 11 Jul 2025 15:47:51 -0400 Subject: [PATCH 4/5] adding simba fast + simba utils that allow for field descrip. changes --- dspy/teleprompt/simba_fast.py | 397 +++++++++++++++++++++++++++++++++ dspy/teleprompt/simba_utils.py | 279 +++++++++++++++++++---- 2 files changed, 632 insertions(+), 44 deletions(-) create mode 100644 dspy/teleprompt/simba_fast.py diff --git a/dspy/teleprompt/simba_fast.py b/dspy/teleprompt/simba_fast.py new file mode 100644 index 0000000000..76a3682758 --- /dev/null +++ b/dspy/teleprompt/simba_fast.py @@ -0,0 +1,397 @@ +import dspy +import random +import logging + +import numpy as np +from typing import Callable, Optional, Any, Dict +from dspy.teleprompt.teleprompt import Teleprompter +from dspy.teleprompt.simba_utils import prepare_models_for_resampling, wrap_program, append_a_demo, append_a_rule, update_fields +from dspy.teleprompt.utils import log_token_usage + +logger = logging.getLogger(__name__) + + +# Stochastic Introspective Mini-Batch Ascent +class SIMBAFast(Teleprompter): + def __init__( + self, + *, + metric: Callable, + bsize=32, + num_candidates=6, + max_steps=8, + max_demos=4, + prompt_model: Optional[Any] = None, + teacher_settings: Optional[Dict] = None, + demo_input_field_maxlen=100_000, + num_threads=16, + temperature_for_sampling=0.2, + temperature_for_candidates=0.2, + ): + """ + :param metric: A function (Example, prediction_dict) -> float + :param bsize: mini-batch size + :param num_candidates: how many new candidate programs to produce per iteration + :param max_steps: how many optimization steps to run + :param max_demos: how many demos we allow a predictor to hold before we must drop some + :param demo_input_field_maxlen: how many characters of an input field to keep when building a new demo + :param num_threads: how many threads for run_parallel + :param temperature_for_sampling: temperature used for picking programs for the trajectory-sampling step + :param temperature_for_candidates: temperature used for picking the source program for building new candidates + """ + self.metric = metric + self.bsize = bsize + self.num_candidates = num_candidates + self.max_steps = max_steps + self.max_demos = max_demos + self.prompt_model = prompt_model if prompt_model else dspy.settings.lm + self.teacher_settings = teacher_settings if teacher_settings else {} + self.demo_input_field_maxlen = demo_input_field_maxlen + self.num_threads = num_threads + + self.temperature_for_sampling = temperature_for_sampling + self.temperature_for_candidates = temperature_for_candidates + + if self.max_demos > 0: + # self.strategies = [append_a_demo(demo_input_field_maxlen), append_a_rule] + self.strategies = [update_fields] + else: + self.strategies = [append_a_rule] + + def compile(self, student: dspy.Module, *, trainset: list[dspy.Example], seed: int = 0): + # Basic checks + assert len(trainset) >= self.bsize, f"Trainset too small: {len(trainset)} < {self.bsize}" + + # Initialize RNG + rng = random.Random(seed) + rng_np = np.random.default_rng(seed) + + programs = [] + program_scores = {} + program_batch_idx = {} + next_program_idx = 0 + batch_idx_to_baseline_scores = {} + + # Helper functions + def calc_average_score(prog_idx: int) -> float: + scores = program_scores.get(prog_idx, []) + if not scores: + return 0.0 + return sum(scores) / len(scores) + + def calc_average_adjusted_score(prog_idx: int) -> float: + prog_scores = program_scores.get(prog_idx, []) + baseline_scores = batch_idx_to_baseline_scores.get(program_batch_idx[prog_idx], []) + + # If either list is empty or not the same length, return 0 or handle how you prefer + if not prog_scores or not baseline_scores: + return 0.0 + if len(prog_scores) != len(baseline_scores): + # You can decide how you want to handle mismatch + return 0.0 + + # Elementwise subtraction + adjusted_scores = [p - b for p, b in zip(prog_scores, baseline_scores)] + return sum(adjusted_scores) / len(adjusted_scores) + + def adjusted_top_k_plus_baseline(k: int) -> list[int]: + # Sort all programs by descending average score + scored_programs = sorted(programs, key=lambda p: calc_average_adjusted_score(p.simba_idx), reverse=True) + top_k = [p.simba_idx for p in scored_programs[:k]] + # Ensure baseline=0 is in there: + if 0 not in top_k and len(top_k) > 0: + top_k[-1] = 0 + return list(dict.fromkeys(top_k)) + + def top_k_plus_baseline(k: int) -> list[int]: + # Sort all programs by descending average score + scored_programs = sorted(programs, key=lambda p: calc_average_score(p.simba_idx), reverse=True) + top_k = [p.simba_idx for p in scored_programs[:k]] + # Ensure baseline=0 is in there: + if 0 not in top_k and len(top_k) > 0: + top_k[-1] = 0 + return list(dict.fromkeys(top_k)) + + def softmax_sample(rng_obj: random.Random, program_idxs: list[int], temperature: float) -> int: + if not program_idxs: + raise ValueError("No programs available for softmax sampling.") + + # Unnormalized weights + scores = [calc_average_score(idx) for idx in program_idxs] + exps = [np.exp(s / temperature) for s in scores] + sum_exps = sum(exps) + if sum_exps <= 0: + # Fallback: uniform if all exps are zero + return rng_obj.choice(program_idxs) + + # Weighted random choice + probs = [val / sum_exps for val in exps] + return rng_obj.choices(program_idxs, weights=probs, k=1)[0] + + def register_new_program(prog: dspy.Module, score_list: list[float], batch_idx: int): + nonlocal next_program_idx + next_program_idx += 1 + new_idx = next_program_idx + prog.simba_idx = new_idx + programs.append(prog) + program_scores[new_idx] = score_list + program_batch_idx[new_idx] = batch_idx + + # Initialize the baseline program: index=0 + student = student.deepcopy() + student.simba_idx = 0 + programs.append(student) + program_scores[0] = [] + program_batch_idx[0] = 0 + + winning_programs = [(0,student)] + + # Data shuffling + data_indices = list(range(len(trainset))) + rng.shuffle(data_indices) + instance_idx = 0 + + # Parallel runner + logger.info(f"Creating parallel runner with num_threads: {self.num_threads}") + run_parallel = dspy.Parallel(access_examples=False, num_threads=self.num_threads) + + trial_logs = {} + + # Initialize for hybrid execution reuse + last_batch_outputs = None + + predictor2name = {} + + M = self.max_steps - 1 + N = self.num_candidates + 1 + program_idxs = [0] * N if M < 1 else [round(i * M / (N - 1)) for i in range(N)] + program_idxs = list(dict.fromkeys(program_idxs)) + + # Compute baseline student score on the full trainset + logger.info(f"Evaluating student program on full trainset.") + exec_pairs = [(wrap_program(student, self.metric), ex) for ex in trainset] + full_outputs = run_parallel(exec_pairs) + baseline_scores = [o["score"] for o in full_outputs] + + # Compute average score for the baseline program + avg_baseline_score = sum(baseline_scores) / len(baseline_scores) + logger.info(f"Baseline program (index 0) score: {avg_baseline_score}\n") + + final_candidate_programs = [student] + final_candidate_scores = [avg_baseline_score] + validated_program_outputs = {} # {prog_idx: {example_idx: output_dict}} + + for batch_idx in range(self.max_steps): + trial_logs[batch_idx+1] = {} + + logger.info(f"Starting batch {batch_idx+1} of {self.max_steps}.") + + # STEP 1: Get next batch + if instance_idx + self.bsize > len(trainset): + rng.shuffle(data_indices) + instance_idx = 0 + + batch_indices = data_indices[instance_idx : instance_idx + self.bsize] + batch = [trainset[i] for i in batch_indices] + instance_idx += self.bsize + + # Compute student baseline on batch + batch_idx_to_baseline_scores[batch_idx] = [score for i, score in enumerate(baseline_scores) if i in batch_indices] + + # STEP 2 (or hybrid): Collect execution results for bucket building + models = prepare_models_for_resampling(programs[0], self.num_candidates, self.teacher_settings) + top_programs = top_k_plus_baseline(self.num_candidates) + + exec_pairs = [] + + if batch_idx == 0: + # First round — use full trajectory sampling + for model in models: + for example in batch: + chosen_prog_idx = softmax_sample(rng, top_programs, self.temperature_for_sampling) + candidate_system = programs[chosen_prog_idx].deepcopy() + candidate_system.set_lm(model) + + for name, predictor in candidate_system.named_predictors(): + predictor2name[id(predictor)] = name + + wrapped_candidate_system = wrap_program(candidate_system, self.metric) + exec_pairs.append((wrapped_candidate_system, example)) + + logger.info(f"Sampling program trajectories on {self.bsize} examples x {self.num_candidates} samples.") + outputs = run_parallel(exec_pairs) + else: + outputs = last_batch_outputs.copy() if last_batch_outputs else [] + for prog_idx, prog_cache in validated_program_outputs.items(): + for i in batch_indices: + if i in prog_cache: + outputs.append(prog_cache[i]) + + dspy.settings.lm.history.extend([entry for model in models for entry in model.history]) + + # STEP 3: Sort the training buckets by (max-to-min gap, max score, and max-to-avg gap). + buckets = [] + largest_max_to_avg_gap = float("-inf") + batch_10th_percentile_score = np.percentile([float(o["score"]) for o in outputs], 10) + batch_90th_percentile_score = np.percentile([float(o["score"]) for o in outputs], 90) + + # We'll chunk `outputs` by example index, each chunk has length = num_candidates + for idx, example in enumerate(batch): + # gather all results for this example + bucket = [outputs[i] for i in range(idx, len(outputs), self.bsize)] + bucket.sort(key=lambda x: x["score"], reverse=True) + + max_score = float(bucket[0]["score"]) + min_score = float(bucket[-1]["score"]) + avg_score = sum(x["score"] for x in bucket) / len(bucket) + max_to_min_gap = max_score - min_score + max_to_avg_gap = max_score - avg_score + if max_to_avg_gap > largest_max_to_avg_gap: + largest_max_to_avg_gap = max_to_avg_gap + + buckets.append((bucket, (max_to_min_gap, max_score, max_to_avg_gap))) + + # sort the buckets + buckets.sort(key=lambda x: x[1], reverse=True) + # TODO: if all buckets mave a max_to_min gap of 0 and max score <1.0, then we should do more trajectory sampling + + # Baseline for the batch is just the average of all runs + all_scores_in_this_batch = [o["score"] for o in outputs] + baseline_score = sum(all_scores_in_this_batch) / len(all_scores_in_this_batch) + logger.info(f"Batch {batch_idx+1}: Baseline mini-batch score: {baseline_score}\n") + + # summarize_batch([bucket[0] for bucket in buckets]) + # STEP 4: Build new candidate programs by applying a strategy to some top buckets. + system_candidates = [] + for bucket_idx, (bucket, bucket_stats) in enumerate(buckets): + max_to_min_gap, max_score, max_to_avg_gap = bucket_stats + logger.info( + f"Batch {batch_idx+1}: Processing bucket #{bucket_idx+1}, with max score {max_score}, " + f"max-to-min gap {max_to_min_gap}, and max-to-avg gap {max_to_avg_gap}." + ) + + # pick source program + src_prog_idx = softmax_sample( + rng, top_k_plus_baseline(self.num_candidates), self.temperature_for_candidates + ) + system_candidate = programs[src_prog_idx].deepcopy() + + # Drop some demos from each predictor + name2predictor = {} + num_demos_list = [] + + max_demos_tmp = self.max_demos if self.max_demos > 0 else 3 + + for name, predictor in system_candidate.named_predictors(): + name2predictor[name] = predictor + num_demos_list.append(len(predictor.demos)) + + num_demos = max(num_demos_list) if num_demos_list else 0 + num_demos_to_drop = max(rng_np.poisson(num_demos / max_demos_tmp), int(num_demos >= max_demos_tmp)) + num_demos_to_drop = min(num_demos_to_drop, num_demos) + demos_to_drop = [rng.randrange(num_demos) for _ in range(num_demos_to_drop)] + + for name, predictor in name2predictor.items(): + predictor.demos = [demo for idxd, demo in enumerate(predictor.demos) if idxd not in demos_to_drop] + + # Pick a strategy + strategy = rng.choice(self.strategies) + logger.info( + f"Batch {batch_idx+1}: Invoking strategy: {strategy.__name__}" + + (f", having dropped {num_demos_to_drop} demos per predictor" if num_demos_to_drop else "") + ) + + for name, predictor in system_candidate.named_predictors(): + predictor2name[id(predictor)] = name + + try: + strategy( + bucket, + system_candidate, + predictor2name=predictor2name, + name2predictor=name2predictor, + batch_10p_score=batch_10th_percentile_score, + batch_90p_score=batch_90th_percentile_score, + prompt_model=self.prompt_model, + ) + except Exception as e: + logger.error(f"Strategy failed with error: {e}") + continue + + system_candidates.append(system_candidate) + logger.info("\n") + + if len(system_candidates) >= self.num_candidates: + break + + # STEP 5: Evaluate these new system_candidates on the same mini-batch + logger.info(f"Batch {batch_idx+1}: Evaluating {len(system_candidates)} programs on {self.bsize} examples.") + + exec_pairs = [(wrap_program(sys, self.metric), ex) for sys in system_candidates for ex in batch] + outputs = run_parallel(exec_pairs) + assert len(outputs) == len(exec_pairs) == len(system_candidates) * self.bsize + + # STEP 6: Compute average mini-batch scores for each new candidate + candidate_scores = [] + for idx_cand, cand_sys in enumerate(system_candidates): + start = idx_cand * self.bsize + end = (idx_cand + 1) * self.bsize + sys_scores = [outputs[i]["score"] for i in range(start, end)] + avg_sys_score = sum(sys_scores) / len(sys_scores) + candidate_scores.append(avg_sys_score) + + logger.info( + f"Scores after {batch_idx+1} batches: {candidate_scores}, " + f"Best: {max(candidate_scores) if candidate_scores else 'N/A'}\n" + ) + + trial_logs[batch_idx+1]["batch_scores"] = candidate_scores + + # STEP 7: Select the best among these new ones for "winning" record + if candidate_scores: + best_idx_among_candidates = candidate_scores.index(max(candidate_scores)) + best_program = system_candidates[best_idx_among_candidates] + winning_programs.append((batch_idx+1, best_program.deepcopy())) + + # STEP 8: If it's time for a full evaluation, evaluate the winning program on the full trainset + if batch_idx in program_idxs: + logger.info(f"Batch {batch_idx+1}: Evaluating winning program on full trainset.") + exec_pairs = [(wrap_program(best_program, self.metric), ex) for ex in trainset] + full_outputs = run_parallel(exec_pairs) + scores = [o["score"] for o in full_outputs] + avg_score = sum(scores) / len(scores) + logger.info(f"Batch {batch_idx+1}: Full trainset score: {avg_score}")simb + trial_logs[batch_idx + 1]["train_score"] = avg_score + + final_candidate_programs.append(best_program.deepcopy()) + final_candidate_scores.append(avg_score) + + prog_cache = {i: out for i, out in enumerate(full_outputs)} + validated_program_outputs[best_program.simba_idx] = prog_cache + + # STEP 9: Register all new candidate systems in our global pool + for idx_cand, cand_sys in enumerate(system_candidates): + start = idx_cand * self.bsize + end = (idx_cand + 1) * self.bsize + sys_scores = [outputs[i]["score"] for i in range(start, end)] + register_new_program(cand_sys, sys_scores, batch_idx) + + # Save for hybrid bucket building next round + last_batch_outputs = outputs.copy() + + log_token_usage(trial_logs, batch_idx+1, {"lm": dspy.settings.lm}) + + + best_idx = np.argmax(final_candidate_scores) if final_candidate_scores else 0 + # best_idx = scores.index(max(final_candidate_scores)) if final_candidate_scores else 0 + best_program = final_candidate_programs[best_idx] + logger.info( + f"Final trainset scores: {final_candidate_scores}, Best: {max(final_candidate_scores) if final_candidate_scores else 'N/A'} " + f"(at index {best_idx if final_candidate_scores else 'N/A'})\n\n\n" + ) + # FIXME: Attach all program candidates in decreasing average score to the best program. + best_program.candidate_programs = final_candidate_programs + best_program.winning_programs = winning_programs + best_program.trial_logs = trial_logs + + return best_program diff --git a/dspy/teleprompt/simba_utils.py b/dspy/teleprompt/simba_utils.py index b32f30609d..b36f12a33f 100644 --- a/dspy/teleprompt/simba_utils.py +++ b/dspy/teleprompt/simba_utils.py @@ -1,23 +1,42 @@ +import dspy +import ujson import inspect import logging import textwrap -from typing import Callable +import re -import ujson - -import dspy -from dspy.adapters.utils import get_field_description_string +from dspy.adapters.chat_adapter import enumerate_fields from dspy.signatures import InputField, OutputField +from typing import Callable, Optional, Dict, Any logger = logging.getLogger(__name__) +def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: Optional[Dict] = None): + + models = [] + if teacher_settings: + with dspy.settings.context(trace=[], **teacher_settings): + lm = dspy.settings.lm + models.append(lm) -def prepare_models_for_resampling(program: dspy.Module, n: int): lm = program.get_lm() or dspy.settings.lm - temps = [lm.kwargs["temperature"]] + [0.5 + i * (0.5 / n) for i in range(n)] - temps = list(dict.fromkeys(temps))[:n] - return [lm.copy(temperature=t) for t in temps] + # Check to see if our model is a reasoning model, which means temp must stay as 1.0 + model_family = lm.model.split("/")[-1].lower() if "/" in lm.model else lm.model.lower() + model_pattern = re.match(r"^o([13])(?:-mini)?", model_family) + + if model_pattern: # Vary the seed + start_seed = 0 if "seed" not in lm.kwargs else lm.kwargs["seed"] + seeds = [start_seed + 1 + i for i in range(n-len(models))] + seeds = list(dict.fromkeys(seeds))[:(n-len(models))] + models.extend([lm.copy(seed=seed) for seed in seeds]) + else: # Vary the temperature + start_temp = 0 if "temperature" not in lm.kwargs else lm.kwargs["temperature"] + temps = [start_temp + 0.5 + i * (0.5 / n) for i in range(n-len(models))] + temps = list(dict.fromkeys(temps))[:(n-len(models))] + models.extend([lm.copy(temperature=t) for t in temps]) + + return models def wrap_program(program: dspy.Module, metric: Callable): def wrapped_program(example): @@ -26,33 +45,56 @@ def wrapped_program(example): try: prediction = program(**example.inputs()) except Exception as e: - print(e) + logger.info(e) trace = dspy.settings.trace.copy() + output = None + score = 0.0 + output_metadata = {} + try: - score = metric(example, prediction) + output = metric(example, prediction) + if isinstance(output, (int, float)): + score = output + elif isinstance(output, dspy.Prediction): + if not hasattr(output, 'score'): + raise ValueError("dspy.Prediction must contain a 'score' attribute") + score = output.score + # Just extract fields from _store, excluding 'score' + output_metadata = { + k: v for k, v in output._store.items() if k != "score" + } except Exception as e: - print(e) + logger.info(e) - # Include the `example` in the output for subsequent usage in buckets/strategies. return { "prediction": prediction, "trace": trace, "score": score, - "example": example + "example": example, + "output_metadata": output_metadata } return wrapped_program - - def append_a_demo(demo_input_field_maxlen): def append_a_demo_(bucket, system, **kwargs): predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"] + batch_10p_score = kwargs["batch_10p_score"] - trace = bucket[0]["trace"] + logger.info(f"Appending a demo with max length {demo_input_field_maxlen}") + + good = bucket[0] + trace = good["trace"] name2demo = {} + # if good["score"] < batch_10p_score: + # logger.info(f"Skipping appending a demo as good score {good['score']} is below the 10th percentile.") + # return False + if good["score"] <= batch_10p_score: + logger.info(f"Skipping appending a demo as good score {good['score']} is at or below the 10th percentile.") + return False + for step in trace: predictor, _inputs, _outputs = step @@ -63,28 +105,47 @@ def append_a_demo_(bucket, system, **kwargs): demo = dspy.Example(augmented=True, **_inputs, **_outputs) name = predictor2name[id(predictor)] name2demo[name] = demo # keep the last demo for each predictor - for name, demo in name2demo.items(): predictor = name2predictor[name] predictor.demos.append(demo) - logger.info(f"Added {len(name2demo)} demos (one each) across all predictors.") + logger.info(f"Added {len(name2demo)} demos (one each) across all predictors. Each predictor now has {len(predictor.demos)} demos total.") return True - + return append_a_demo_ +def get_good_and_bad_examples(bucket): + """Get good and bad examples from bucket + """ + good, bad = bucket[0], bucket[-1] + return good["example"], bad["example"] + +def get_good_and_bad_trajectories(good_example, bad_example, predictor2name): + """Get good and bad trajectories from examples + """ + good_trajectory = [ + dict(module_name=predictor2name[id(p)], inputs=i, outputs=dict(o)) + for p, i, o in good_example["trace"] + ] + bad_trajectory = [ + dict(module_name=predictor2name[id(p)], inputs=i, outputs=dict(o)) + for p, i, o in bad_example["trace"] + ] + return good_trajectory, bad_trajectory def append_a_rule(bucket, system, **kwargs): + # Read in kwargs predictor2name = kwargs["predictor2name"] batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"] + prompt_model = kwargs["prompt_model"] or dspy.settings.lm module_names = [name for name, _ in system.named_predictors()] good, bad = bucket[0], bucket[-1] example = good["example"] - if good["score"] < batch_10p_score or bad["score"] > batch_90p_score: - logger.info(f"Skipping rule generation as good score {good['score']} is below the 10th percentile " - f"*or* bad score {bad['score']} is above the 90th percentile.") + if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score: + logger.info(f"Skipping rule generation as good score {good['score']} is at or below the 10th percentile " + f"*or* bad score {bad['score']} is at or above the 90th percentile.") return False if good["score"] <= bad["score"]: @@ -98,31 +159,36 @@ def append_a_rule(bucket, system, **kwargs): good["prediction"] = {"N/A": "Prediction not available"} better_trajectory = [ - {"module_name": predictor2name[id(p)], "inputs": i, "outputs": dict(o)} + dict(module_name=predictor2name[id(p)], inputs=i, outputs=dict(o)) for p, i, o in good["trace"] ] worse_trajectory = [ - {"module_name": predictor2name[id(p)], "inputs": i, "outputs": dict(o)} + dict(module_name=predictor2name[id(p)], inputs=i, outputs=dict(o)) for p, i, o in bad["trace"] ] - kwargs = { - "program_code": inspect.getsource(system.__class__), - "modules_defn": inspect_modules(system), - "program_inputs": {**example.inputs()}, - "oracle_metadata": {**example.labels()}, - "better_program_trajectory": better_trajectory, - "better_program_outputs": dict(good["prediction"]), - "worse_program_trajectory": worse_trajectory, - "worse_program_outputs": dict(bad["prediction"] or {}), - "worse_reward_value": bad["score"], - "better_reward_value": good["score"], - "module_names": module_names, - } + kwargs = dict( + program_code=inspect.getsource(system.__class__), + modules_defn=inspect_modules(system), + program_inputs={**example.inputs()}, + oracle_metadata={**example.labels()}, + better_program_trajectory=better_trajectory, + better_program_outputs=dict(good["prediction"]), + worse_program_trajectory=worse_trajectory, + worse_program_outputs=dict(bad["prediction"] or {}), + worse_reward_value=bad["score"], + better_reward_value=good["score"], + worse_reward_info=bad["output_metadata"], + better_reward_info=good["output_metadata"], + module_names=module_names, + ) kwargs = {k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2) for k, v in kwargs.items()} - advice = dspy.Predict(OfferFeedback)(**kwargs).module_advice + + with dspy.settings.context(trace=[], lm=prompt_model): + advice_program = dspy.Predict(OfferFeedback) + advice = advice_program(**kwargs).module_advice for name, predictor in system.named_predictors(): if name in advice: @@ -132,6 +198,130 @@ def append_a_rule(bucket, system, **kwargs): return True +def update_fields(bucket, system, **kwargs): + predictor2name = kwargs["predictor2name"] + batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"] + prompt_model = kwargs["prompt_model"] or dspy.settings.lm + + module_names = [name for name, _ in system.named_predictors()] + good, bad = bucket[0], bucket[-1] + example = good["example"] + + if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score: + logger.info(f"Skipping rule generation as good score {good['score']} is at or below the 10th percentile " + f"*or* bad score {bad['score']} is at or above the 90th percentile.") + return False + + if good["score"] <= bad["score"]: + if good["score"] > batch_90p_score: + bad["trace"] = [] + bad["score"] = "N/A" + bad["prediction"] = {"N/A": "Prediction not available"} + else: + good["trace"] = [] + good["score"] = "N/A" + good["prediction"] = {"N/A": "Prediction not available"} + + better_trajectory = [ + dict(module_name=predictor2name[id(p)], inputs=i, outputs=dict(o)) + for p, i, o in good["trace"] + ] + worse_trajectory = [ + dict(module_name=predictor2name[id(p)], inputs=i, outputs=dict(o)) + for p, i, o in bad["trace"] + ] + + # Get the current fields + current_fields = {} + for name, predictor in system.named_predictors(): + current_fields[name] = {} + for field_name, field in predictor.signature.input_fields.items(): + current_fields[name][field_name] = {} + current_fields[name][field_name]["name"] = field.json_schema_extra["prefix"] + current_fields[name][field_name]["desc"] = field.json_schema_extra["desc"] + for field_name, field in predictor.signature.output_fields.items(): + current_fields[name][field_name] = {} + current_fields[name][field_name]["name"] = field.json_schema_extra["prefix"] + current_fields[name][field_name]["desc"] = field.json_schema_extra["desc"] + + kwargs = dict( + program_code=inspect.getsource(system.__class__), + modules_defn=inspect_modules(system), + program_inputs={**example.inputs()}, + oracle_metadata={**example.labels()}, + current_fields=current_fields, + better_program_trajectory=better_trajectory, + better_program_outputs=dict(good["prediction"]), + worse_program_trajectory=worse_trajectory, + worse_program_outputs=dict(bad["prediction"] or {}), + worse_reward_value=bad["score"], + better_reward_value=good["score"], + worse_reward_info=bad["output_metadata"], + better_reward_info=good["output_metadata"], + ) + + kwargs = {k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2) + for k, v in kwargs.items()} + + # Get new prefixes and descriptions for each field + with dspy.settings.context(trace=[], lm=prompt_model): + update_fields_program = dspy.Predict(UpdateFields) + updated_fields = update_fields_program(**kwargs).updated_fields + + # Set the prefix and description of the fields + for name, predictor in system.named_predictors(): + if name in updated_fields: + for field_name, field in predictor.signature.input_fields.items(): + if field_name in updated_fields[name]: + if "name" in updated_fields[name][field_name]: + field.json_schema_extra["prefix"] = updated_fields[name][field_name]["name"] + if "desc" in updated_fields[name][field_name]: + field.json_schema_extra["desc"] = updated_fields[name][field_name]["desc"] + for field_name, field in predictor.signature.output_fields.items(): + if field_name in updated_fields[name]: + if "name" in updated_fields[name][field_name]: + field.json_schema_extra["prefix"] = updated_fields[name][field_name]["name"] + if "desc" in updated_fields[name][field_name]: + field.json_schema_extra["desc"] = updated_fields[name][field_name]["desc"] + + prompt_model.inspect_history(n=1) + print(f"Current fields: {current_fields}") + print(f"Updated fields: {updated_fields}") + + return True + +class UpdateFields(dspy.Signature): + """ + You will be given two trajectories of an LLM-driven program's execution: one that is successful and one that is not. + You will also be provided with the current fields of the program, which are being used to describe the desired inputs and outputs of the program to the LLM. + Your goal is to update the fields of the program to be more accurate and informative to ensure that the program + is able to learn from the successful trajectory and avoid the mistakes of the unsuccessful trajectory. You can update both the name and the description of the fields. + + These fields are important because they are used to provide the LLM with a description of the inputs it will receive, + and the outputs it will produce. + + """ + program_code: str = InputField(desc="The code of the program that we are analyzing") + modules_defn: str = InputField(desc="The definition of each module in the program, including its I/O") + program_inputs: str = InputField(desc="The inputs to the program that we are analyzing") + oracle_metadata: str = InputField(desc="Any (hidden) metadata about the training set instance we're analyzing") + worse_program_trajectory: str = InputField( + desc="The trajectory of the program's execution, showing each module's I/O" + ) + worse_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") + worse_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") + worse_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") + better_program_trajectory: str = InputField( + desc="The trajectory of the program's execution, showing each module's I/O" + ) + better_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") + better_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") + better_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") + current_fields: dict[str, dict[str, dict[str, str]]] = InputField(desc="A dictionary of current field names and descriptions for the program.") + discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did.") + field_discussion: str = OutputField(desc="Discussing the changes to the fields that should be made for each model in the program.") + updated_fields: dict[str, dict[str, dict[str, str]]] = OutputField(desc="A dictionary of new field names and descriptions for each module in the program. These will be used to update the fields of the program to better clarify expected inputs & outputs to the LLM.") + class OfferFeedback(dspy.Signature): """ You will be given two trajectories of an LLM-driven program's execution. Your goal is to help the program's modules @@ -156,11 +346,13 @@ class OfferFeedback(dspy.Signature): ) worse_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") worse_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") + worse_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") better_program_trajectory: str = InputField( desc="The trajectory of the program's execution, showing each module's I/O" ) better_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") better_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") + better_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") module_names: list[str] = InputField(desc="The names of the modules in the program, for which we seek advice") discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did") module_advice: dict[str, str] = OutputField( @@ -170,21 +362,20 @@ class OfferFeedback(dspy.Signature): "like the successful trajectory rather than the lower-scoring trajectory." ) - def inspect_modules(program): separator = "-" * 80 output = [separator] - for name, predictor in program.named_predictors(): + for idx, (name, predictor) in enumerate(program.named_predictors()): signature = predictor.signature instructions = textwrap.dedent(signature.instructions) instructions = ("\n" + "\t" * 2).join([""] + instructions.splitlines()) output.append(f"Module {name}") output.append("\n\tInput Fields:") - output.append(("\n" + "\t" * 2).join([""] + get_field_description_string(signature.input_fields).splitlines())) + output.append(("\n" + "\t" * 2).join([""] + enumerate_fields(signature.input_fields).splitlines())) output.append("\tOutput Fields:") - output.append(("\n" + "\t" * 2).join([""] + get_field_description_string(signature.output_fields).splitlines())) + output.append(("\n" + "\t" * 2).join([""] + enumerate_fields(signature.output_fields).splitlines())) output.append(f"\tOriginal Instructions: {instructions}") output.append(separator) @@ -210,4 +401,4 @@ def recursive_mask(o): return tuple(recursive_mask(v) for v in o) # Otherwise, replace it with a placeholder string (or use repr(o)). else: - return f"" + return f"" \ No newline at end of file From d1c054bf18a5c0a1db5ea99b2f42a2416fdd9f06 Mon Sep 17 00:00:00 2001 From: arnavsinghvi11 Date: Thu, 24 Jul 2025 14:12:46 -0700 Subject: [PATCH 5/5] support for nested schema fields --- dspy/teleprompt/simba_utils.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/dspy/teleprompt/simba_utils.py b/dspy/teleprompt/simba_utils.py index b36f12a33f..ff0cddaf6c 100644 --- a/dspy/teleprompt/simba_utils.py +++ b/dspy/teleprompt/simba_utils.py @@ -7,7 +7,8 @@ from dspy.adapters.chat_adapter import enumerate_fields from dspy.signatures import InputField, OutputField -from typing import Callable, Optional, Dict, Any +from dspy.adapters.types.base_type import Type +from typing import Callable, Optional, Dict, Any, Union logger = logging.getLogger(__name__) @@ -231,6 +232,35 @@ def update_fields(bucket, system, **kwargs): for p, i, o in bad["trace"] ] + def extract_model_fields(model, prefix, current_fields_dict, module_name): + if hasattr(model, 'model_fields'): + for nested_field_name, nested_field in model.model_fields.items(): + nested_key = f"{prefix}.{nested_field_name}" + current_fields_dict[module_name][nested_key] = {} + current_fields_dict[module_name][nested_key]["name"] = nested_field_name + current_fields_dict[module_name][nested_key]["desc"] = nested_field.description or "" + + def process_field(field, field_name, current_fields_dict, module_name): + if hasattr(field.annotation, 'model_fields'): + extract_model_fields(field.annotation, field_name, current_fields_dict, module_name) + elif hasattr(field.annotation, '__origin__'): + origin = field.annotation.__origin__ + args = field.annotation.__args__ + + if origin is list: + extract_model_fields(args[0], field_name, current_fields_dict, module_name) + elif origin is Union: + for arg in args: + if arg != type(None): + extract_model_fields(arg, field_name, current_fields_dict, module_name) + elif origin is tuple: + for i, arg in enumerate(args): + extract_model_fields(arg, f"{field_name}.{i}", current_fields_dict, module_name) + elif origin is dict: + extract_model_fields(args[1], f"{field_name}.value", current_fields_dict, module_name) + elif hasattr(field.annotation, '__bases__') and any(issubclass(base, Type) for base in field.annotation.__bases__): + extract_model_fields(field.annotation, field_name, current_fields_dict, module_name) + # Get the current fields current_fields = {} for name, predictor in system.named_predictors(): @@ -239,10 +269,12 @@ def update_fields(bucket, system, **kwargs): current_fields[name][field_name] = {} current_fields[name][field_name]["name"] = field.json_schema_extra["prefix"] current_fields[name][field_name]["desc"] = field.json_schema_extra["desc"] + process_field(field, field_name, current_fields, name) for field_name, field in predictor.signature.output_fields.items(): current_fields[name][field_name] = {} current_fields[name][field_name]["name"] = field.json_schema_extra["prefix"] current_fields[name][field_name]["desc"] = field.json_schema_extra["desc"] + process_field(field, field_name, current_fields, name) kwargs = dict( program_code=inspect.getsource(system.__class__),