|
| 1 | +import asyncio |
1 | 2 | import contextlib
|
2 | 3 | import copy
|
3 | 4 | import logging
|
@@ -42,29 +43,60 @@ def __init__(
|
42 | 43 | self.error_lock = threading.Lock()
|
43 | 44 | self.cancel_jobs = threading.Event()
|
44 | 45 |
|
| 46 | + self.error_lock_async = asyncio.Lock() |
| 47 | + self.cancel_jobs_async = asyncio.Event() |
| 48 | + |
45 | 49 | def execute(self, function, data):
|
46 | 50 | tqdm.tqdm._instances.clear()
|
47 |
| - wrapped = self._wrap_function(function) |
| 51 | + wrapped = self._wrap_function(function, async_mode=False) |
48 | 52 | return self._execute_parallel(wrapped, data)
|
49 | 53 |
|
50 |
| - def _wrap_function(self, user_function): |
51 |
| - def safe_func(item): |
| 54 | + async def aexecute(self, function, data): |
| 55 | + tqdm.tqdm._instances.clear() |
| 56 | + wrapped = self._wrap_function(function, async_mode=True) |
| 57 | + return await self._execute_parallel_async(wrapped, data) |
| 58 | + |
| 59 | + def _handle_error(self, item, e): |
| 60 | + with self.error_lock: |
| 61 | + self.error_count += 1 |
| 62 | + if self.error_count >= self.max_errors: |
| 63 | + self.cancel_jobs.set() |
| 64 | + if self.provide_traceback: |
| 65 | + logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}") |
| 66 | + else: |
| 67 | + logger.error(f"Error for {item}: {e}. Set `provide_traceback=True` for traceback.") |
| 68 | + |
| 69 | + async def _handle_error_async(self, item, e): |
| 70 | + async with self.error_lock_async: |
| 71 | + self.error_count += 1 |
| 72 | + if self.error_count >= self.max_errors: |
| 73 | + self.cancel_jobs_async.set() |
| 74 | + if self.provide_traceback: |
| 75 | + logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}") |
| 76 | + |
| 77 | + def _wrap_function(self, user_function, async_mode=False): |
| 78 | + async def _async_safe_func(item): |
| 79 | + if self.cancel_jobs.is_set(): |
| 80 | + return None |
| 81 | + try: |
| 82 | + return await user_function(item) |
| 83 | + except Exception as e: |
| 84 | + await self._handle_error_async(item, e) |
| 85 | + return None |
| 86 | + |
| 87 | + def _sync_safe_func(item): |
52 | 88 | if self.cancel_jobs.is_set():
|
53 | 89 | return None
|
54 | 90 | try:
|
55 | 91 | return user_function(item)
|
56 | 92 | except Exception as e:
|
57 |
| - with self.error_lock: |
58 |
| - self.error_count += 1 |
59 |
| - if self.error_count >= self.max_errors: |
60 |
| - self.cancel_jobs.set() |
61 |
| - if self.provide_traceback: |
62 |
| - logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}") |
63 |
| - else: |
64 |
| - logger.error(f"Error for {item}: {e}. Set `provide_traceback=True` for traceback.") |
| 93 | + self._handle_error(item, e) |
65 | 94 | return None
|
66 | 95 |
|
67 |
| - return safe_func |
| 96 | + if async_mode: |
| 97 | + return _async_safe_func |
| 98 | + else: |
| 99 | + return _sync_safe_func |
68 | 100 |
|
69 | 101 | def _execute_parallel(self, function, data):
|
70 | 102 | results = [None] * len(data)
|
@@ -204,6 +236,50 @@ def all_done():
|
204 | 236 |
|
205 | 237 | return results
|
206 | 238 |
|
| 239 | + async def _execute_parallel_async(self, function, data): |
| 240 | + queue = asyncio.Queue() |
| 241 | + results = [None] * len(data) |
| 242 | + for i, example in enumerate(data): |
| 243 | + await queue.put((i, example)) |
| 244 | + |
| 245 | + for _ in range(self.num_threads): |
| 246 | + # Add a sentinel value to indicate that the worker should exit |
| 247 | + await queue.put((-1, None)) |
| 248 | + |
| 249 | + # Create tqdm progress bar |
| 250 | + pbar = tqdm.tqdm(total=len(data), dynamic_ncols=True) |
| 251 | + |
| 252 | + async def worker(): |
| 253 | + while True: |
| 254 | + if self.cancel_jobs_async.is_set(): |
| 255 | + break |
| 256 | + index, example = await queue.get() |
| 257 | + if index == -1: |
| 258 | + break |
| 259 | + function_outputs = await function(example) |
| 260 | + results[index] = function_outputs |
| 261 | + |
| 262 | + if self.compare_results: |
| 263 | + vals = [r[-1] for r in results if r is not None] |
| 264 | + self._update_progress(pbar, sum(vals), len(vals)) |
| 265 | + else: |
| 266 | + self._update_progress( |
| 267 | + pbar, |
| 268 | + len([r for r in results if r is not None]), |
| 269 | + len(data), |
| 270 | + ) |
| 271 | + |
| 272 | + queue.task_done() |
| 273 | + |
| 274 | + workers = [asyncio.create_task(worker()) for _ in range(self.num_threads)] |
| 275 | + await asyncio.gather(*workers) |
| 276 | + pbar.close() |
| 277 | + if self.cancel_jobs_async.is_set(): |
| 278 | + logger.warning("Execution cancelled due to errors or interruption.") |
| 279 | + raise Exception("Execution cancelled due to errors or interruption.") |
| 280 | + |
| 281 | + return results |
| 282 | + |
207 | 283 | def _update_progress(self, pbar, nresults, ntotal):
|
208 | 284 | if self.compare_results:
|
209 | 285 | pct = round(100 * nresults / ntotal, 1) if ntotal else 0
|
|
0 commit comments