From f32c8daa823cf3032f2f35feab360dcd5c17be9d Mon Sep 17 00:00:00 2001 From: nbrosse <31697743+nbrosse@users.noreply.github.com> Date: Fri, 22 Dec 2023 14:31:06 +0100 Subject: [PATCH] Wandb (#6) * wandb disable project id * wandb args * dict and lmdb * wandb watch * wandb pyproject.toml installation --------- Co-authored-by: nicolas.brosse --- .github/workflows/publish.yml | 8 +-- pyproject.toml | 6 ++ requirements.txt | 11 +++- setup.py | 18 ++++-- unicore/data/dictionary.py | 10 +--- unicore/data/lmdb_dataset.py | 6 +- unicore/data/tokenize_dataset.py | 2 +- unicore/logging/progress_bar.py | 99 +++++++++++++++++--------------- unicore/options.py | 15 ++--- unicore_cli/train.py | 51 +++++++++------- 10 files changed, 128 insertions(+), 98 deletions(-) create mode 100644 pyproject.toml diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 014daba..7e42486 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -62,8 +62,8 @@ jobs: - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} if: ${{ matrix.cuda-version != 'cpu' }} run: | - pip --no-cache-dir install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses - pip --no-cache-dir install sympy networkx jinja2 cmake lit + pip --no-cache-dir install numpy pyyaml mkl mkl-include ninja cython + pip --no-cache-dir install jinja2 cmake lit pip --no-cache-dir install torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${{ matrix.cuda-version }} python --version python -c "import torch; print('PyTorch:', torch.__version__)" @@ -76,7 +76,7 @@ jobs: if: ${{ matrix.cuda-version == 'cpu' }} run: | python --version - pip --no-cache-dir install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses + pip --no-cache-dir install numpy pyyaml mkl mkl-include ninja cython pip --no-cache-dir install torch==${{ matrix.torch-version }} python -c "import torch; print('PyTorch:', torch.__version__)" shell: @@ -86,7 +86,7 @@ jobs: if: ${{ matrix.cuda-version != 'cpu' }} run: | pip install wheel - python setup.py bdist_wheel --dist-dir=dist + python setup.py --cross-compile bdist_wheel --dist-dir=dist tmpname=cu${{ matrix.cuda-version }}torch${{ matrix.torch-version }} wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") ls dist/*whl |xargs -I {} mv {} ${wheel_name} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..bb30b2d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +build-backend = 'setuptools.build_meta' +requires = [ + 'setuptools >= 58.0.0', + 'torch', +] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8ce7832..9cc2fbf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,13 @@ -iopath lmdb -ml_collections numpy -scipy tensorboardX tqdm tokenizers +wandb +mkl +mkl-include +ninja +cython +jinja2 +cmake +lit \ No newline at end of file diff --git a/setup.py b/setup.py index 1219e2c..93553da 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + import torch from torch.utils import cpp_extension from torch.utils.cpp_extension import CUDAExtension, BuildExtension @@ -15,13 +16,20 @@ from setuptools import find_packages, setup DISABLE_CUDA_EXTENSION = False +CROSS_COMPILE = False +CUDA_AVAILABLE = torch.cuda.is_available() filtered_args = [] for i, arg in enumerate(sys.argv): if arg == '--disable-cuda-ext': DISABLE_CUDA_EXTENSION = True - continue - filtered_args.append(arg) + elif arg == '--cross-compile': + CROSS_COMPILE = True + else: + filtered_args.append(arg) sys.argv = filtered_args +# +if not CROSS_COMPILE and not CUDA_AVAILABLE: + DISABLE_CUDA_EXTENSION = True if sys.version_info < (3, 7): @@ -114,7 +122,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): if torch.utils.cpp_extension.CUDA_HOME is None: raise RuntimeError("Nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - # check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME) + check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME) generator_flag = [] torch_dir = torch.__path__[0] @@ -229,15 +237,13 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): "Topic :: Scientific/Engineering :: Artificial Intelligence", ], setup_requires=[ - "setuptools>=18.0", + "setuptools>=58.0.0", ], install_requires=[ 'numpy; python_version>="3.7"', "lmdb", "torch>=2.0.0", "tqdm", - "ml_collections", - "scipy", "tensorboardX", "tokenizers", "wandb", diff --git a/unicore/data/dictionary.py b/unicore/data/dictionary.py index db2b2c8..c92cdb5 100644 --- a/unicore/data/dictionary.py +++ b/unicore/data/dictionary.py @@ -17,11 +17,10 @@ def __init__( *, # begin keyword-only arguments bos="[CLS]", pad="[PAD]", - eos="[SEP]", unk="[UNK]", extra_special_symbols=None, ): - self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos + self.bos_word, self.unk_word, self.pad_word = bos, unk, pad self.symbols = [] self.count = [] self.indices = {} @@ -29,7 +28,6 @@ def __init__( self.specials.add(bos) self.specials.add(unk) self.specials.add(pad) - self.specials.add(eos) def __eq__(self, other): return self.indices == other.indices @@ -82,10 +80,6 @@ def pad(self): """Helper to get index of pad symbol""" return self.index(self.pad_word) - def eos(self): - """Helper to get index of end-of-sentence symbol""" - return self.index(self.eos_word) - def unk(self): """Helper to get index of unk symbol""" return self.index(self.unk_word) @@ -128,7 +122,7 @@ def add_from_file(self, f): try: splits = line.rstrip().rsplit(" ", 1) line = splits[0] - field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx) + field = splits[1] if len(splits) > 1 else str(1) if field == "#overwrite": overwrite = True line, field = line.rsplit(" ", 1) diff --git a/unicore/data/lmdb_dataset.py b/unicore/data/lmdb_dataset.py index 70741d4..34637f6 100644 --- a/unicore/data/lmdb_dataset.py +++ b/unicore/data/lmdb_dataset.py @@ -9,12 +9,16 @@ import numpy as np import collections from functools import lru_cache + +from unicore.data import BaseWrapperDataset + from . import data_utils import logging logger = logging.getLogger(__name__) -class LMDBDataset: +class LMDBDataset(BaseWrapperDataset): def __init__(self, db_path): + super(LMDBDataset, self).__init__(dataset=None) self.db_path = db_path assert os.path.isfile(self.db_path), "{} not found".format( self.db_path diff --git a/unicore/data/tokenize_dataset.py b/unicore/data/tokenize_dataset.py index 46887f8..0176293 100644 --- a/unicore/data/tokenize_dataset.py +++ b/unicore/data/tokenize_dataset.py @@ -24,5 +24,5 @@ def __init__( @lru_cache(maxsize=16) def __getitem__(self, index: int): raw_data = self.dataset[index] - assert len(raw_data) < self.max_seq_len and len(raw_data) > 0 + assert len(raw_data) <= self.max_seq_len and len(raw_data) > 0 return torch.from_numpy(self.dictionary.vec_index(raw_data)).long() \ No newline at end of file diff --git a/unicore/logging/progress_bar.py b/unicore/logging/progress_bar.py index ee05c77..382764d 100644 --- a/unicore/logging/progress_bar.py +++ b/unicore/logging/progress_bar.py @@ -16,8 +16,8 @@ from collections import OrderedDict from contextlib import contextmanager from numbers import Number -from pathlib import Path from typing import Optional +import wandb import torch @@ -33,11 +33,10 @@ def progress_bar( log_interval: int = 100, epoch: Optional[int] = None, prefix: Optional[str] = None, - tensorboard_logdir: Optional[str] = None, - wandb_logdir: Optional[str] = None, + save_dir: Optional[str] = "./save/", + tensorboard: bool = False, wandb_project: Optional[str] = None, default_log_format: str = "tqdm", - args=None, ): if log_format is None: log_format = default_log_format @@ -55,10 +54,13 @@ def progress_bar( else: raise ValueError("Unknown log format: {}".format(log_format)) - if tensorboard_logdir: - bar = TensorboardProgressBarWrapper( - bar, tensorboard_logdir, wandb_logdir, wandb_project, args - ) + if tensorboard: + tensorboard_logdir = os.path.join(save_dir, "tsb") + os.makedirs(tensorboard_logdir, exist_ok=True) + bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir) + + if wandb_project: + bar = WandBProgressBarWrapper(bar) return bar @@ -279,23 +281,14 @@ def print(self, stats, tag=None, step=None): except ImportError: SummaryWriter = None -try: - _wandb_inited = False - import wandb - - wandb_available = True -except ImportError: - wandb_available = False - - def _close_writers(): for w in _tensorboard_writers.values(): w.close() - if _wandb_inited: - try: - wandb.finish() - except: - pass + # if _wandb_inited: + # try: + # wandb.finish() + # except: + # pass atexit.register(_close_writers) @@ -304,7 +297,7 @@ def _close_writers(): class TensorboardProgressBarWrapper(BaseProgressBar): """Log to tensorboard.""" - def __init__(self, wrapped_bar, tensorboard_logdir, wandb_logdir, wandb_project, args): + def __init__(self, wrapped_bar, tensorboard_logdir): self.wrapped_bar = wrapped_bar self.tensorboard_logdir = tensorboard_logdir @@ -312,26 +305,6 @@ def __init__(self, wrapped_bar, tensorboard_logdir, wandb_logdir, wandb_project, logger.warning( "tensorboard not found, please install with: pip install tensorboard" ) - global _wandb_inited - if not _wandb_inited and wandb_project and wandb_available: - wandb_name = args.wandb_name or wandb.util.generate_id() - if "/" in wandb_project: - entity, project = wandb_project.split("/") - else: - entity, project = None, wandb_project - wandb.init( - project=project, - entity=entity, - name=wandb_name, - dir=wandb_logdir, - config=vars(args), - # id=wandb_name, - resume="allow", - ) - wandb.define_metric("custom_step") - wandb.define_metric("train_*", step_metric="custom_step") - wandb.define_metric("valid_*", step_metric="custom_step") - _wandb_inited = True def _writer(self, key): if SummaryWriter is None: @@ -377,7 +350,41 @@ def _log_to_tensorboard(self, stats, tag=None, step=None): val = None if val: writer.add_scalar(key, val, step) - if _wandb_inited: - # wandb.log({"{}_{}".format(tag, key): val}, step=step) - wandb.log({"{}_{}".format(tag, key): val, "custom_step": step}) writer.flush() + + +class WandBProgressBarWrapper(BaseProgressBar): + """Log to Weights & Biases.""" + + def __init__(self, wrapped_bar: BaseProgressBar): + # global _wandb_inited + # _wandb_inited = True + self.wrapped_bar = wrapped_bar + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + self._log_to_wandb(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + self._log_to_wandb(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def update_config(self, config): + """Log latest configuration.""" + wandb.config.update(config) + self.wrapped_bar.update_config(config) + + def _log_to_wandb(self, stats, tag=None, step=None): + if step is None: + step = stats["num_updates"] + + prefix = "" if tag is None else tag + "/" + + for key in stats.keys() - {"num_updates"}: + if isinstance(stats[key], AverageMeter): + wandb.log({prefix + key: stats[key].val}, step=step) + elif isinstance(stats[key], Number): + wandb.log({prefix + key: stats[key]}, step=step) diff --git a/unicore/options.py b/unicore/options.py index a355eca..72d8304 100644 --- a/unicore/options.py +++ b/unicore/options.py @@ -171,15 +171,12 @@ def get_parser(desc, default_task="test"): help='log progress every N batches (when progress bar is disabled)') parser.add_argument('--log-format', default=None, help='log format to use', choices=['json', 'none', 'simple', 'tqdm']) - parser.add_argument('--tensorboard-logdir', metavar='DIR', default='', - help='path to save logs for tensorboard, should match --logdir ' - 'of running tensorboard (default: no tensorboard logging)') - parser.add_argument('--wandb-logdir', metavar='DIR', default='', - help='path to save logs for wandb') - parser.add_argument('--wandb-project', metavar='DIR', default='', - help='name of wandb project, empty for no wandb logging, for wandb login, use env WANDB_API_KEY. You can also use team_name/project_name for project name.') - parser.add_argument('--wandb-name', metavar='DIR', default='', - help='wandb run/id name, empty for no wandb logging, for wandb login, use env WANDB_API_KEY') + parser.add_argument('--tensorboard', action='store_true', help='Active tensorboard logging') + parser.add_argument('--wandb-project', default=None, + help='name of wandb project, empty for no wandb logging, for wandb login, use env WANDB_API_KEY.') + parser.add_argument('--wandb-run-name', default=None, help='wandb run name') + parser.add_argument('--wandb-run-id', default=None, help='wandb run id') + parser.add_argument("--wandb-watch", action="store_true", help="Activate wandb watch to log weights and gradients") parser.add_argument('--seed', default=1, type=int, metavar='N', help='pseudo random number generator seed') parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') diff --git a/unicore_cli/train.py b/unicore_cli/train.py index 7e54a8b..f15313c 100644 --- a/unicore_cli/train.py +++ b/unicore_cli/train.py @@ -17,6 +17,8 @@ import numpy as np import torch +import wandb + from unicore import ( checkpoint_utils, options, @@ -72,6 +74,27 @@ def main(args) -> None: model = task.build_model(args) loss = task.build_loss(args) + # Wandb + if args.wandb_project is not None: + logger.info("Wandb init") + wandb_logdir = os.path.join(args.save_dir, "wandb") + os.makedirs(wandb_logdir, exist_ok=True) + wandb.init( + project=args.wandb_project, + name=args.wandb_run_name, + id=args.wandb_run_id, + resume="allow", + config=vars(args) if args is not None else None, + dir=wandb_logdir, + ) + if args.wandb_watch: + logger.info("Wandb watch") + wandb.watch( + models=model, + log="all", + log_freq=args.log_interval, + ) + # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) @@ -194,17 +217,10 @@ def train( log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, - tensorboard_logdir=( - args.tensorboard_logdir if distributed_utils.is_master(args) else None - ), - wandb_logdir=( - args.wandb_logdir if distributed_utils.is_master(args) else None - ), - wandb_project=( - args.wandb_project if distributed_utils.is_master(args) else None - ), + save_dir=args.save_dir, + tensorboard=args.tensorboard if distributed_utils.is_master(args) else False, + wandb_project=args.wandb_project if distributed_utils.is_master(args) else None, default_log_format=("tqdm" if not args.no_progress_bar else "simple"), - args=args, ) trainer.begin_epoch(epoch_itr.epoch) @@ -366,16 +382,11 @@ def validate( log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", - tensorboard_logdir=( - args.tensorboard_logdir - if distributed_utils.is_master(args) - else None - ), - wandb_logdir=( - args.wandb_logdir - if distributed_utils.is_master(args) - else None - ), + save_dir=args.save_dir, + tensorboard=args.tensorboard if distributed_utils.is_master( + args) else False, + wandb_project=args.wandb_project if distributed_utils.is_master( + args) else None, default_log_format=("tqdm" if not args.no_progress_bar else "simple"), )