Skip to content

Commit

Permalink
drop comments / debugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Mickus Timothee committed Sep 18, 2023
1 parent ce2858a commit 86b8183
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 13 deletions.
2 changes: 1 addition & 1 deletion onmt/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def numel_fn(example_dict):

class InferenceBatcher():
"""Iterator for inference"""
def __init__(self, dataset, batch_size, as_iter=False):
def __init__(self, dataset, batch_size):
self.examples_stream = dataset
self.collate_fn = dataset.collate_fn
self.batch_size = batch_size
Expand Down
11 changes: 0 additions & 11 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,6 @@ def init_distributed(model, task_queue_manager):
logger.debug(f'{task_queue_manager.node_rank}:{task_queue_manager.local_rank} {name}: {p.flatten()[:10]}')


# def iter_on_device(iterator, device_context):
# if device_context.is_gpu():
# device = torch.device(f'cuda:{device_context.local_rank}')
# else:
# device = torch.device('cpu')
# for batch, meta, comm_batch_id in iterator:
# yield batch.to(device), meta, comm_batch_id


def main(
opt,
vocabs_dict,
Expand Down Expand Up @@ -208,8 +199,6 @@ def _train_iter():
# train_iter = iter_on_device(train_iter, device_context)
logger.info("Device {} - Valid iter".format(device_context.id))
valid_iter = _build_valid_iter(opt, vocabs_dict, transforms_cls, task_queue_manager)
# if valid_iter is not None:
# valid_iter = iter_on_device(valid_iter, device_context)

if len(opt.gpu_ranks):
if device_context.is_master():
Expand Down
2 changes: 1 addition & 1 deletion onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def train(
)

if step % valid_steps == 0 and valid_iter is not None:
if True or self.gpu_verbose_level > 0:
if self.gpu_verbose_level > 0:
logger.info(f'{device_context.node_rank}:{device_context.local_rank} validate step {step}')
valid_stats = self.validate(
iter_on_device(valid_iter, device_context),
Expand Down

0 comments on commit 86b8183

Please sign in to comment.