Skip to content

Commit

Permalink
Dev (#9)
Browse files Browse the repository at this point in the history
* WIP

* WIP

* setup check cuda versions

* WIP

* WIP

---------

Co-authored-by: nicolas.brosse <[email protected]>
  • Loading branch information
nbrosse and nicolas.brosse authored Jan 25, 2024
1 parent f32c8da commit b2dc36d
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 42 deletions.
8 changes: 0 additions & 8 deletions .github/workflows/cuda/cu116-Linux-env.sh

This file was deleted.

15 changes: 0 additions & 15 deletions .github/workflows/cuda/cu116-Linux.sh

This file was deleted.

5 changes: 1 addition & 4 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@ jobs:
matrix:
os: [ubuntu-22.04]
python-version: ['3.9']
torch-version: [1.13.1, 2.0.0]
torch-version: [2.0.0]
cuda-version: ['117', '118']
include:
- os: macos-13
torch-version: 2.0.0
cuda-version: 'cpu'
python-version: '3.9'
exclude:
- torch-version: 1.13.1
cuda-version: '118'

steps:
- name: Checkout
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ def get_cuda_bare_metal_version(cuda_dir):
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])

if not ( (TORCH_MAJOR >= 1 and TORCH_MINOR >= 4)
or (TORCH_MAJOR > 1)
):
if not ((TORCH_MAJOR >= 1 and TORCH_MINOR >= 4) or (TORCH_MAJOR > 1)):
raise RuntimeError("Requires Pytorch 1.4 or newer.\n" +
"The latest stable release can be obtained from https://pytorch.org/")

Expand Down Expand Up @@ -266,6 +264,7 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
entry_points={
"console_scripts": [
"unicore-train = unicore_cli.train:cli_main",
"unicore-infer = unicore_cli.infer:cli_main",
],
},
zip_safe=False,
Expand Down
19 changes: 10 additions & 9 deletions unicore/data/pad_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,30 @@


class PadDataset(BaseWrapperDataset):
def __init__(self, dataset, pad_idx, left_pad):
def __init__(self, dataset, pad_idx, left_pad, pad_to_multiple: int = 8):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
self.pad_to_multiple = pad_to_multiple

def collater(self, samples):
return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8)
return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=self.pad_to_multiple)


class LeftPadDataset(PadDataset):
def __init__(self, dataset, pad_idx):
super().__init__(dataset, pad_idx, left_pad=True)
def __init__(self, dataset, pad_idx, pad_to_multiple: int = 8):
super().__init__(dataset, pad_idx, left_pad=True, pad_to_multiple=pad_to_multiple)


class RightPadDataset(PadDataset):
def __init__(self, dataset, pad_idx):
super().__init__(dataset, pad_idx, left_pad=False)
def __init__(self, dataset, pad_idx, pad_to_multiple: int = 8):
super().__init__(dataset, pad_idx, left_pad=False, pad_to_multiple=pad_to_multiple)


class RightPadDataset2D(BaseWrapperDataset):
def __init__(self, dataset, pad_idx,left_pad=False):
def __init__(self, dataset, pad_idx, pad_to_multiple: int = 8):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
self.pad_to_multiple = pad_to_multiple
def collater(self, samples):
return data_utils.collate_tokens_2d(samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8)
return data_utils.collate_tokens_2d(samples, self.pad_idx, left_pad=False, pad_to_multiple=self.pad_to_multiple)
2 changes: 1 addition & 1 deletion unicore/tasks/unicore_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,4 +327,4 @@ def load_state_dict(self, state_dict: Dict[str, Any]):
self.state.merge_state_dict(state_dict)

def disable_shuffling(self) -> bool:
return False
return True
114 changes: 114 additions & 0 deletions unicore_cli/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#!/usr/bin/env python3 -u
# Copyright (c) DP Techonology, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import pickle
import sys

import torch
from unicore import checkpoint_utils, distributed_utils, options, tasks, utils
from unicore.logging import progress_bar

logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("unicore_cli.inference")


def main(args):
assert (
args.batch_size is not None
), "Must specify batch size either with --batch-size"

use_fp16 = args.fp16
use_cuda = torch.cuda.is_available() and not args.cpu

if use_cuda:
torch.cuda.set_device(args.device_id)

if args.distributed_world_size > 1:
data_parallel_world_size = distributed_utils.get_data_parallel_world_size()
data_parallel_rank = distributed_utils.get_data_parallel_rank()
else:
data_parallel_world_size = 1
data_parallel_rank = 0

# Load model
logger.info("loading model(s) from {}".format(args.path))
state = checkpoint_utils.load_checkpoint_to_cpu(args.path)
task = tasks.setup_task(args)
model = task.build_model(args)
model.load_state_dict(state["model"], strict=False)

# Move models to GPU
if use_cuda:
model.cuda()
# fp16 only supported on CUDA for fused kernels
if use_fp16:
model.half()

# Print args
logger.info(args)

# Build loss
loss = task.build_loss(args)
loss.eval()

for subset in args.valid_subset.split(","):
try:
task.load_dataset(subset, combine=False, epoch=1)
dataset = task.dataset(subset)
except KeyError:
raise Exception("Cannot find dataset: " + subset)

if not os.path.exists(args.results_path):
os.makedirs(args.results_path)
save_path = os.path.join(args.results_path, f"{subset}.out.pkl")
# Initialize data iterator
itr = task.get_batch_iterator(
dataset=dataset,
batch_size=args.batch_size,
ignore_invalid_inputs=True,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=data_parallel_world_size,
shard_id=data_parallel_rank,
num_workers=args.num_workers,
data_buffer_size=args.data_buffer_size,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
itr,
log_format=args.log_format,
log_interval=args.log_interval,
prefix=f"valid on '{subset}' subset",
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
)
log_outputs = []
for i, sample in enumerate(progress):
sample = utils.move_to_cuda(sample) if use_cuda else sample
if len(sample) == 0:
continue
_, _, log_output = task.valid_step(sample, model, loss, test=True)
progress.log({}, step=i)
log_outputs.append(log_output)
pickle.dump(log_outputs, open(save_path, "wb"))
logger.info("Done inference! ")
return None


def cli_main():
parser = options.get_validation_parser()
options.add_model_args(parser)
args = options.parse_args_and_arch(parser)
distributed_utils.call_main(args, main)


if __name__ == "__main__":
cli_main()
5 changes: 3 additions & 2 deletions unicore_cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ def main(args) -> None:
)

# Load valid dataset (we load training data below, based on the latest checkpoint)
for valid_sub_split in args.valid_subset.split(","):
task.load_dataset(valid_sub_split, combine=False, epoch=1)
if not args.disable_validation:
for valid_sub_split in args.valid_subset.split(","):
task.load_dataset(valid_sub_split, combine=False, epoch=1)

logger.info(model)
logger.info("task: {}".format(task.__class__.__name__))
Expand Down

0 comments on commit b2dc36d

Please sign in to comment.