diff --git a/mammoth/inputters/dataloader.py b/mammoth/inputters/dataloader.py index 72dd0731..1258681c 100644 --- a/mammoth/inputters/dataloader.py +++ b/mammoth/inputters/dataloader.py @@ -1,5 +1,6 @@ import collections import itertools +import math import random import torch @@ -8,16 +9,11 @@ from mammoth.utils.logging import logger -def infinite_iterator(iterable): - return itertools.chain.from_iterable(itertools.repeat(iterable)) - - def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=None, cycle=True, as_iter=True): """Convert an mammoth.inputters.ParallelCorpus into an infinite iterator of batches""" if not cycle: loader = InferenceBatcher(dataset, batch_size) else: - examples_stream = infinite_iterator(dataset) if batch_type == 'sents': n_buckets = 1 @@ -30,23 +26,25 @@ def numel_fn(_): elif batch_type == 'tokens': def bucket_fn(example_dict): + """map example dict to bucket index""" + # subtract two for bos/eos + src_len = min(len(example_dict['src']), n_buckets) - 2 if 'tgt' in example_dict: - # subtract four for bos/eos on both sides - true_size = len(example_dict['src']) + len(example_dict['tgt']) - 4 + tgt_len = min(len(example_dict['tgt']), n_buckets) - 2 else: - true_size = len(example_dict['src']) + 2 + tgt_len = src_len # maybe dump it in the last bucket if it's just too long - return min(n_buckets - 1, true_size) + return src_len, tgt_len def numel_fn(example_dict): + """count tokens in example""" if 'tgt' in example_dict: true_size = len(example_dict['src']) + len(example_dict['tgt']) else: true_size = len(example_dict['src']) return true_size - collate_fn = dataset.collate_fn - loader = LookAheadBucketing(examples_stream, pool_size, n_buckets, batch_size, bucket_fn, numel_fn, collate_fn) + loader = LookAheadBucketing(dataset, pool_size, n_buckets, batch_size, bucket_fn, numel_fn) return iter(loader) if as_iter else loader @@ -72,117 +70,136 @@ def __iter__(self): class LookAheadBucketing(): - def __init__(self, examples_stream, look_ahead_size, n_buckets, batch_size, bucket_fn, numel_fn, collate_fn): - self.examples_stream = examples_stream - self._buckets = [[] for _ in range(n_buckets)] - self._lens = [0 for _ in range(n_buckets)] + def __init__(self, dataset, look_ahead_size, n_buckets, batch_size, bucket_fn, numel_fn): + self.dataset = dataset + # actual generator of examples + self.examples_stream = iter([]) + # tracks whether the stream needs to be restarted + self._is_exhausted = True + self.n_buckets = n_buckets + self._buckets = [ + [ + [] + for _ in range(n_buckets) + ] + for _ in range(n_buckets) + ] self.look_ahead_size = look_ahead_size self.batch_size = batch_size self.bucket_fn = bucket_fn self.numel_fn = numel_fn - self.collate_fn = collate_fn + self.collate_fn = dataset.collate_fn self._init() def _init(self): logger.info('LookAheadBucketing: initialization start') - for example in itertools.islice(self.examples_stream, self.look_ahead_size): - bucket_idx = self.bucket_fn(example) - self._buckets[bucket_idx].append(example) - self._lens[bucket_idx] += 1 + self.examples_stream = iter(self.dataset) + for example in range(self.look_ahead_size): + self.maybe_replenish() + if self._is_exhausted: + break + assert not self.is_empty(), 'Dataset contains no usable example!' logger.info('LookAheadBucketing: initialization done') - def maybe_replenish(self) -> bool: - """look up one more example to add to this reservoir.""" + def maybe_replenish(self): + """try to look up one more example to add to this reservoir.""" try: example = next(self.examples_stream) - bucket_idx = self.bucket_fn(example) - creates_new_bucket = self._lens[bucket_idx] == 0 - self._buckets[bucket_idx].append(example) - self._lens[bucket_idx] += 1 - return creates_new_bucket + s_idx, t_idx = self.bucket_fn(example) + self._buckets[s_idx][t_idx].append(example) + self._is_exhausted = False except StopIteration: - return None - - def bucket_is_empty(self, bucket_idx) -> bool: - return self._lens[bucket_idx] == 0 - - def _choose_and_prepare_bucket(self, bucket_idx=None): - """pick a bucket (at random unless specified) and prepare examples for iteration""" - if bucket_idx is None: - bucket_idx = random.choices(range(len(self._buckets)), weights=self._lens, k=1)[0] - # if bucket_idx >= len(self._buckets): - # import pdb; pdb.set_trace() - # if len(self._prefetched[self._buckets[bucket_idx]]) == 0: - # import pdb; pdb.set_trace() - random.shuffle(self._buckets[bucket_idx]) + self._is_exhausted = True + + def bucket_is_empty(self, s_idx: int, t_idx: int) -> bool: + """check if this bucket is empty""" + return len(self._buckets[s_idx][t_idx]) == 0 + + def _choose_bucket(self): + """pick a bucket at random""" + buckets = [(s, t) for s in range(self.n_buckets) for t in range(self.n_buckets)] + weights = [len(self._buckets[s][t]) for s in range(self.n_buckets) for t in range(self.n_buckets)] + bucket_idx = random.choices(buckets, weights=weights, k=1)[0] return bucket_idx - def is_empty(self): - return all(size == 0 for size in self._lens) + def _select_from_bucket(self, s_idx: int, t_idx: int) -> object: + """randomly select an item from a bucket""" + bucket = self._buckets[s_idx][t_idx] + obj_idx = random.randrange(len(bucket)) + # swap to last to get O(1) deletion + bucket[obj_idx], bucket[-1] = bucket[-1], bucket[obj_idx] + return bucket.pop() + + def is_empty(self) -> bool: + """check if all buckets are empty""" + return all(len(bucket) == 0 for bucket in itertools.chain.from_iterable(self._buckets)) + + def _spiralling(self, s_idx: int, t_idx: int): + def _seq(): + # from https://math.stackexchange.com/questions/163080/on-a-two-dimensional-grid-is-there-a-formula-i-can-use-to-spiral-coordinates-in#answer-3448361 # noqa: E501 + for n in itertools.count(1): + k = math.ceil((math.sqrt(n) - 1) / 2.0) + t = 2 * k + 1 + m = t ** 2 + t = t - 1 + if n >= m - t: + yield k - (m - n), k + else: + m = m - t + if n >= m - t: + yield -k, k - (m - n) + else: + m = m - t + if n >= m - t: + yield -k + (m - n), -k + else: + yield k, -k + (m - n - t) + + offsets = ((s_idx + x, t_idx + y) for x, y in _seq()) + # offsets = itertools.takewhile( + # # this far out is obviously too far out + # lambda tup: (tup[0] < self.n_buckets * 2 + 1) and (tup[1] < self.n_buckets * 2 + 1), + # offsets, + # ) + offsets = filter( + lambda tup: (0 <= tup[0] < self.n_buckets) and (0 <= tup[1] < self.n_buckets), + offsets, + ) + # maybe more brittle than the takewhile a few lines above + offsets = itertools.islice(offsets, self.n_buckets ** 2) + yield from offsets def __iter__(self): while True: - # 1. maybe we've exhausted the stream and the buckets - if self.is_empty(): - break + # 1. maybe we've exhausted both the stream and the buckets: + # if so, we restart the example stream + if self.is_empty() and self._is_exhausted: + self._init() accum, cur_batch_size = [], 0 # 2. pick a length at random - smallest_bucket_idx = self._choose_and_prepare_bucket() + smallest_bucket_idx = self._choose_bucket() current_bucket_idx = smallest_bucket_idx # 3. build batch batch_is_complete = False - while not batch_is_complete: + # stop either when batch is built or when it can't be built + while not (batch_is_complete or self.is_empty()): # maybe switch buckets - if self.bucket_is_empty(current_bucket_idx): - if self.is_empty(): - logger.info('Reached end of stream') # should not happen - if accum: - yield self.collate_fn(accum) - break - try: - if not any(self._lens[current_bucket_idx:]): - # this was the largest bucket, so we'll need to pick the next smallest instead - smallest_bucket_idx = next( - bucket_idx - for bucket_idx in range(smallest_bucket_idx, -1, -1) - if self._lens[bucket_idx] != 0 - ) - current_bucket_idx = smallest_bucket_idx - else: - # there was a larger bucket, shift the index by one - current_bucket_idx = next( - bucket_idx - for bucket_idx in range(current_bucket_idx, len(self._buckets) + 1) - if self._lens[bucket_idx] != 0 - ) - except StopIteration: - logger.warning( - 'StopIteration when trying to pick a bucket in a smart way. ' - 'Doing something stupid instead. Please check me.' - ) - current_bucket_idx = next( - bucket_idx - for bucket_idx in range(len(self._lens)) - if self._lens[bucket_idx] != 0 - ) - _ = self._choose_and_prepare_bucket(bucket_idx=current_bucket_idx) + current_bucket_idx = smallest_bucket_idx + next_indices = self._spiralling(*current_bucket_idx) + while self.bucket_is_empty(*current_bucket_idx): + current_bucket_idx = next(next_indices) # retrieve and process the example - example = self._buckets[current_bucket_idx].pop() - self._lens[current_bucket_idx] -= 1 + example = self._select_from_bucket(*current_bucket_idx) accum.append(example) numel = self.numel_fn(example) cur_batch_size += numel batch_is_complete = cur_batch_size >= self.batch_size # 4. try to replenish reservoir if possible + # if not, this will also update self._is_exhausted self.maybe_replenish() - # if (new_bucket is not None) and (new_bucket <= bucket): - # assert self._buckets[bucket_idx] != bucket - # bucket_idx += 1 yield self.collate_fn(accum) - # if self.bucket_is_empty(bucket_idx): - # del self._buckets[bucket_idx] class DynamicDatasetIter(object): diff --git a/mammoth/tests/test_look_ahead_bucketing.py b/mammoth/tests/test_look_ahead_bucketing.py new file mode 100644 index 00000000..f119924a --- /dev/null +++ b/mammoth/tests/test_look_ahead_bucketing.py @@ -0,0 +1,78 @@ +from itertools import product + +import unittest +from mammoth.inputters.dataloader import ( + build_dataloader, + LookAheadBucketing, + InferenceBatcher, +) + + +class hashabledict(dict): + def __hash__(self): + return hash(tuple(sorted(self.items()))) + + +class MockStream(): + def __init__(self, items): + self.items = items + + def __len__(self): + return len(self.items) + + def __getitem__(self, idx): + return self.items[idx] + + def __iter__(self): + return iter(self.items) + + def collate_fn(self, items): + return items + + +class TestLookAheadBucketing(unittest.TestCase): + + def test_all_read(self): + stream = MockStream([ + hashabledict({ + 'src': tuple([letter for _ in range(i)]), + 'tgt': tuple([letter for _ in range(j)]), + }) + for letter in 'xyz' + for i, j in product(range(1, 11), range(1, 11)) + ]) + lab = build_dataloader(stream, 2, 'tokens', pool_size=4, n_buckets=4, cycle=True, as_iter=False) + examples_read = [] + batches = iter(lab) + while not (lab._is_exhausted and lab.is_empty()): + examples_read.extend(next(batches)) + sorted_src_ref = sorted([ex['src'] for ex in stream.items]) + sorted_src_obs = sorted([ex['src'] for ex in examples_read]) + self.assertTrue(sorted_src_ref == sorted_src_obs) + sorted_tgt_ref = sorted([ex['tgt'] for ex in stream.items]) + sorted_tgt_obs = sorted([ex['tgt'] for ex in examples_read]) + self.assertTrue(sorted_tgt_ref == sorted_tgt_obs) + + def test_reroutes(self): + stream = MockStream([hashabledict({'src': '_', 'tgt': '_'})] * 10) + lab = build_dataloader(stream, 2, 'tokens', 4, 2, cycle=True, as_iter=False) + self.assertTrue(type(lab) is LookAheadBucketing) + not_lab = build_dataloader(stream, 2, 'tokens', 4, 2, cycle=False, as_iter=False) + self.assertTrue(type(not_lab) is InferenceBatcher) + + def test_always_continues(self): + stream = MockStream([hashabledict({'src': '_', 'tgt': '_'})] * 10) + was_exhausted = False + stopped_exhaustion = False + lab = build_dataloader(stream, 2, 'tokens', pool_size=4, n_buckets=4, cycle=True, as_iter=False) + batches = iter(lab) + all_items = [] + for _ in range(len(stream) * 3 // 2): + all_items.extend(next(batches)) + was_exhausted = was_exhausted or lab._is_exhausted + if was_exhausted: + stopped_exhaustion = stopped_exhaustion or not lab._is_exhausted + + self.assertTrue(was_exhausted) + self.assertTrue(stopped_exhaustion) + self.assertTrue(len(all_items) > len(stream))