Skip to content

Commit 0a44512

Browse files
adding keep_module_scores flag
1 parent 7ec9abe commit 0a44512

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

dspy/teleprompt/gepa/gepa_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
num_threads: int | None = None,
5959
add_format_failure_as_feedback: bool = False,
6060
rng: random.Random | None = None,
61+
keep_module_scores: bool = False,
6162
):
6263
self.student = student_module
6364
self.metric_fn = metric_fn
@@ -66,6 +67,7 @@ def __init__(
6667
self.num_threads = num_threads
6768
self.add_format_failure_as_feedback = add_format_failure_as_feedback
6869
self.rng = rng or random.Random(0)
70+
self.keep_module_scores = keep_module_scores
6971

7072
# Cache predictor names/signatures
7173
self.named_predictors = list(self.student.named_predictors())
@@ -214,8 +216,11 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update):
214216
captured_trace=trace,
215217
)
216218
d["Feedback"] = fb["feedback"]
217-
assert fb["score"] == module_score, f"Currently, GEPA only supports feedback functions that return the same score as the module's score. However, the module-level score is {module_score} and the feedback score is {fb.score}."
218-
# d['score'] = fb.score
219+
if self.keep_module_scores:
220+
d['score'] = module_score
221+
else:
222+
d['score'] = fb['score']
223+
219224
items.append(d)
220225

221226
if len(items) == 0:

0 commit comments

Comments
 (0)