Skip to content

Commit

Permalink
Wandb (#6)
Browse files Browse the repository at this point in the history
* wandb disable project id

* wandb args

* dict and lmdb

* wandb watch

* wandb pyproject.toml installation

---------

Co-authored-by: nicolas.brosse <[email protected]>
  • Loading branch information
nbrosse and nicolas.brosse authored Dec 22, 2023
1 parent b6cd9f3 commit f32c8da
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 98 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ jobs:
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
run: |
pip --no-cache-dir install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses
pip --no-cache-dir install sympy networkx jinja2 cmake lit
pip --no-cache-dir install numpy pyyaml mkl mkl-include ninja cython
pip --no-cache-dir install jinja2 cmake lit
pip --no-cache-dir install torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${{ matrix.cuda-version }}
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
Expand All @@ -76,7 +76,7 @@ jobs:
if: ${{ matrix.cuda-version == 'cpu' }}
run: |
python --version
pip --no-cache-dir install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses
pip --no-cache-dir install numpy pyyaml mkl mkl-include ninja cython
pip --no-cache-dir install torch==${{ matrix.torch-version }}
python -c "import torch; print('PyTorch:', torch.__version__)"
shell:
Expand All @@ -86,7 +86,7 @@ jobs:
if: ${{ matrix.cuda-version != 'cpu' }}
run: |
pip install wheel
python setup.py bdist_wheel --dist-dir=dist
python setup.py --cross-compile bdist_wheel --dist-dir=dist
tmpname=cu${{ matrix.cuda-version }}torch${{ matrix.torch-version }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} ${wheel_name}
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[build-system]
build-backend = 'setuptools.build_meta'
requires = [
'setuptools >= 58.0.0',
'torch',
]
11 changes: 8 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
iopath
lmdb
ml_collections
numpy
scipy
tensorboardX
tqdm
tokenizers
wandb
mkl
mkl-include
ninja
cython
jinja2
cmake
lit
18 changes: 12 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch.utils import cpp_extension
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
Expand All @@ -15,13 +16,20 @@
from setuptools import find_packages, setup

DISABLE_CUDA_EXTENSION = False
CROSS_COMPILE = False
CUDA_AVAILABLE = torch.cuda.is_available()
filtered_args = []
for i, arg in enumerate(sys.argv):
if arg == '--disable-cuda-ext':
DISABLE_CUDA_EXTENSION = True
continue
filtered_args.append(arg)
elif arg == '--cross-compile':
CROSS_COMPILE = True
else:
filtered_args.append(arg)
sys.argv = filtered_args
#
if not CROSS_COMPILE and not CUDA_AVAILABLE:
DISABLE_CUDA_EXTENSION = True


if sys.version_info < (3, 7):
Expand Down Expand Up @@ -114,7 +122,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("Nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")

# check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)
check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)

generator_flag = []
torch_dir = torch.__path__[0]
Expand Down Expand Up @@ -229,15 +237,13 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
setup_requires=[
"setuptools>=18.0",
"setuptools>=58.0.0",
],
install_requires=[
'numpy; python_version>="3.7"',
"lmdb",
"torch>=2.0.0",
"tqdm",
"ml_collections",
"scipy",
"tensorboardX",
"tokenizers",
"wandb",
Expand Down
10 changes: 2 additions & 8 deletions unicore/data/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@ def __init__(
*, # begin keyword-only arguments
bos="[CLS]",
pad="[PAD]",
eos="[SEP]",
unk="[UNK]",
extra_special_symbols=None,
):
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.bos_word, self.unk_word, self.pad_word = bos, unk, pad
self.symbols = []
self.count = []
self.indices = {}
self.specials = set()
self.specials.add(bos)
self.specials.add(unk)
self.specials.add(pad)
self.specials.add(eos)

def __eq__(self, other):
return self.indices == other.indices
Expand Down Expand Up @@ -82,10 +80,6 @@ def pad(self):
"""Helper to get index of pad symbol"""
return self.index(self.pad_word)

def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.index(self.eos_word)

def unk(self):
"""Helper to get index of unk symbol"""
return self.index(self.unk_word)
Expand Down Expand Up @@ -128,7 +122,7 @@ def add_from_file(self, f):
try:
splits = line.rstrip().rsplit(" ", 1)
line = splits[0]
field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx)
field = splits[1] if len(splits) > 1 else str(1)
if field == "#overwrite":
overwrite = True
line, field = line.rsplit(" ", 1)
Expand Down
6 changes: 5 additions & 1 deletion unicore/data/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
import numpy as np
import collections
from functools import lru_cache

from unicore.data import BaseWrapperDataset

from . import data_utils
import logging
logger = logging.getLogger(__name__)

class LMDBDataset:
class LMDBDataset(BaseWrapperDataset):
def __init__(self, db_path):
super(LMDBDataset, self).__init__(dataset=None)
self.db_path = db_path
assert os.path.isfile(self.db_path), "{} not found".format(
self.db_path
Expand Down
2 changes: 1 addition & 1 deletion unicore/data/tokenize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ def __init__(
@lru_cache(maxsize=16)
def __getitem__(self, index: int):
raw_data = self.dataset[index]
assert len(raw_data) < self.max_seq_len and len(raw_data) > 0
assert len(raw_data) <= self.max_seq_len and len(raw_data) > 0
return torch.from_numpy(self.dictionary.vec_index(raw_data)).long()
99 changes: 53 additions & 46 deletions unicore/logging/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from collections import OrderedDict
from contextlib import contextmanager
from numbers import Number
from pathlib import Path
from typing import Optional
import wandb

import torch

Expand All @@ -33,11 +33,10 @@ def progress_bar(
log_interval: int = 100,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
tensorboard_logdir: Optional[str] = None,
wandb_logdir: Optional[str] = None,
save_dir: Optional[str] = "./save/",
tensorboard: bool = False,
wandb_project: Optional[str] = None,
default_log_format: str = "tqdm",
args=None,
):
if log_format is None:
log_format = default_log_format
Expand All @@ -55,10 +54,13 @@ def progress_bar(
else:
raise ValueError("Unknown log format: {}".format(log_format))

if tensorboard_logdir:
bar = TensorboardProgressBarWrapper(
bar, tensorboard_logdir, wandb_logdir, wandb_project, args
)
if tensorboard:
tensorboard_logdir = os.path.join(save_dir, "tsb")
os.makedirs(tensorboard_logdir, exist_ok=True)
bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir)

if wandb_project:
bar = WandBProgressBarWrapper(bar)

return bar

Expand Down Expand Up @@ -279,23 +281,14 @@ def print(self, stats, tag=None, step=None):
except ImportError:
SummaryWriter = None

try:
_wandb_inited = False
import wandb

wandb_available = True
except ImportError:
wandb_available = False


def _close_writers():
for w in _tensorboard_writers.values():
w.close()
if _wandb_inited:
try:
wandb.finish()
except:
pass
# if _wandb_inited:
# try:
# wandb.finish()
# except:
# pass


atexit.register(_close_writers)
Expand All @@ -304,34 +297,14 @@ def _close_writers():
class TensorboardProgressBarWrapper(BaseProgressBar):
"""Log to tensorboard."""

def __init__(self, wrapped_bar, tensorboard_logdir, wandb_logdir, wandb_project, args):
def __init__(self, wrapped_bar, tensorboard_logdir):
self.wrapped_bar = wrapped_bar
self.tensorboard_logdir = tensorboard_logdir

if SummaryWriter is None:
logger.warning(
"tensorboard not found, please install with: pip install tensorboard"
)
global _wandb_inited
if not _wandb_inited and wandb_project and wandb_available:
wandb_name = args.wandb_name or wandb.util.generate_id()
if "/" in wandb_project:
entity, project = wandb_project.split("/")
else:
entity, project = None, wandb_project
wandb.init(
project=project,
entity=entity,
name=wandb_name,
dir=wandb_logdir,
config=vars(args),
# id=wandb_name,
resume="allow",
)
wandb.define_metric("custom_step")
wandb.define_metric("train_*", step_metric="custom_step")
wandb.define_metric("valid_*", step_metric="custom_step")
_wandb_inited = True

def _writer(self, key):
if SummaryWriter is None:
Expand Down Expand Up @@ -377,7 +350,41 @@ def _log_to_tensorboard(self, stats, tag=None, step=None):
val = None
if val:
writer.add_scalar(key, val, step)
if _wandb_inited:
# wandb.log({"{}_{}".format(tag, key): val}, step=step)
wandb.log({"{}_{}".format(tag, key): val, "custom_step": step})
writer.flush()


class WandBProgressBarWrapper(BaseProgressBar):
"""Log to Weights & Biases."""

def __init__(self, wrapped_bar: BaseProgressBar):
# global _wandb_inited
# _wandb_inited = True
self.wrapped_bar = wrapped_bar

def __iter__(self):
return iter(self.wrapped_bar)

def log(self, stats, tag=None, step=None):
self._log_to_wandb(stats, tag, step)
self.wrapped_bar.log(stats, tag=tag, step=step)

def print(self, stats, tag=None, step=None):
self._log_to_wandb(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)

def update_config(self, config):
"""Log latest configuration."""
wandb.config.update(config)
self.wrapped_bar.update_config(config)

def _log_to_wandb(self, stats, tag=None, step=None):
if step is None:
step = stats["num_updates"]

prefix = "" if tag is None else tag + "/"

for key in stats.keys() - {"num_updates"}:
if isinstance(stats[key], AverageMeter):
wandb.log({prefix + key: stats[key].val}, step=step)
elif isinstance(stats[key], Number):
wandb.log({prefix + key: stats[key]}, step=step)
15 changes: 6 additions & 9 deletions unicore/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,12 @@ def get_parser(desc, default_task="test"):
help='log progress every N batches (when progress bar is disabled)')
parser.add_argument('--log-format', default=None, help='log format to use',
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--tensorboard-logdir', metavar='DIR', default='',
help='path to save logs for tensorboard, should match --logdir '
'of running tensorboard (default: no tensorboard logging)')
parser.add_argument('--wandb-logdir', metavar='DIR', default='',
help='path to save logs for wandb')
parser.add_argument('--wandb-project', metavar='DIR', default='',
help='name of wandb project, empty for no wandb logging, for wandb login, use env WANDB_API_KEY. You can also use team_name/project_name for project name.')
parser.add_argument('--wandb-name', metavar='DIR', default='',
help='wandb run/id name, empty for no wandb logging, for wandb login, use env WANDB_API_KEY')
parser.add_argument('--tensorboard', action='store_true', help='Active tensorboard logging')
parser.add_argument('--wandb-project', default=None,
help='name of wandb project, empty for no wandb logging, for wandb login, use env WANDB_API_KEY.')
parser.add_argument('--wandb-run-name', default=None, help='wandb run name')
parser.add_argument('--wandb-run-id', default=None, help='wandb run id')
parser.add_argument("--wandb-watch", action="store_true", help="Activate wandb watch to log weights and gradients")
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
Expand Down
Loading

0 comments on commit f32c8da

Please sign in to comment.