Skip to content

Commit 239e6c0

Browse files
authored
Fix GEPA usage tracking with tuple outputs (#8739)
* Fix GEPA usage tracking with tuple outputs - Fix AttributeError when track_usage=True and GEPA returns tuple - Module.__call__ now handles both prediction objects and tuples - Add regression test for GEPA compile with usage tracking - Resolves hanging/infinite loop when GEPA patches return tuples Fixes issue where GEPA bootstrap tracing returns (prediction, trace) tuples but usage tracking expected only prediction objects. * Rework test to simplify / self containe and make fix more readable * revert other code that has been formatted
1 parent 0bb0d93 commit 239e6c0

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

dspy/primitives/module.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,21 @@ def __call__(self, *args, **kwargs) -> Prediction:
7272
if settings.track_usage and thread_local_overrides.get().get("usage_tracker") is None:
7373
with track_usage() as usage_tracker:
7474
output = self.forward(*args, **kwargs)
75-
output.set_lm_usage(usage_tracker.get_total_tokens())
75+
tokens = usage_tracker.get_total_tokens()
76+
77+
# Some optimizers (e.g., GEPA bootstrap tracing) temporarily patch
78+
# module.forward to return a tuple: (prediction, trace).
79+
# When usage tracking is enabled, ensure we attach usage to the
80+
# prediction object if present.
81+
prediction_in_output = None
82+
if isinstance(output, Prediction):
83+
prediction_in_output = output
84+
elif isinstance(output, tuple) and len(output) > 0 and isinstance(output[0], Prediction):
85+
prediction_in_output = output[0]
86+
if not prediction_in_output:
87+
raise ValueError("No prediction object found in output to call set_lm_usage on.")
88+
89+
prediction_in_output.set_lm_usage(tokens)
7690
return output
7791

7892
return self.forward(*args, **kwargs)

tests/teleprompt/test_gepa.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import json
2+
import threading
3+
from typing import Any
24

35
import pytest
46

57
import dspy
68
import dspy.clients
79
from dspy import Example
810
from dspy.predict import Predict
11+
from dspy.utils.dummies import DummyLM
912

1013

1114
class SimpleModule(dspy.Module):
@@ -66,3 +69,57 @@ def test_metric_requires_feedback_signature():
6669
reflection_lm = DictDummyLM([])
6770
with pytest.raises(TypeError):
6871
dspy.GEPA(metric=bad_metric, reflection_lm=reflection_lm, max_metric_calls=1)
72+
73+
74+
def any_metric(
75+
gold: dspy.Example,
76+
pred: dspy.Prediction,
77+
trace: Any = None,
78+
pred_name: str | None = None,
79+
pred_trace: Any = None,
80+
) -> float:
81+
"""
82+
For this test, we only care that the program runs, not the score.
83+
"""
84+
return 0.0 # ← Just returns 0.0, doesn't access any attributes!
85+
86+
87+
def test_gepa_compile_with_track_usage_no_tuple_error(caplog):
88+
"""
89+
GEPA.compile should not log tuple-usage error when track_usage=True and complete without hanging.
90+
Before, compile would hang and/or log "'tuple' object has no attribute 'set_lm_usage'" repeatedly.
91+
"""
92+
student = dspy.Predict("question -> answer")
93+
trainset = [dspy.Example(question="What is 2+2?", answer="4").with_inputs("question")]
94+
95+
task_lm = DummyLM([{"answer": "mock answer 1"}])
96+
reflection_lm = DummyLM([{"new_instruction": "Something new."}])
97+
98+
compiled_container: dict[str, Any] = {}
99+
exc_container: dict[str, BaseException] = {}
100+
101+
def run_compile():
102+
try:
103+
with dspy.context(lm=task_lm, track_usage=True):
104+
optimizer = dspy.GEPA(metric=any_metric, reflection_lm=reflection_lm, max_metric_calls=3)
105+
compiled_container["prog"] = optimizer.compile(student, trainset=trainset, valset=trainset)
106+
except BaseException as e:
107+
exc_container["e"] = e
108+
109+
t = threading.Thread(target=run_compile, daemon=True)
110+
t.start()
111+
t.join(timeout=1.0)
112+
113+
# Assert compile did not hang (pre-fix behavior would time out here)
114+
assert not t.is_alive(), "GEPA.compile did not complete within timeout (likely pre-fix behavior)."
115+
116+
# Assert no tuple-usage error is logged anymore
117+
assert "'tuple' object has no attribute 'set_lm_usage'" not in caplog.text
118+
119+
# If any exception occurred, fail explicitly
120+
if "e" in exc_container:
121+
pytest.fail(f"GEPA.compile raised unexpectedly: {exc_container['e']}")
122+
123+
# No timeout, no exception -> so the program must exist
124+
if "prog" not in compiled_container:
125+
pytest.fail("GEPA.compile did return a program (likely pre-fix behavior).")

0 commit comments

Comments
 (0)