diff --git a/cubed/extensions/rich.py b/cubed/extensions/rich.py new file mode 100644 index 00000000..f78c76be --- /dev/null +++ b/cubed/extensions/rich.py @@ -0,0 +1,112 @@ +import logging +import sys +from contextlib import contextmanager + +from rich.console import RenderableType +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + Task, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, +) +from rich.text import Text + +from cubed.runtime.pipeline import visit_nodes +from cubed.runtime.types import Callback + + +class RichProgressBar(Callback): + """Rich progress bar for a computation.""" + + def on_compute_start(self, event): + # Set the pulse_style to the background colour to disable pulsing, + # since Rich will pulse all non-started bars. + logger_aware_progress = LoggerAwareProgress( + SpinnerWhenRunningColumn(), + TextColumn("[progress.description]{task.description}"), + LeftJustifiedMofNCompleteColumn(), + BarColumn(bar_width=None, pulse_style="bar.back"), + TaskProgressColumn( + text_format="[progress.percentage]{task.percentage:>3.1f}%" + ), + TimeElapsedColumn(), + logger=logging.getLogger(), + ) + progress = logger_aware_progress.__enter__() + + progress_tasks = {} + for name, node in visit_nodes(event.dag, event.resume): + num_tasks = node["primitive_op"].num_tasks + progress_task = progress.add_task(f"{name}", start=False, total=num_tasks) + progress_tasks[name] = progress_task + + self.logger_aware_progress = logger_aware_progress + self.progress = progress + self.progress_tasks = progress_tasks + + def on_compute_end(self, event): + self.logger_aware_progress.__exit__(None, None, None) + + def on_operation_start(self, event): + self.progress.start_task(self.progress_tasks[event.name]) + + def on_task_end(self, event): + self.progress.update(self.progress_tasks[event.name], advance=event.num_tasks) + + +class SpinnerWhenRunningColumn(SpinnerColumn): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Override so spinner is not shown when bar has not yet started + def render(self, task: "Task") -> RenderableType: + text = ( + self.finished_text + if not task.started or task.finished + else self.spinner.render(task.get_time()) + ) + return text + + +class LeftJustifiedMofNCompleteColumn(MofNCompleteColumn): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def render(self, task: "Task") -> Text: + """Show completed/total.""" + completed = int(task.completed) + total = int(task.total) if task.total is not None else "?" + total_width = len(str(total)) + return Text( + f"{completed}{self.separator}{total}".ljust(total_width + 1 + total_width), + style="progress.download", + ) + + +# Based on CustomProgress from https://github.com/Textualize/rich/discussions/1578 +@contextmanager +def LoggerAwareProgress(*args, **kwargs): + """Wrapper around rich.progress.Progress to manage logging output to stderr.""" + try: + __logger = kwargs.pop("logger", None) + streamhandlers = [ + x for x in __logger.root.handlers if type(x) is logging.StreamHandler + ] + + with Progress(*args, **kwargs) as progress: + for handler in streamhandlers: + __prior_stderr = handler.stream + handler.setStream(sys.stderr) + + yield progress + + finally: + streamhandlers = [ + x for x in __logger.root.handlers if type(x) is logging.StreamHandler + ] + for handler in streamhandlers: + handler.setStream(__prior_stderr) diff --git a/cubed/tests/test_executor_features.py b/cubed/tests/test_executor_features.py index 9117278a..747ed5d7 100644 --- a/cubed/tests/test_executor_features.py +++ b/cubed/tests/test_executor_features.py @@ -9,6 +9,7 @@ import cubed.array_api as xp import cubed.random from cubed.extensions.history import HistoryCallback +from cubed.extensions.rich import RichProgressBar from cubed.extensions.timeline import TimelineVisualizationCallback from cubed.extensions.tqdm import TqdmProgressBar from cubed.primitive.blockwise import apply_blockwise @@ -97,6 +98,19 @@ def test_callbacks(spec, executor): assert task_counter.value == num_created_arrays + 4 +def test_rich_progress_bar(spec, executor): + # test indirectly by checking it doesn't cause a failure + progress = RichProgressBar() + + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec) + c = xp.add(a, b) + assert_array_equal( + c.compute(executor=executor, callbacks=[progress]), + np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]), + ) + + @pytest.mark.cloud def test_callbacks_modal(spec, modal_executor): task_counter = TaskCounter(check_timestamps=False) diff --git a/pyproject.toml b/pyproject.toml index 137094fe..f341d83e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ diagnostics = [ "pydot", "pandas", "matplotlib", + "rich", "seaborn", ] beam = ["apache-beam", "gcsfs"]