-
Notifications
You must be signed in to change notification settings - Fork 2.2k
SIMBA Improvements #8766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
SIMBA Improvements #8766
Conversation
|
||
def parse_value(value, annotation): | ||
annotation = _strip_optional(annotation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fixes the following failure case:
Previously, this was failing for fields of Union(str, None) with values of type 'str' that could also be parsed as ints. Ex: "9812750".
The problem is in this sequence:
value = "9812750" (string)
annotation = typing.Optional[str]
candidate = json_repair.loads("9812750") → 9812750 (parses as integer, not str)
TypeAdapter(typing.Optional[str]).validate_python(9812750) → Fails with pydantic.ValidationError since the value is neither a str nor None
Exception handler is triggered: except pydantic.ValidationError as e:
Then we hit this line : issubclass(annotation, Type), which throws the error issubclass() arg 1 must be a class
because typing.Optional[str]
is not a class, since it's a type annotation/Union.
This fix involves first parsing the Optional field to get the expected non-Null type, and parse the value according to this. Now pydantic handles the type coercion correctly from str -> str instead of int-> int when the non-null annotation type is 'str'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we just need to change the condition from issubclass(annotation, Type)
to inspect.isclass(annotation) and issubclass(annotation, Type)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current approach doesn't handle str | None IIUC, let me file a PR to fix this issue so that we can keep this PR focuses on simba.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#8774, which handles the parse issue
dspy/teleprompt/simba.py
Outdated
@@ -31,6 +31,8 @@ def __init__( | |||
num_candidates: int = 6, | |||
max_steps: int = 8, | |||
max_demos: int = 4, | |||
prompt_model: Optional[Any] = None, | |||
teacher_settings: Optional[Dict] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding support for prompt / teacher models
|
||
|
||
start_rollout_idx, models = 0, [] | ||
# If we have a teacher model, use this as the first model | ||
if teacher_settings: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been updated to add support for teacher model (used for 1 of the N trajectories)
elif isinstance(output, dspy.Prediction): | ||
if not hasattr(output, 'score'): | ||
raise ValueError("dspy.Prediction must contain a 'score' attribute") | ||
score = output.score |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to handle additional metric metadata in addition to the score. To do this, we check if the output from the metric is a float or int (in which case we use it as the score) or a dspy.Prediction object, which contains a score + potentially additional meta-data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also use float(output)
, which might be more intuitive?
name2demo = {} | ||
|
||
if good["score"] <= batch_10p_score: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double checking that the demo we're appending is not below the 10th percentile of scores
@@ -117,12 +147,17 @@ def append_a_rule(bucket, system, **kwargs): | |||
"worse_program_outputs": dict(bad["prediction"] or {}), | |||
"worse_reward_value": bad["score"], | |||
"better_reward_value": good["score"], | |||
"worse_reward_info": bad["output_metadata"], | |||
"better_reward_info": good["output_metadata"], | |||
"module_names": module_names, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding in the metric meta-data (ex. feedback from a judge) to help come up with a better set of rules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces several key improvements to the SIMBA (Sampling-based Iterative Multi-round Bootstrapping Algorithm) system for optimizing DSPy programs. The changes focus on enhancing model flexibility, improving demo quality control, and expanding metric capabilities.
- Support for teacher and prompt models to generate more targeted trajectories
- Quality filtering to prevent appending poor-performing demos below the 10th percentile
- Enhanced metric system supporting metadata alongside scores
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
dspy/teleprompt/simba_utils.py | Core logic updates for teacher model support, demo filtering, and metric metadata handling |
dspy/teleprompt/simba.py | Constructor updates to accept new prompt_model and teacher_settings parameters |
dspy/adapters/utils.py | Bug fix for Optional field parsing to handle string values correctly |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
if good["score"] <= batch_10p_score: | ||
logger.info(f"Skipping appending a demo as good score {good['score']} is at or below the 10th percentile.") | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The condition check and logic for skipping demo appending is duplicated between append_a_demo_
and append_a_rule
functions. Consider extracting this into a shared helper function to reduce code duplication.
Copilot uses AI. Check for mistakes.
teacher_lm.kwargs["rollout_id"] = rollout_ids[start_rollout_idx] | ||
models.append(teacher_lm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Direct mutation of teacher_lm.kwargs
could cause side effects if the teacher model is reused elsewhere. Consider using teacher_lm.copy()
and setting the rollout_id on the copy instead.
teacher_lm.kwargs["rollout_id"] = rollout_ids[start_rollout_idx] | |
models.append(teacher_lm) | |
models.append(teacher_lm.copy(rollout_id=rollout_ids[start_rollout_idx])) |
Copilot uses AI. Check for mistakes.
def _strip_optional(ann): | ||
"""If ann is Union[..., NoneType] return the non‑None part, else ann.""" | ||
if get_origin(ann) is Union and NoneType in get_args(ann): | ||
# keep the first non‑None member (there will be only one in Optional[T]) | ||
return next(a for a in get_args(ann) if a is not NoneType) | ||
return ann |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring uses an en-dash character (‑) instead of a regular hyphen (-) in 'non‑None'. This should be corrected for consistency and readability.
Copilot uses AI. Check for mistakes.
@@ -31,6 +31,8 @@ def __init__( | |||
num_candidates: int = 6, | |||
max_steps: int = 8, | |||
max_demos: int = 4, | |||
prompt_model: Any | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prompt_model: Any | None = None, | |
prompt_model: dspy.LM | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, can we add type annotations for other arguments too?
@@ -62,6 +64,8 @@ def __init__( | |||
self.num_candidates = num_candidates | |||
self.max_steps = max_steps | |||
self.max_demos = max_demos | |||
self.prompt_model = prompt_model if prompt_model else dspy.settings.lm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.prompt_model = prompt_model if prompt_model else dspy.settings.lm | |
self.prompt_model = prompt_model or dspy.settings.lm |
score = output.score | ||
# Just extract fields from _store, excluding 'score' | ||
output_metadata = { | ||
k: v for k, v in output._store.items() if k != "score" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we use output.items()
?
@@ -77,14 +106,15 @@ def append_a_demo_(bucket, system, **kwargs): | |||
def append_a_rule(bucket, system, **kwargs): | |||
predictor2name = kwargs["predictor2name"] | |||
batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"] | |||
prompt_model = kwargs["prompt_model"] or dspy.settings.lm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q: is it possible that prompt_model
is not passed? Maybe kwargs.get("prompt_model")
is safer
if good["score"] < batch_10p_score or bad["score"] > batch_90p_score: | ||
logger.info(f"Skipping rule generation as good score {good['score']} is below the 10th percentile " | ||
f"*or* bad score {bad['score']} is above the 90th percentile.") | ||
if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score: | |
if good <= batch_10p_score or bad >= batch_90p_score: |
@@ -1,3 +1,4 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: blank line
name2demo = {} | ||
|
||
if good["score"] <= batch_10p_score: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of my comments overlap with Tomu's. Basically the function looks good, we can merge after addressing the comments.
@@ -26,33 +38,51 @@ def wrapped_program(example): | |||
try: | |||
prediction = program(**example.inputs()) | |||
except Exception as e: | |||
print(e) | |||
logger.info(e) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be logger.warning
or logger.error
?
This PR makes the following updates:
parse_value
function that parses strings like "152" as numbers when they are of annotation type Optional(None, str)