From dd4575ced31d0eb680fafb07331f3252e8779ce7 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Wed, 27 Sep 2023 16:31:30 +0300 Subject: [PATCH] spiralling pattern & 2d bucket arrays --- mammoth/inputters/dataloader.py | 117 +++++++++++++++++++++----------- 1 file changed, 78 insertions(+), 39 deletions(-) diff --git a/mammoth/inputters/dataloader.py b/mammoth/inputters/dataloader.py index 7e26987b..c85a0c36 100644 --- a/mammoth/inputters/dataloader.py +++ b/mammoth/inputters/dataloader.py @@ -1,5 +1,6 @@ import collections import itertools +import math import random import torch @@ -30,15 +31,18 @@ def numel_fn(_): elif batch_type == 'tokens': def bucket_fn(example_dict): + """map example dict to bucket index""" + # subtract two for bos/eos + src_len = min(len(example_dict['src']), n_buckets) - 2 if 'tgt' in example_dict: - # subtract four for bos/eos on both sides - true_size = len(example_dict['src']) + len(example_dict['tgt']) - 4 + tgt_len = min(len(example_dict['tgt']), n_buckets) - 2 else: - true_size = len(example_dict['src']) + 2 + tgt_len = src_len # maybe dump it in the last bucket if it's just too long - return min(n_buckets - 1, true_size) + return src_len, tgt_len def numel_fn(example_dict): + """count tokens in example"""" if 'tgt' in example_dict: true_size = len(example_dict['src']) + len(example_dict['tgt']) else: @@ -71,11 +75,28 @@ def __iter__(self): yield self.collate_fn(accum) + + + + class LookAheadBucketing(): def __init__(self, examples_stream, look_ahead_size, n_buckets, batch_size, bucket_fn, numel_fn, collate_fn): self.examples_stream = examples_stream - self._buckets = [[] for _ in range(n_buckets)] - self._lens = [0 for _ in range(n_buckets)] + sekf.n_buckets = n_buckets + self._buckets = [ + [ + [] + for _ in range(n_buckets) + ] + for _ in range(n_buckets) + ] + self._lens = [ + [ + 0 + for _ in range(n_buckets) + ] + for _ in range(n_buckets) + ] self.look_ahead_size = look_ahead_size self.batch_size = batch_size self.bucket_fn = bucket_fn @@ -86,40 +107,75 @@ def __init__(self, examples_stream, look_ahead_size, n_buckets, batch_size, buck def _init(self): logger.info('LookAheadBucketing: initialization start') for example in itertools.islice(self.examples_stream, self.look_ahead_size): - bucket_idx = self.bucket_fn(example) - self._buckets[bucket_idx].append(example) - self._lens[bucket_idx] += 1 + s_bucket, t_bucket = self.bucket_fn(example) + self._buckets[s_bucket][t_bucket].append(example) + self._lens[s_bucket][t_bucket] += 1 logger.info('LookAheadBucketing: initialization done') def maybe_replenish(self) -> bool: """look up one more example to add to this reservoir.""" try: example = next(self.examples_stream) - bucket_idx = self.bucket_fn(example) - creates_new_bucket = self._lens[bucket_idx] == 0 - self._buckets[bucket_idx].append(example) - self._lens[bucket_idx] += 1 + s_bucket, t_bucket = self.bucket_fn(example) + creates_new_bucket = self._lens[s_bucket][t_bucket] == 0 + self._buckets[s_bucket][t_bucket].append(example) + self._lens[s_bucket][t_bucket] += 1 return creates_new_bucket except StopIteration: return None def bucket_is_empty(self, bucket_idx) -> bool: - return self._lens[bucket_idx] == 0 + return self._lens[s_bucket][t_bucket] == 0 def _choose_and_prepare_bucket(self, bucket_idx=None): """pick a bucket (at random unless specified) and prepare examples for iteration""" if bucket_idx is None: - bucket_idx = random.choices(range(len(self._buckets)), weights=self._lens, k=1)[0] + buckets = [(s, t) for s in range(self.n_buckets) for t in range(self.n_buckets)] + weights = [self._lens[s][t] for s in range(self.n_buckets) for t in range(self.n_buckets)] + s_bucket, t_bucket = random.choices(buckets, weights=self._lens, k=1)[0] # if bucket_idx >= len(self._buckets): # import pdb; pdb.set_trace() # if len(self._prefetched[self._buckets[bucket_idx]]) == 0: # import pdb; pdb.set_trace() - random.shuffle(self._buckets[bucket_idx]) - return bucket_idx + random.shuffle(self._buckets[s_bucket][t_bucket]) + return s_bucket, t_bucket def is_empty(self): return all(size == 0 for size in self._lens) + def _spiralling(self, s_idx, t_idx): + def _seq(): + # from https://math.stackexchange.com/questions/163080/ + # on-a-two-dimensional-grid-is-there-a-formula-i-can-use-to-spiral-coordinates-in#answer-3448361 + for n in itertools.count(1): + k = math.ceil((math.sqrt(n) - 1) / 2.0) + t = 2 * k + 1 + m = t ** 2 + t = t - 1 + if n >= m - t: + yield x + k - (m - n), y - k + else: + m = m - t + if n >= m - t: + yield x + -k, y -k + (m - n) + else: + m = m - t + if n >= m - t: + yield x -k + (m - n), y + k + else: + yield x + k, y + k - (m - n - t) + offsets = _seq() + offsets = map(lambda tup: tup[0] + s_idx, tup[1] + t_idx, offsets) + offsets = filter( + lambda tup: (0 <= tup[0] < self.n_buckets) and (0 <= tup[1] < self.n_buckets), + offsets, + ) + offsets = filter( + lambda tup: self._lens[tup[0]][tup[1]] > 0, + offsets, + ) + yield from offsets + def __iter__(self): while True: # 1. maybe we've exhausted the stream and the buckets @@ -132,29 +188,12 @@ def __iter__(self): # 3. build batch batch_is_complete = False while not batch_is_complete: + assert not self.is_empty(), 'Stream should never end!' # maybe switch buckets - if self.bucket_is_empty(current_bucket_idx): - if self.is_empty(): - logger.info('Reached end of stream') # should not happen - if accum: - yield self.collate_fn(accum) - break - if not any(self._lens[current_bucket_idx:]): - # this was the largest bucket, so we'll need to pick the next smallest instead - smallest_bucket_idx = next( - bucket_idx - for bucket_idx in range(smallest_bucket_idx, -1, -1) - if self._lens[bucket_idx] != 0 - ) - current_bucket_idx = smallest_bucket_idx - else: - # there was a larger bucket, shift the index by one - current_bucket_idx = next( - bucket_idx - for bucket_idx in range(current_bucket_idx, len(self._buckets) + 1) - if self._lens[bucket_idx] != 0 - ) - _ = self._choose_and_prepare_bucket(bucket_idx=current_bucket_idx) + current_bucket_idx = smallest_bucket_idx + next_indices = self._spiralling(*current_bucket_idx) + while self.bucket_is_empty(current_bucket_idx): + current_bucket_idx = next(next_indices) # retrieve and process the example example = self._buckets[current_bucket_idx].pop() self._lens[current_bucket_idx] -= 1