Skip to content

Commit

Permalink
fixing validation stats aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
Mickus Timothee committed Sep 18, 2023
1 parent 7933c0c commit ce2858a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 49 deletions.
71 changes: 38 additions & 33 deletions onmt/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,57 @@ def infinite_iterator(iterable):
return itertools.chain.from_iterable(itertools.repeat(iterable))


def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=None, cycle=True):
def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=None, cycle=True, as_iter=True):
"""Convert an onmt.inputters.ParallelCorpus into an infinite iterator of batches"""
if not cycle:
return iter(InferenceBatcher(dataset, batch_size))
loader = InferenceBatcher(dataset, batch_size)
else:
examples_stream = infinite_iterator(dataset)
if batch_type == 'sents':
n_buckets = 1

examples_stream = infinite_iterator(dataset)
if batch_type == 'sents':
n_buckets = 1
def bucket_fn(_):
return 0

def bucket_fn(_):
return 0
def numel_fn(_):
return 1

def numel_fn(_):
return 1
elif batch_type == 'tokens':

elif batch_type == 'tokens':
def bucket_fn(example_dict):
if 'tgt' in example_dict:
# subtract four for bos/eos on both sides
true_size = len(example_dict['src']) + len(example_dict['tgt']) - 4
else:
true_size = len(example_dict['src']) + 2
# maybe dump it in the last bucket if it's just too long
return min(n_buckets - 1, true_size)

def bucket_fn(example_dict):
if 'tgt' in example_dict:
# subtract four for bos/eos on both sides
true_size = len(example_dict['src']) + len(example_dict['tgt']) - 4
else:
true_size = len(example_dict['src']) + 2
# maybe dump it in the last bucket if it's just too long
return min(n_buckets - 1, true_size)
def numel_fn(example_dict):
if 'tgt' in example_dict:
true_size = len(example_dict['src']) + len(example_dict['tgt'])
else:
true_size = len(example_dict['src'])
return true_size

def numel_fn(example_dict):
if 'tgt' in example_dict:
true_size = len(example_dict['src']) + len(example_dict['tgt'])
else:
true_size = len(example_dict['src'])
return true_size

collate_fn = dataset.collate_fn
lab = LookAheadBucketing(examples_stream, pool_size, n_buckets, batch_size, bucket_fn, numel_fn, collate_fn)
return iter(lab)
collate_fn = dataset.collate_fn
loader = LookAheadBucketing(examples_stream, pool_size, n_buckets, batch_size, bucket_fn, numel_fn, collate_fn)
return iter(loader) if as_iter else loader


DatasetMetadata = collections.namedtuple('DatasetMetadata', 'src_lang tgt_lang encoder_id decoder_id corpus_id')


class InferenceBatcher():
"""Iterator for inference"""
def __init__(self, dataset, batch_size):
self.examples_stream = iter(dataset)
def __init__(self, dataset, batch_size, as_iter=False):
self.examples_stream = dataset
self.collate_fn = dataset.collate_fn
self.batch_size = batch_size

def __iter__(self):
accum = []
for example in self.examples_stream:
for example in iter(self.examples_stream):
accum.append(example)
if len(accum) >= self.batch_size:
yield self.collate_fn(accum)
Expand Down Expand Up @@ -299,6 +299,7 @@ def _init_datasets(self):
self.bucket_size,
n_buckets=self.n_buckets,
cycle=self.is_train,
as_iter=self.is_train,
)

self.dataset_iterators[task.corpus_id] = (ordered_iter, metadata)
Expand All @@ -310,8 +311,12 @@ def __iter__(self):
self._init_datasets()

if not self.is_train:
for ordered_iter, metadata in self.dataset_iterators.values():
yield from zip(ordered_iter, itertools.repeat(metadata), itertools.repeat(0))
# to be absolutely clear: all the validation data is read per validation loop
all_val_data = [
zip(ordered_iter, itertools.repeat(metadata), itertools.repeat(0))
for ordered_iter, metadata in self.dataset_iterators.values()
]
yield from itertools.chain.from_iterable(all_val_data)

else:
# All minibatches with the same communication_batch_id should be trained on
Expand Down
20 changes: 10 additions & 10 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ def init_distributed(model, task_queue_manager):
logger.debug(f'{task_queue_manager.node_rank}:{task_queue_manager.local_rank} {name}: {p.flatten()[:10]}')


def iter_on_device(iterator, device_context):
if device_context.is_gpu():
device = torch.device(f'cuda:{device_context.local_rank}')
else:
device = torch.device('cpu')
for batch, meta, comm_batch_id in iterator:
yield batch.to(device), meta, comm_batch_id
# def iter_on_device(iterator, device_context):
# if device_context.is_gpu():
# device = torch.device(f'cuda:{device_context.local_rank}')
# else:
# device = torch.device('cpu')
# for batch, meta, comm_batch_id in iterator:
# yield batch.to(device), meta, comm_batch_id


def main(
Expand Down Expand Up @@ -205,11 +205,11 @@ def _train_iter():
yield batch, metadata, communication_batch_id

train_iter = _train_iter()
train_iter = iter_on_device(train_iter, device_context)
# train_iter = iter_on_device(train_iter, device_context)
logger.info("Device {} - Valid iter".format(device_context.id))
valid_iter = _build_valid_iter(opt, vocabs_dict, transforms_cls, task_queue_manager)
if valid_iter is not None:
valid_iter = iter_on_device(valid_iter, device_context)
# if valid_iter is not None:
# valid_iter = iter_on_device(valid_iter, device_context)

if len(opt.gpu_ranks):
if device_context.is_master():
Expand Down
29 changes: 23 additions & 6 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
from onmt.utils.logging import logger


def iter_on_device(iterator, device_context):
if device_context.is_gpu():
device = torch.device(f'cuda:{device_context.local_rank}')
else:
device = torch.device('cpu')
for batch, meta, comm_batch_id in iterator:
yield batch.to(device), meta, comm_batch_id


def build_trainer(
opt,
device_context,
Expand Down Expand Up @@ -245,6 +254,7 @@ def train(
Returns:
The gathered statistics.
"""
train_iter = iter_on_device(train_iter, device_context)
if valid_iter is None:
logger.info('Start training loop without validation...')
else:
Expand Down Expand Up @@ -351,18 +361,25 @@ def train(
report_stats,
)

if valid_iter is not None and step % valid_steps == 0:
if self.gpu_verbose_level > 0:
if step % valid_steps == 0 and valid_iter is not None:
if True or self.gpu_verbose_level > 0:
logger.info(f'{device_context.node_rank}:{device_context.local_rank} validate step {step}')
valid_stats = self.validate(
valid_iter, moving_average=self.moving_average)
iter_on_device(valid_iter, device_context),
moving_average=self.moving_average,
)
if self.gpu_verbose_level > 0:
logger.info(f'{device_context.node_rank}:{device_context.local_rank} gather valid stat step {step}')
valid_stats = self._maybe_gather_stats(valid_stats)
if self.gpu_verbose_level > 0:
logger.info(f'{device_context.node_rank}:{device_context.local_rank} report stat step {step}')
self._report_step(self.optim.learning_rate(), # learning_rate_to_show, #self.optim.learning_rate(),
step, valid_stats=valid_stats)
if device_context.is_master():
self._report_step(
self.optim.learning_rate(), # learning_rate_to_show, #self.optim.learning_rate(),
step,
valid_stats=valid_stats,
)

# # Run patience mechanism
# if self.earlystopper is not None:
# self.earlystopper(valid_stats, step)
Expand Down Expand Up @@ -524,7 +541,7 @@ def _maybe_gather_stats(self, stat):
Returns:
stat: the updated (or unchanged) stat object
"""
if stat is not None and self.device_context.is_distributed() > 1:
if stat is not None and self.device_context.is_distributed():
return onmt.utils.Statistics.all_gather_stats(stat)
return stat

Expand Down

0 comments on commit ce2858a

Please sign in to comment.