Skip to content

Conversation

klopsahlong
Copy link
Collaborator

@klopsahlong klopsahlong commented Sep 4, 2025

This PR makes the following updates:

  • Support for teacher and prompt model. The teacher model will be used to generate 1/N trajectories for each example, so that we are still targeting rule generation improvements based on task model failure modes.
  • Don't append_demo if demo score < 10th percentile. This will help us avoid adding in poor demos.
  • Support for metric metadata. Allow for additional metadata to be passed back in a dspy.Prediction object, in addition to the score. Note that the one downside here is that users must know that the score should be called 'score'. To address this for now, we've added in an error message telling users to add in score to their dspy.Prediction object.
  • Fix for Optional fields. Fixes a bug in the parse_value function that parses strings like "152" as numbers when they are of annotation type Optional(None, str)

@klopsahlong klopsahlong changed the title simba updates + handling optional fields SIMBA Improvements Sep 4, 2025

def parse_value(value, annotation):
annotation = _strip_optional(annotation)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes the following failure case:

Previously, this was failing for fields of Union(str, None) with values of type 'str' that could also be parsed as ints. Ex: "9812750".

The problem is in this sequence:
value = "9812750" (string)
annotation = typing.Optional[str]
candidate = json_repair.loads("9812750") → 9812750 (parses as integer, not str)
TypeAdapter(typing.Optional[str]).validate_python(9812750) → Fails with pydantic.ValidationError since the value is neither a str nor None
Exception handler is triggered: except pydantic.ValidationError as e:
Then we hit this line : issubclass(annotation, Type), which throws the error issubclass() arg 1 must be a class because typing.Optional[str] is not a class, since it's a type annotation/Union.

This fix involves first parsing the Optional field to get the expected non-Null type, and parse the value according to this. Now pydantic handles the type coercion correctly from str -> str instead of int-> int when the non-null annotation type is 'str'

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we just need to change the condition from issubclass(annotation, Type) to inspect.isclass(annotation) and issubclass(annotation, Type)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current approach doesn't handle str | None IIUC, let me file a PR to fix this issue so that we can keep this PR focuses on simba.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#8774, which handles the parse issue

@@ -31,6 +31,8 @@ def __init__(
num_candidates: int = 6,
max_steps: int = 8,
max_demos: int = 4,
prompt_model: Optional[Any] = None,
teacher_settings: Optional[Dict] = None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding support for prompt / teacher models



start_rollout_idx, models = 0, []
# If we have a teacher model, use this as the first model
if teacher_settings:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been updated to add support for teacher model (used for 1 of the N trajectories)

elif isinstance(output, dspy.Prediction):
if not hasattr(output, 'score'):
raise ValueError("dspy.Prediction must contain a 'score' attribute")
score = output.score
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to handle additional metric metadata in addition to the score. To do this, we check if the output from the metric is a float or int (in which case we use it as the score) or a dspy.Prediction object, which contains a score + potentially additional meta-data

Copy link
Collaborator

@TomeHirata TomeHirata Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also use float(output), which might be more intuitive?

name2demo = {}

if good["score"] <= batch_10p_score:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double checking that the demo we're appending is not below the 10th percentile of scores

@@ -117,12 +147,17 @@ def append_a_rule(bucket, system, **kwargs):
"worse_program_outputs": dict(bad["prediction"] or {}),
"worse_reward_value": bad["score"],
"better_reward_value": good["score"],
"worse_reward_info": bad["output_metadata"],
"better_reward_info": good["output_metadata"],
"module_names": module_names,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding in the metric meta-data (ex. feedback from a judge) to help come up with a better set of rules

@klopsahlong klopsahlong marked this pull request as ready for review September 4, 2025 15:27
@TomeHirata TomeHirata requested a review from Copilot September 5, 2025 08:02
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces several key improvements to the SIMBA (Sampling-based Iterative Multi-round Bootstrapping Algorithm) system for optimizing DSPy programs. The changes focus on enhancing model flexibility, improving demo quality control, and expanding metric capabilities.

  • Support for teacher and prompt models to generate more targeted trajectories
  • Quality filtering to prevent appending poor-performing demos below the 10th percentile
  • Enhanced metric system supporting metadata alongside scores

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
dspy/teleprompt/simba_utils.py Core logic updates for teacher model support, demo filtering, and metric metadata handling
dspy/teleprompt/simba.py Constructor updates to accept new prompt_model and teacher_settings parameters
dspy/adapters/utils.py Bug fix for Optional field parsing to handle string values correctly

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +82 to +84
if good["score"] <= batch_10p_score:
logger.info(f"Skipping appending a demo as good score {good['score']} is at or below the 10th percentile.")
return False
Copy link
Preview

Copilot AI Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The condition check and logic for skipping demo appending is duplicated between append_a_demo_ and append_a_rule functions. Consider extracting this into a shared helper function to reduce code duplication.

Copilot uses AI. Check for mistakes.

Comment on lines +25 to +26
teacher_lm.kwargs["rollout_id"] = rollout_ids[start_rollout_idx]
models.append(teacher_lm)
Copy link
Preview

Copilot AI Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Direct mutation of teacher_lm.kwargs could cause side effects if the teacher model is reused elsewhere. Consider using teacher_lm.copy() and setting the rollout_id on the copy instead.

Suggested change
teacher_lm.kwargs["rollout_id"] = rollout_ids[start_rollout_idx]
models.append(teacher_lm)
models.append(teacher_lm.copy(rollout_id=rollout_ids[start_rollout_idx]))

Copilot uses AI. Check for mistakes.

Comment on lines +137 to +142
def _strip_optional(ann):
"""If ann is Union[..., NoneType] return the non‑None part, else ann."""
if get_origin(ann) is Union and NoneType in get_args(ann):
# keep the first non‑None member (there will be only one in Optional[T])
return next(a for a in get_args(ann) if a is not NoneType)
return ann
Copy link
Preview

Copilot AI Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring uses an en-dash character (‑) instead of a regular hyphen (-) in 'non‑None'. This should be corrected for consistency and readability.

Copilot uses AI. Check for mistakes.

@@ -31,6 +31,8 @@ def __init__(
num_candidates: int = 6,
max_steps: int = 8,
max_demos: int = 4,
prompt_model: Any | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prompt_model: Any | None = None,
prompt_model: dspy.LM | None = None,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, can we add type annotations for other arguments too?

@@ -62,6 +64,8 @@ def __init__(
self.num_candidates = num_candidates
self.max_steps = max_steps
self.max_demos = max_demos
self.prompt_model = prompt_model if prompt_model else dspy.settings.lm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.prompt_model = prompt_model if prompt_model else dspy.settings.lm
self.prompt_model = prompt_model or dspy.settings.lm

score = output.score
# Just extract fields from _store, excluding 'score'
output_metadata = {
k: v for k, v in output._store.items() if k != "score"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we use output.items()?

@@ -77,14 +106,15 @@ def append_a_demo_(bucket, system, **kwargs):
def append_a_rule(bucket, system, **kwargs):
predictor2name = kwargs["predictor2name"]
batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"]
prompt_model = kwargs["prompt_model"] or dspy.settings.lm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: is it possible that prompt_model is not passed? Maybe kwargs.get("prompt_model") is safer

if good["score"] < batch_10p_score or bad["score"] > batch_90p_score:
logger.info(f"Skipping rule generation as good score {good['score']} is below the 10th percentile "
f"*or* bad score {bad['score']} is above the 90th percentile.")
if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score:
if good <= batch_10p_score or bad >= batch_90p_score:

@@ -1,3 +1,4 @@

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: blank line

name2demo = {}

if good["score"] <= batch_10p_score:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of my comments overlap with Tomu's. Basically the function looks good, we can merge after addressing the comments.

@@ -26,33 +38,51 @@ def wrapped_program(example):
try:
prediction = program(**example.inputs())
except Exception as e:
print(e)
logger.info(e)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be logger.warning or logger.error?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants