Skip to content

Commit

Permalink
Bugfixes and rename parameters of SimpleLookAheadBucketing
Browse files Browse the repository at this point in the history
Due to a rounding bug in determining the size of the minibatch, it was
possible for minibatches to slightly exceed the specified size.
Now the estimation is slightly more pessimistic, and the guarantee
holds.
  • Loading branch information
Waino committed May 20, 2024
1 parent 25a8be0 commit 8475e37
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 77 deletions.
71 changes: 42 additions & 29 deletions mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@
from mammoth.utils.logging import logger


def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=None, cycle=True, as_iter=True):
def build_dataloader(
dataset,
batch_size,
batch_type,
max_look_ahead_sentences=None,
lookahead_minibatches=None,
cycle=True,
as_iter=True
):
"""Convert an mammoth.inputters.ParallelCorpus into an infinite iterator of batches"""
if not cycle:
loader = InferenceBatcher(dataset, batch_size)
Expand All @@ -18,8 +26,8 @@ def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=
elif batch_type == 'tokens':
loader = SimpleLookAheadBucketing(
dataset=dataset,
max_look_ahead_size=pool_size,
n_buckets=n_buckets,
max_look_ahead_sentences=max_look_ahead_sentences,
lookahead_minibatches=lookahead_minibatches,
batch_size=batch_size,
score_fn=SimpleLookAheadBucketing.max_of_lens,
)
Expand Down Expand Up @@ -75,13 +83,13 @@ class SimpleLookAheadBucketing():
"""
Arguments:
dataset: mammoth.inputters.ParallelCorpus
max_look_ahead_size:
max_look_ahead_sentences:
The maximum number of sentence pairs to read before yielding minibatches.
Limits the time spent looping if there is a corpus with unexpectedly short sentences.
n_buckets:
lookahead_minibatches:
The number of minibatches that will be yielded once bucketing is complete.
Recommended value: same as accum_count, or at least a multiple of it.
Setting n_buckets == accum_count means that each accumulated batch uses up the whole buffer.
Setting lookahead_minibatches == accum_count means that each accumulated batch uses up the whole buffer.
All tasks stay in sync concerning the length sorting: each task begins with the smallest
minibatch and ends with the largest just before accumulation ends.
batch_size:
Expand All @@ -91,13 +99,12 @@ class SimpleLookAheadBucketing():
score_fn:
Compute the size estimate (single integer) for sorting examples.
"""
def __init__(self, dataset, max_look_ahead_size, n_buckets, batch_size, score_fn=None):
def __init__(self, dataset, max_look_ahead_sentences, lookahead_minibatches, batch_size, score_fn=None):
score_fn = score_fn if score_fn else self.max_of_lens
self._sie = ScoredInfiniteExamples(dataset, score_fn)
self.max_look_ahead_size = max_look_ahead_size
self.max_look_ahead_sentences = max_look_ahead_sentences
self.batch_size = batch_size
self.n_buckets = n_buckets
self.multi_batch_size = n_buckets * batch_size
self.lookahead_minibatches = lookahead_minibatches
self.collate_fn = dataset.collate_fn

@staticmethod
Expand All @@ -110,32 +117,38 @@ def max_of_lens(example_dict) -> int:

def __iter__(self):
while True:
multibatch = []
maxi_batch = []
max_score = 0
for i in range(self.max_look_ahead_size):
for i in range(self.max_look_ahead_sentences):
score = self._sie.peek_at_score()
# Decide whether to add it or not
if len(multibatch) <= self.n_buckets:
if len(maxi_batch) < self.lookahead_minibatches:
# Always add at least one example per minibatch
still_fits = True
else:
still_fits = (max(max_score, score) * (len(multibatch) + 1)) < (self.multi_batch_size)
estimated_minibatch_size = math.ceil((len(maxi_batch) + 1) / self.lookahead_minibatches)
still_fits = (max(max_score, score) * estimated_minibatch_size) < (self.batch_size)
if still_fits:
score, example = self._sie.next()
multibatch.append((score, example))
maxi_batch.append((score, example))
max_score = max(max_score, score)
else:
break
# Sort by score to reduce padding
multibatch = list(sorted(multibatch, key=lambda x: x[0]))
maxi_batch = list(sorted(maxi_batch, key=lambda x: x[0]))
# Split into minibatches and yield
examples_per_batch = math.ceil(len(multibatch) / self.n_buckets)
multibatch_it = iter(multibatch)
for _ in range(self.n_buckets):
floor_examples_per_batch = math.floor(len(maxi_batch) / self.lookahead_minibatches)
examples_per_batch = [floor_examples_per_batch] * self.lookahead_minibatches
for i in range(len(maxi_batch) % self.lookahead_minibatches):
examples_per_batch[i] += 1
assert all(epb > 0 for epb in examples_per_batch)
assert sum(examples_per_batch) == len(maxi_batch)
maxi_batch_it = iter(maxi_batch)
for epb in examples_per_batch:
yield self.collate_fn(
[
example_dict for _, example_dict
in itertools.islice(multibatch_it, examples_per_batch)
in itertools.islice(maxi_batch_it, epb)
]
)

Expand All @@ -152,7 +165,7 @@ class DynamicDatasetIter(object):
batch_type (str): batching type to count on, choices=[tokens, sents];
batch_size (int): numbers of examples in a batch;
batch_size_multiple (int): make batch size multiply of this;
pool_size (int): accum this number of examples in a dynamic dataset;
max_look_ahead_sentences (int): accum this number of examples in a dynamic dataset;
skip_empty_level (str): security level when encouter empty line;
stride (int): iterate data files with this stride;
offset (int): iterate data files with this offset.
Expand All @@ -173,8 +186,8 @@ def __init__(
batch_type,
batch_size,
batch_size_multiple,
pool_size=2048,
n_buckets=1024,
max_look_ahead_sentences=2048,
lookahead_minibatches=4,
):
self.task_queue_manager = task_queue_manager
self.opts = opts
Expand All @@ -188,8 +201,8 @@ def __init__(
self.batch_size = batch_size
self.batch_size_multiple = batch_size_multiple
self.device = 'cpu'
self.pool_size = pool_size
self.n_buckets = n_buckets
self.max_look_ahead_sentences = max_look_ahead_sentences
self.lookahead_minibatches = lookahead_minibatches

@classmethod
def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_train):
Expand All @@ -209,8 +222,8 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra
opts.batch_type,
batch_size,
batch_size_multiple,
pool_size=opts.pool_size,
n_buckets=opts.n_buckets,
max_look_ahead_sentences=opts.max_look_ahead_sentences,
lookahead_minibatches=opts.lookahead_minibatches,
)

def _init_datasets(self):
Expand Down Expand Up @@ -243,8 +256,8 @@ def _init_datasets(self):
corpus,
self.batch_size,
self.batch_type,
self.pool_size,
n_buckets=self.n_buckets,
self.max_look_ahead_sentences,
lookahead_minibatches=self.lookahead_minibatches,
cycle=self.is_train,
as_iter=self.is_train,
)
Expand Down
28 changes: 18 additions & 10 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,13 @@ def model_opts(parser):

# Encoder-Decoder Options
group = parser.add_argument_group('Model- Encoder-Decoder')
group.add(
'--model_type',
'-model_type',
default='text',
choices=['text'],
help="Type of source model to use. Allows the system to incorporate non-text inputs. Options are [text].",
)
group.add('--model_dtype', '-model_dtype', default='fp32', choices=['fp32', 'fp16'], help='Data type of the model.')

group.add(
Expand Down Expand Up @@ -627,22 +634,23 @@ def _add_train_general_opts(parser):
"uses more memory. Set to 0 to disable.",
)
group.add(
"-pool_size",
"--pool_size",
"-lookahead_minibatches",
"--lookahead_minibatches",
type=int,
default=2048,
help="(Maximum) number of examples to dynamically pool before batching.",
default=4,
help="The number of minibatches that SimpleLookAheadBucketing will read into a maxibatch, "
"pessimisticly sort by length, split into minibatches, and yield in one go. "
"Recommended value: same as accum_count, or at least a multiple of it."
)
group.add(
"-n_buckets",
"--n_buckets",
"-max_look_ahead_sentences",
"--max_look_ahead_sentences",
type=int,
default=4,
help="The number of minibatches that will be yielded once bucketing is complete. "
"Recommended value: same as accum_count, or at least a multiple of it."
default=2048,
help="(Maximum) number of sentence pairs that SimpleLookAheadBucketing can attempt to add to the maxibatch. "
"This is mainly a failsafe in case some corpus contains very short examples.",
)

group = parser.add_argument_group('Optimization')
group.add(
'--optim',
'-optim',
Expand Down
81 changes: 46 additions & 35 deletions mammoth/tests/test_look_ahead_bucketing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from itertools import product

import unittest
from mammoth.inputters.dataloader import build_dataloader


Expand All @@ -26,37 +26,48 @@ def collate_fn(self, items):
return items


class TestLookAheadBucketing(unittest.TestCase):

def test_all_read(self):
max_batch_size = 12
stream = MockStream([
hashabledict({
'src': tuple([letter for _ in range(i)]),
'tgt': tuple([letter for _ in range(j)]),
})
for letter in 'xyz'
for i, j in product(range(1, 11), range(1, 11))
])
lab = build_dataloader(
stream,
batch_size=max_batch_size,
batch_type='tokens',
pool_size=4,
n_buckets=4,
cycle=True,
as_iter=False
)
examples_read = []
batches = iter(lab)
for _ in range(1000):
batch = next(batches)
assert len(batch) > 0
src_toks = sum(len(ex['src']) for ex in batch)
tgt_toks = sum(len(ex['tgt']) for ex in batch)
# check that the batch size is respected
assert src_toks <= max_batch_size
assert tgt_toks <= max_batch_size, str(batch)
examples_read.extend(batch)
# Check that the stream was cycled
self.assertTrue(len(examples_read) > len(stream))
@pytest.mark.parametrize(
('max_batch_size', 'lookahead_minibatches'),
[
(12, 4),
(13, 4),
(14, 4),
(15, 4),
(12, 5),
(13, 5),
(14, 5),
(15, 5),
],
)
def test_simple_lookeahead_bucketing(max_batch_size, lookahead_minibatches):
stream = MockStream([
hashabledict({
'src': tuple([letter for _ in range(i)]),
'tgt': tuple([letter for _ in range(j)]),
})
for letter in 'xyz'
for i, j in product(range(1, 11), range(1, 11))
])
lab = build_dataloader(
stream,
batch_size=max_batch_size,
batch_type='tokens',
max_look_ahead_sentences=512,
lookahead_minibatches=lookahead_minibatches,
cycle=True,
as_iter=False
)
examples_read = []
batches = iter(lab)
for _ in range(1000):
batch = next(batches)
print(batch)
assert len(batch) > 0
src_toks = sum(len(ex['src']) for ex in batch)
tgt_toks = sum(len(ex['tgt']) for ex in batch)
# check that the batch size is respected
assert src_toks <= max_batch_size
assert tgt_toks <= max_batch_size, str(batch)
examples_read.extend(batch)
# Check that the stream was cycled
assert len(examples_read) > len(stream)
4 changes: 3 additions & 1 deletion mammoth/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def main(
transforms_cls = get_transforms_cls(opts._all_transform)
model_opts = _get_model_opts(opts, checkpoint=checkpoint)

task_queue_manager.create_all_distributed_components(use_attention_bridge=model_opts.bridge)
task_queue_manager.create_all_distributed_components(
use_attention_bridge=(model_opts.ab_layers is not None and len(model_opts.ab_layers) != 0),
)

# Build model.

Expand Down
4 changes: 2 additions & 2 deletions mammoth/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,8 @@ def _translate(
corpus,
batch_size=batch_size,
batch_type=batch_type,
pool_size=512,
n_buckets=512,
max_look_ahead_sentences=512,
lookahead_minibatches=512,
cycle=False,
)

Expand Down

0 comments on commit 8475e37

Please sign in to comment.