From f97fb04406f337f92f46c43bdef556051becf622 Mon Sep 17 00:00:00 2001 From: Nick Date: Thu, 6 Jun 2024 23:28:15 +0300 Subject: [PATCH 1/3] feat: update pbar --- moai/engine/progressbar.py | 46 ++++++++++++++++++++++++++++++++++++++ moai/engine/runner.py | 4 +++- 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 moai/engine/progressbar.py diff --git a/moai/engine/progressbar.py b/moai/engine/progressbar.py new file mode 100644 index 00000000..971ce1c0 --- /dev/null +++ b/moai/engine/progressbar.py @@ -0,0 +1,46 @@ +import rich.progress +from pytorch_lightning.callbacks import RichProgressBar +from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme + +__all__ = ["MoaiProgressBar"] + +# progress_bar = RichProgressBar( +# theme=RichProgressBarTheme( +# description="green_yellow", +# progress_bar="green1", +# progress_bar_finished="green1", +# progress_bar_pulse="#6206E0", +# batch_progress="green_yellow", +# time="grey82", +# processing_speed="grey82", +# metrics="grey82", +# metrics_text_delimiter="\n", +# metrics_format=".3e", +# ) + + +# NOTE: check https://github.com/Textualize/rich/discussions/482 +# NOTE: check https://github.com/facebookresearch/EGG/blob/a139946a73d45553360a7f897626d1ae20759f12/egg/core/callbacks.py#L335 +# NOTE: check https://github.com/Textualize/rich/discussions/921 +class MoaiProgressBar(RichProgressBar): + def __init__(self) -> None: + super().__init__( + theme=RichProgressBarTheme(metrics_text_delimiter="|"), + ) + + # return [ + # TextColumn("[progress.description]{task.description}"), + # CustomBarColumn( + # complete_style=self.theme.progress_bar, + # finished_style=self.theme.progress_bar_finished, + # pulse_style=self.theme.progress_bar_pulse, + # ), + # BatchesProcessedColumn(style=self.theme.batch_progress), + # CustomTimeColumn(style=self.theme.time), + # ProcessingSpeedColumn(style=self.theme.processing_speed), + # ] + def configure_columns(self, trainer: "pl.Trainer") -> list: + original = super().configure_columns(trainer) + moai_column = rich.progress.TextColumn(":moai:") + spinner_column = rich.progress.SpinnerColumn(spinner_name="dots5") + return [moai_column, spinner_column] + original diff --git a/moai/engine/runner.py b/moai/engine/runner.py index ed8f416a..f62d08b4 100644 --- a/moai/engine/runner.py +++ b/moai/engine/runner.py @@ -5,6 +5,7 @@ import pytorch_lightning as L from omegaconf.omegaconf import DictConfig +from moai.engine.progressbar import MoaiProgressBar from moai.engine.run_callback import RunCallback log = logging.getLogger(__name__) @@ -92,7 +93,8 @@ def __init__( [hyu.instantiate(logger) for logger in loggers.values()] if loggers else [] ) pytl_callbacks = [ - RunCallback() + RunCallback(), + MoaiProgressBar(), ] # TODO: only when moai model is used, should not be used for custom models pytl_callbacks.extend( [hyu.instantiate(c) for c in callbacks.values()] From 1f2f66a1ddf9ba9e5fb7382a7a651457b44d66d6 Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 7 Jun 2024 10:06:34 +0300 Subject: [PATCH 2/3] Update requirements.dev.txt --- requirements.dev.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.dev.txt b/requirements.dev.txt index e6d310ae..dba23f7e 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,4 +1,5 @@ black==24.4.2 pre-commit==3.7.1 isort==5.13.2 -yamllint==1.35.1 \ No newline at end of file +yamllint==1.35.1 +pytest==8.2.0 \ No newline at end of file From de253010919b0b6fae60b69afa967069eeec5eaf Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 7 Jun 2024 11:19:07 +0300 Subject: [PATCH 3/3] cleanup: printing messages --- moai/data/datasets/generic/npz.py | 4 ++-- moai/parameters/initialization/schemes/zero_flow_params.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/moai/data/datasets/generic/npz.py b/moai/data/datasets/generic/npz.py index 3785c16b..37daac03 100644 --- a/moai/data/datasets/generic/npz.py +++ b/moai/data/datasets/generic/npz.py @@ -44,7 +44,7 @@ def __init__( filename: str = "", ): self.file = load_npz_file(filename) - log.info(f"Loaded an .npz file producing [{list(self.file.keys())}].") + log.info(f"Loaded an .npz file producing {list(self.file.keys())}.") def __len__(self) -> int: return len(self.file[toolz.first(self.file)]) @@ -63,7 +63,7 @@ def __init__( ): self.file = load_npz_file(filename) self.length = length - log.info(f"Loaded an .npz file producing [{list(self.file.keys())}].") + log.info(f"Loaded an .npz file producing {list(self.file.keys())}.") def __len__(self) -> int: return self.length diff --git a/moai/parameters/initialization/schemes/zero_flow_params.py b/moai/parameters/initialization/schemes/zero_flow_params.py index 45dcc69f..434f06c4 100644 --- a/moai/parameters/initialization/schemes/zero_flow_params.py +++ b/moai/parameters/initialization/schemes/zero_flow_params.py @@ -18,13 +18,16 @@ def __init__( self.keys = keys def __call__(self, module: torch.nn.Module) -> None: + zeroed_keys = [] for key in self.keys: try: m = get_parameter(module.named_flows, key) if m is not None: - log.info(f"Zeroing out parameter: [cyan italic]{key}[/].") with torch.no_grad(): # TODO: remove this and add in root apply call m.zero_() m.grad = None + zeroed_keys.append(key) except: break + all_zeroed_keys = ",".join(zeroed_keys) + log.info(f"Zeroing out parameters: [cyan italic]\[{all_zeroed_keys}][/].")