Skip to content

Commit

Permalink
Merge pull request #23 from Helsinki-NLP/feats/bucket-lord
Browse files Browse the repository at this point in the history
Feats/bucket lord
  • Loading branch information
TimotheeMickus authored Oct 23, 2023
2 parents af7774f + f981d26 commit ed7261c
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 88 deletions.
193 changes: 105 additions & 88 deletions mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import itertools
import math
import random

import torch
Expand All @@ -8,16 +9,11 @@
from mammoth.utils.logging import logger


def infinite_iterator(iterable):
return itertools.chain.from_iterable(itertools.repeat(iterable))


def build_dataloader(dataset, batch_size, batch_type, pool_size=None, n_buckets=None, cycle=True, as_iter=True):
"""Convert an mammoth.inputters.ParallelCorpus into an infinite iterator of batches"""
if not cycle:
loader = InferenceBatcher(dataset, batch_size)
else:
examples_stream = infinite_iterator(dataset)
if batch_type == 'sents':
n_buckets = 1

Expand All @@ -30,23 +26,25 @@ def numel_fn(_):
elif batch_type == 'tokens':

def bucket_fn(example_dict):
"""map example dict to bucket index"""
# subtract two for bos/eos
src_len = min(len(example_dict['src']), n_buckets) - 2
if 'tgt' in example_dict:
# subtract four for bos/eos on both sides
true_size = len(example_dict['src']) + len(example_dict['tgt']) - 4
tgt_len = min(len(example_dict['tgt']), n_buckets) - 2
else:
true_size = len(example_dict['src']) + 2
tgt_len = src_len
# maybe dump it in the last bucket if it's just too long
return min(n_buckets - 1, true_size)
return src_len, tgt_len

def numel_fn(example_dict):
"""count tokens in example"""
if 'tgt' in example_dict:
true_size = len(example_dict['src']) + len(example_dict['tgt'])
else:
true_size = len(example_dict['src'])
return true_size

collate_fn = dataset.collate_fn
loader = LookAheadBucketing(examples_stream, pool_size, n_buckets, batch_size, bucket_fn, numel_fn, collate_fn)
loader = LookAheadBucketing(dataset, pool_size, n_buckets, batch_size, bucket_fn, numel_fn)
return iter(loader) if as_iter else loader


Expand All @@ -72,117 +70,136 @@ def __iter__(self):


class LookAheadBucketing():
def __init__(self, examples_stream, look_ahead_size, n_buckets, batch_size, bucket_fn, numel_fn, collate_fn):
self.examples_stream = examples_stream
self._buckets = [[] for _ in range(n_buckets)]
self._lens = [0 for _ in range(n_buckets)]
def __init__(self, dataset, look_ahead_size, n_buckets, batch_size, bucket_fn, numel_fn):
self.dataset = dataset
# actual generator of examples
self.examples_stream = iter([])
# tracks whether the stream needs to be restarted
self._is_exhausted = True
self.n_buckets = n_buckets
self._buckets = [
[
[]
for _ in range(n_buckets)
]
for _ in range(n_buckets)
]
self.look_ahead_size = look_ahead_size
self.batch_size = batch_size
self.bucket_fn = bucket_fn
self.numel_fn = numel_fn
self.collate_fn = collate_fn
self.collate_fn = dataset.collate_fn
self._init()

def _init(self):
logger.info('LookAheadBucketing: initialization start')
for example in itertools.islice(self.examples_stream, self.look_ahead_size):
bucket_idx = self.bucket_fn(example)
self._buckets[bucket_idx].append(example)
self._lens[bucket_idx] += 1
self.examples_stream = iter(self.dataset)
for example in range(self.look_ahead_size):
self.maybe_replenish()
if self._is_exhausted:
break
assert not self.is_empty(), 'Dataset contains no usable example!'
logger.info('LookAheadBucketing: initialization done')

def maybe_replenish(self) -> bool:
"""look up one more example to add to this reservoir."""
def maybe_replenish(self):
"""try to look up one more example to add to this reservoir."""
try:
example = next(self.examples_stream)
bucket_idx = self.bucket_fn(example)
creates_new_bucket = self._lens[bucket_idx] == 0
self._buckets[bucket_idx].append(example)
self._lens[bucket_idx] += 1
return creates_new_bucket
s_idx, t_idx = self.bucket_fn(example)
self._buckets[s_idx][t_idx].append(example)
self._is_exhausted = False
except StopIteration:
return None

def bucket_is_empty(self, bucket_idx) -> bool:
return self._lens[bucket_idx] == 0

def _choose_and_prepare_bucket(self, bucket_idx=None):
"""pick a bucket (at random unless specified) and prepare examples for iteration"""
if bucket_idx is None:
bucket_idx = random.choices(range(len(self._buckets)), weights=self._lens, k=1)[0]
# if bucket_idx >= len(self._buckets):
# import pdb; pdb.set_trace()
# if len(self._prefetched[self._buckets[bucket_idx]]) == 0:
# import pdb; pdb.set_trace()
random.shuffle(self._buckets[bucket_idx])
self._is_exhausted = True

def bucket_is_empty(self, s_idx: int, t_idx: int) -> bool:
"""check if this bucket is empty"""
return len(self._buckets[s_idx][t_idx]) == 0

def _choose_bucket(self):
"""pick a bucket at random"""
buckets = [(s, t) for s in range(self.n_buckets) for t in range(self.n_buckets)]
weights = [len(self._buckets[s][t]) for s in range(self.n_buckets) for t in range(self.n_buckets)]
bucket_idx = random.choices(buckets, weights=weights, k=1)[0]
return bucket_idx

def is_empty(self):
return all(size == 0 for size in self._lens)
def _select_from_bucket(self, s_idx: int, t_idx: int) -> object:
"""randomly select an item from a bucket"""
bucket = self._buckets[s_idx][t_idx]
obj_idx = random.randrange(len(bucket))
# swap to last to get O(1) deletion
bucket[obj_idx], bucket[-1] = bucket[-1], bucket[obj_idx]
return bucket.pop()

def is_empty(self) -> bool:
"""check if all buckets are empty"""
return all(len(bucket) == 0 for bucket in itertools.chain.from_iterable(self._buckets))

def _spiralling(self, s_idx: int, t_idx: int):
def _seq():
# from https://math.stackexchange.com/questions/163080/on-a-two-dimensional-grid-is-there-a-formula-i-can-use-to-spiral-coordinates-in#answer-3448361 # noqa: E501
for n in itertools.count(1):
k = math.ceil((math.sqrt(n) - 1) / 2.0)
t = 2 * k + 1
m = t ** 2
t = t - 1
if n >= m - t:
yield k - (m - n), k
else:
m = m - t
if n >= m - t:
yield -k, k - (m - n)
else:
m = m - t
if n >= m - t:
yield -k + (m - n), -k
else:
yield k, -k + (m - n - t)

offsets = ((s_idx + x, t_idx + y) for x, y in _seq())
# offsets = itertools.takewhile(
# # this far out is obviously too far out
# lambda tup: (tup[0] < self.n_buckets * 2 + 1) and (tup[1] < self.n_buckets * 2 + 1),
# offsets,
# )
offsets = filter(
lambda tup: (0 <= tup[0] < self.n_buckets) and (0 <= tup[1] < self.n_buckets),
offsets,
)
# maybe more brittle than the takewhile a few lines above
offsets = itertools.islice(offsets, self.n_buckets ** 2)
yield from offsets

def __iter__(self):
while True:
# 1. maybe we've exhausted the stream and the buckets
if self.is_empty():
break
# 1. maybe we've exhausted both the stream and the buckets:
# if so, we restart the example stream
if self.is_empty() and self._is_exhausted:
self._init()
accum, cur_batch_size = [], 0
# 2. pick a length at random
smallest_bucket_idx = self._choose_and_prepare_bucket()
smallest_bucket_idx = self._choose_bucket()
current_bucket_idx = smallest_bucket_idx
# 3. build batch
batch_is_complete = False
while not batch_is_complete:
# stop either when batch is built or when it can't be built
while not (batch_is_complete or self.is_empty()):
# maybe switch buckets
if self.bucket_is_empty(current_bucket_idx):
if self.is_empty():
logger.info('Reached end of stream') # should not happen
if accum:
yield self.collate_fn(accum)
break
try:
if not any(self._lens[current_bucket_idx:]):
# this was the largest bucket, so we'll need to pick the next smallest instead
smallest_bucket_idx = next(
bucket_idx
for bucket_idx in range(smallest_bucket_idx, -1, -1)
if self._lens[bucket_idx] != 0
)
current_bucket_idx = smallest_bucket_idx
else:
# there was a larger bucket, shift the index by one
current_bucket_idx = next(
bucket_idx
for bucket_idx in range(current_bucket_idx, len(self._buckets) + 1)
if self._lens[bucket_idx] != 0
)
except StopIteration:
logger.warning(
'StopIteration when trying to pick a bucket in a smart way. '
'Doing something stupid instead. Please check me.'
)
current_bucket_idx = next(
bucket_idx
for bucket_idx in range(len(self._lens))
if self._lens[bucket_idx] != 0
)
_ = self._choose_and_prepare_bucket(bucket_idx=current_bucket_idx)
current_bucket_idx = smallest_bucket_idx
next_indices = self._spiralling(*current_bucket_idx)
while self.bucket_is_empty(*current_bucket_idx):
current_bucket_idx = next(next_indices)
# retrieve and process the example
example = self._buckets[current_bucket_idx].pop()
self._lens[current_bucket_idx] -= 1
example = self._select_from_bucket(*current_bucket_idx)
accum.append(example)
numel = self.numel_fn(example)
cur_batch_size += numel
batch_is_complete = cur_batch_size >= self.batch_size

# 4. try to replenish reservoir if possible
# if not, this will also update self._is_exhausted
self.maybe_replenish()
# if (new_bucket is not None) and (new_bucket <= bucket):
# assert self._buckets[bucket_idx] != bucket
# bucket_idx += 1

yield self.collate_fn(accum)
# if self.bucket_is_empty(bucket_idx):
# del self._buckets[bucket_idx]


class DynamicDatasetIter(object):
Expand Down
78 changes: 78 additions & 0 deletions mammoth/tests/test_look_ahead_bucketing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from itertools import product

import unittest
from mammoth.inputters.dataloader import (
build_dataloader,
LookAheadBucketing,
InferenceBatcher,
)


class hashabledict(dict):
def __hash__(self):
return hash(tuple(sorted(self.items())))


class MockStream():
def __init__(self, items):
self.items = items

def __len__(self):
return len(self.items)

def __getitem__(self, idx):
return self.items[idx]

def __iter__(self):
return iter(self.items)

def collate_fn(self, items):
return items


class TestLookAheadBucketing(unittest.TestCase):

def test_all_read(self):
stream = MockStream([
hashabledict({
'src': tuple([letter for _ in range(i)]),
'tgt': tuple([letter for _ in range(j)]),
})
for letter in 'xyz'
for i, j in product(range(1, 11), range(1, 11))
])
lab = build_dataloader(stream, 2, 'tokens', pool_size=4, n_buckets=4, cycle=True, as_iter=False)
examples_read = []
batches = iter(lab)
while not (lab._is_exhausted and lab.is_empty()):
examples_read.extend(next(batches))
sorted_src_ref = sorted([ex['src'] for ex in stream.items])
sorted_src_obs = sorted([ex['src'] for ex in examples_read])
self.assertTrue(sorted_src_ref == sorted_src_obs)
sorted_tgt_ref = sorted([ex['tgt'] for ex in stream.items])
sorted_tgt_obs = sorted([ex['tgt'] for ex in examples_read])
self.assertTrue(sorted_tgt_ref == sorted_tgt_obs)

def test_reroutes(self):
stream = MockStream([hashabledict({'src': '_', 'tgt': '_'})] * 10)
lab = build_dataloader(stream, 2, 'tokens', 4, 2, cycle=True, as_iter=False)
self.assertTrue(type(lab) is LookAheadBucketing)
not_lab = build_dataloader(stream, 2, 'tokens', 4, 2, cycle=False, as_iter=False)
self.assertTrue(type(not_lab) is InferenceBatcher)

def test_always_continues(self):
stream = MockStream([hashabledict({'src': '_', 'tgt': '_'})] * 10)
was_exhausted = False
stopped_exhaustion = False
lab = build_dataloader(stream, 2, 'tokens', pool_size=4, n_buckets=4, cycle=True, as_iter=False)
batches = iter(lab)
all_items = []
for _ in range(len(stream) * 3 // 2):
all_items.extend(next(batches))
was_exhausted = was_exhausted or lab._is_exhausted
if was_exhausted:
stopped_exhaustion = stopped_exhaustion or not lab._is_exhausted

self.assertTrue(was_exhausted)
self.assertTrue(stopped_exhaustion)
self.assertTrue(len(all_items) > len(stream))

0 comments on commit ed7261c

Please sign in to comment.