Skip to content

Commit 23c320d

Browse files
move async handling to parallelizer
1 parent 8ee8ade commit 23c320d

File tree

3 files changed

+164
-48
lines changed

3 files changed

+164
-48
lines changed

dspy/evaluate/evaluate.py

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
if TYPE_CHECKING:
77
import pandas as pd
88

9-
import asyncio
109

1110
import tqdm
1211

@@ -161,7 +160,7 @@ async def acall(
161160
if callback_metadata:
162161
logger.debug(f"Evaluate.acall is called with callback metadata: {callback_metadata}")
163162
tqdm.tqdm._instances.clear()
164-
results = await self._execute_with_event_loop(program, metric, devset, num_threads)
163+
results = await self._execute_with_event_loop(program, metric, devset, num_threads, display_progress)
165164
return self._process_evaluate_result(devset, results, metric, display_table)
166165

167166
def _resolve_call_args(self, metric, devset, num_threads, display_progress, display_table):
@@ -220,41 +219,22 @@ async def _execute_with_event_loop(
220219
metric: Callable,
221220
devset: list["dspy.Example"],
222221
num_threads: int,
222+
disable_progress_bar: bool,
223223
):
224-
queue = asyncio.Queue()
225-
results = [None] * len(devset)
226-
for i, example in enumerate(devset):
227-
await queue.put((i, example))
228-
229-
for _ in range(num_threads):
230-
# Add a sentinel value to indicate that the worker should exit
231-
await queue.put((-1, None))
232-
233-
# Create tqdm progress bar
234-
pbar = tqdm.tqdm(total=len(devset), dynamic_ncols=True)
235-
236-
async def worker():
237-
while True:
238-
index, example = await queue.get()
239-
if index == -1:
240-
break
241-
prediction = await program.acall(**example.inputs())
242-
score = metric(example, prediction)
243-
results[index] = (prediction, score)
244-
245-
vals = [r[-1] for r in results if r is not None]
246-
nresults = sum(vals)
247-
ntotal = len(vals)
248-
pct = round(100 * nresults / ntotal, 1) if ntotal else 0
249-
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({pct}%)")
250-
pbar.update(1)
251-
queue.task_done()
252-
253-
workers = [asyncio.create_task(worker()) for _ in range(num_threads)]
254-
await asyncio.gather(*workers)
255-
pbar.close()
256-
257-
return results
224+
executor = ParallelExecutor(
225+
num_threads=num_threads,
226+
disable_progress_bar=disable_progress_bar,
227+
max_errors=(self.max_errors or dspy.settings.max_errors),
228+
provide_traceback=self.provide_traceback,
229+
compare_results=True,
230+
)
231+
232+
async def process_item(example):
233+
prediction = await program.acall(**example.inputs())
234+
score = metric(example, prediction)
235+
return prediction, score
236+
237+
return await executor.aexecute(process_item, devset)
258238

259239
def _construct_result_table(
260240
self, results: list[tuple["dspy.Example", "dspy.Example", Any]], metric_name: str

dspy/utils/parallelizer.py

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import contextlib
23
import copy
34
import logging
@@ -42,29 +43,60 @@ def __init__(
4243
self.error_lock = threading.Lock()
4344
self.cancel_jobs = threading.Event()
4445

46+
self.error_lock_async = asyncio.Lock()
47+
self.cancel_jobs_async = asyncio.Event()
48+
4549
def execute(self, function, data):
4650
tqdm.tqdm._instances.clear()
47-
wrapped = self._wrap_function(function)
51+
wrapped = self._wrap_function(function, async_mode=False)
4852
return self._execute_parallel(wrapped, data)
4953

50-
def _wrap_function(self, user_function):
51-
def safe_func(item):
54+
async def aexecute(self, function, data):
55+
tqdm.tqdm._instances.clear()
56+
wrapped = self._wrap_function(function, async_mode=True)
57+
return await self._execute_parallel_async(wrapped, data)
58+
59+
def _handle_error(self, item, e):
60+
with self.error_lock:
61+
self.error_count += 1
62+
if self.error_count >= self.max_errors:
63+
self.cancel_jobs.set()
64+
if self.provide_traceback:
65+
logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}")
66+
else:
67+
logger.error(f"Error for {item}: {e}. Set `provide_traceback=True` for traceback.")
68+
69+
async def _handle_error_async(self, item, e):
70+
async with self.error_lock_async:
71+
self.error_count += 1
72+
if self.error_count >= self.max_errors:
73+
self.cancel_jobs_async.set()
74+
if self.provide_traceback:
75+
logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}")
76+
77+
def _wrap_function(self, user_function, async_mode=False):
78+
async def _async_safe_func(item):
79+
if self.cancel_jobs.is_set():
80+
return None
81+
try:
82+
return await user_function(item)
83+
except Exception as e:
84+
await self._handle_error_async(item, e)
85+
return None
86+
87+
def _sync_safe_func(item):
5288
if self.cancel_jobs.is_set():
5389
return None
5490
try:
5591
return user_function(item)
5692
except Exception as e:
57-
with self.error_lock:
58-
self.error_count += 1
59-
if self.error_count >= self.max_errors:
60-
self.cancel_jobs.set()
61-
if self.provide_traceback:
62-
logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}")
63-
else:
64-
logger.error(f"Error for {item}: {e}. Set `provide_traceback=True` for traceback.")
93+
self._handle_error(item, e)
6594
return None
6695

67-
return safe_func
96+
if async_mode:
97+
return _async_safe_func
98+
else:
99+
return _sync_safe_func
68100

69101
def _execute_parallel(self, function, data):
70102
results = [None] * len(data)
@@ -204,6 +236,50 @@ def all_done():
204236

205237
return results
206238

239+
async def _execute_parallel_async(self, function, data):
240+
queue = asyncio.Queue()
241+
results = [None] * len(data)
242+
for i, example in enumerate(data):
243+
await queue.put((i, example))
244+
245+
for _ in range(self.num_threads):
246+
# Add a sentinel value to indicate that the worker should exit
247+
await queue.put((-1, None))
248+
249+
# Create tqdm progress bar
250+
pbar = tqdm.tqdm(total=len(data), dynamic_ncols=True)
251+
252+
async def worker():
253+
while True:
254+
if self.cancel_jobs_async.is_set():
255+
break
256+
index, example = await queue.get()
257+
if index == -1:
258+
break
259+
function_outputs = await function(example)
260+
results[index] = function_outputs
261+
262+
if self.compare_results:
263+
vals = [r[-1] for r in results if r is not None]
264+
self._update_progress(pbar, sum(vals), len(vals))
265+
else:
266+
self._update_progress(
267+
pbar,
268+
len([r for r in results if r is not None]),
269+
len(data),
270+
)
271+
272+
queue.task_done()
273+
274+
workers = [asyncio.create_task(worker()) for _ in range(self.num_threads)]
275+
await asyncio.gather(*workers)
276+
pbar.close()
277+
if self.cancel_jobs_async.is_set():
278+
logger.warning("Execution cancelled due to errors or interruption.")
279+
raise Exception("Execution cancelled due to errors or interruption.")
280+
281+
return results
282+
207283
def _update_progress(self, pbar, nresults, ntotal):
208284
if self.compare_results:
209285
pct = round(100 * nresults / ntotal, 1) if ntotal else 0

tests/utils/test_parallelizer.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,63 @@ def task(item):
5959

6060
# Verify that the results exclude the failed task
6161
assert results == [1, 2, None, 4, 5]
62+
63+
64+
@pytest.mark.asyncio
65+
async def test_worker_threads_independence_async():
66+
async def task(item):
67+
# Each thread maintains its own state by appending to a thread-local list
68+
return item * 2
69+
70+
data = [1, 2, 3, 4, 5]
71+
executor = ParallelExecutor(num_threads=3)
72+
results = await executor.aexecute(task, data)
73+
74+
assert results == [2, 4, 6, 8, 10]
75+
76+
77+
@pytest.mark.asyncio
78+
async def test_parallel_execution_speed_async():
79+
async def task(item):
80+
time.sleep(0.1) # Simulate a time-consuming task
81+
return item
82+
83+
data = [1, 2, 3, 4, 5]
84+
executor = ParallelExecutor(num_threads=5)
85+
86+
start_time = time.time()
87+
await executor.aexecute(task, data)
88+
end_time = time.time()
89+
90+
assert end_time - start_time < len(data)
91+
92+
93+
@pytest.mark.asyncio
94+
async def test_max_errors_handling_async():
95+
async def task(item):
96+
if item == 3:
97+
raise ValueError("Intentional error")
98+
return item
99+
100+
data = [1, 2, 3, 4, 5]
101+
executor = ParallelExecutor(num_threads=3, max_errors=1)
102+
103+
with pytest.raises(Exception, match="Execution cancelled due to errors or interruption."):
104+
await executor.aexecute(task, data)
105+
106+
107+
@pytest.mark.asyncio
108+
async def test_max_errors_not_met_async():
109+
async def task(item):
110+
if item == 3:
111+
raise ValueError("Intentional error")
112+
return item
113+
114+
data = [1, 2, 3, 4, 5]
115+
executor = ParallelExecutor(num_threads=3, max_errors=2)
116+
117+
# Ensure that the execution completes without crashing when max_errors is not met
118+
results = await executor.aexecute(task, data)
119+
120+
# Verify that the results exclude the failed task
121+
assert results == [1, 2, None, 4, 5]

0 commit comments

Comments
 (0)