Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 14 additions & 78 deletions src/coffea/processor/_dask.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,20 @@
from dask.distributed.diagnostics.progressbar import ProgressBar
from distributed.client import futures_of
from distributed.core import clean_exception
from distributed.utils import LoopRunner
from rich.progress import Progress
from rich.traceback import Traceback
from tornado.ioloop import IOLoop

from coffea.util import rich_bar


class RichProgressBar(ProgressBar):
__loop: IOLoop | None = None

def __init__(
self,
keys,
scheduler=None,
interval="100ms",
complete=False,
progress_bar=None,
description="Processing",
unit="tasks",
):
super().__init__(keys, scheduler, interval, complete)
if progress_bar is not None:
if not isinstance(progress_bar, Progress):
raise ValueError(
"progress_bar must be a rich.progress.Progress instance"
)
self.pbar = progress_bar
else:
self.pbar = rich_bar()
self.pbar.start()

self.task = self.pbar.add_task(description, total=len(keys), unit=unit)
from __future__ import annotations

self._loop_runner = LoopRunner(loop=None)
self._loop_runner.run_sync(self.listen)
from typing import Any

@property
def loop(self) -> IOLoop | None:
loop = self.__loop
if loop is None:
# If the loop is not running when this is called, the LoopRunner.loop
# property will raise a DeprecationWarning
# However subsequent calls might occur - eg atexit, where a stopped
# loop is still acceptable - so we cache access to the loop.
self.__loop = loop = self._loop_runner.loop
return loop

def _draw_stop(self, remaining, all, status, exception=None, **kwargs):
del kwargs

if status == "error":
_, exception, _ = clean_exception(exception)

rtc = Traceback.from_exception(
type(exception),
exception,
exception.__traceback__,
)
self.pbar.console.print(rtc)

if not remaining:
self.pbar.update(self.task, total=all, completed=all)
self.pbar.stop()
from rich.console import Group
from rich.live import Live
from rich.progress import Progress

def _draw_bar(self, remaining, all, **kwargs):
del kwargs
self.pbar.update(self.task, total=all, completed=all - remaining, refresh=True)
from coffea.util import coffea_console, rich_bar

_processing_sentinel = object()
_final_merge_sentinel = object()

def progress(*futures, complete=True, **kwargs):
# fallback to normal dask progress bar if any special kwargs are given
if "multi" in kwargs or "group_by" in kwargs:
from distributed import progress as dask_progress

dask_progress(*futures, complete=complete, **kwargs)
else:
futures = futures_of(futures)
if not isinstance(futures, (set, list)):
futures = [futures]
RichProgressBar(futures, complete=complete, **kwargs)
# group of progress bars for dask executor
def pbar_group(datasets: list[str]) -> tuple[Live, dict[Any, Progress]]:
pbars = {_processing_sentinel: rich_bar()}
pbars.update({ds: rich_bar() for ds in datasets})
pbars[_final_merge_sentinel] = rich_bar()
return Live(Group(*pbars.values()), console=coffea_console), pbars
Loading
Loading