Skip to content

Commit

Permalink
spiralling pattern & 2d bucket arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Mickus Timothee committed Sep 27, 2023
1 parent 61910fc commit dd4575c
Showing 1 changed file with 78 additions and 39 deletions.
117 changes: 78 additions & 39 deletions mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import itertools
import math
import random

import torch
Expand Down Expand Up @@ -30,15 +31,18 @@ def numel_fn(_):
elif batch_type == 'tokens':

def bucket_fn(example_dict):
"""map example dict to bucket index"""
# subtract two for bos/eos
src_len = min(len(example_dict['src']), n_buckets) - 2
if 'tgt' in example_dict:
# subtract four for bos/eos on both sides
true_size = len(example_dict['src']) + len(example_dict['tgt']) - 4
tgt_len = min(len(example_dict['tgt']), n_buckets) - 2
else:
true_size = len(example_dict['src']) + 2
tgt_len = src_len
# maybe dump it in the last bucket if it's just too long
return min(n_buckets - 1, true_size)
return src_len, tgt_len

def numel_fn(example_dict):
"""count tokens in example""""
if 'tgt' in example_dict:
true_size = len(example_dict['src']) + len(example_dict['tgt'])
else:
Expand Down Expand Up @@ -71,11 +75,28 @@ def __iter__(self):
yield self.collate_fn(accum)






class LookAheadBucketing():
def __init__(self, examples_stream, look_ahead_size, n_buckets, batch_size, bucket_fn, numel_fn, collate_fn):
self.examples_stream = examples_stream
self._buckets = [[] for _ in range(n_buckets)]
self._lens = [0 for _ in range(n_buckets)]
sekf.n_buckets = n_buckets
self._buckets = [
[
[]
for _ in range(n_buckets)
]
for _ in range(n_buckets)
]
self._lens = [
[
0
for _ in range(n_buckets)
]
for _ in range(n_buckets)
]
self.look_ahead_size = look_ahead_size
self.batch_size = batch_size
self.bucket_fn = bucket_fn
Expand All @@ -86,40 +107,75 @@ def __init__(self, examples_stream, look_ahead_size, n_buckets, batch_size, buck
def _init(self):
logger.info('LookAheadBucketing: initialization start')
for example in itertools.islice(self.examples_stream, self.look_ahead_size):
bucket_idx = self.bucket_fn(example)
self._buckets[bucket_idx].append(example)
self._lens[bucket_idx] += 1
s_bucket, t_bucket = self.bucket_fn(example)
self._buckets[s_bucket][t_bucket].append(example)
self._lens[s_bucket][t_bucket] += 1
logger.info('LookAheadBucketing: initialization done')

def maybe_replenish(self) -> bool:
"""look up one more example to add to this reservoir."""
try:
example = next(self.examples_stream)
bucket_idx = self.bucket_fn(example)
creates_new_bucket = self._lens[bucket_idx] == 0
self._buckets[bucket_idx].append(example)
self._lens[bucket_idx] += 1
s_bucket, t_bucket = self.bucket_fn(example)
creates_new_bucket = self._lens[s_bucket][t_bucket] == 0
self._buckets[s_bucket][t_bucket].append(example)
self._lens[s_bucket][t_bucket] += 1
return creates_new_bucket
except StopIteration:
return None

def bucket_is_empty(self, bucket_idx) -> bool:
return self._lens[bucket_idx] == 0
return self._lens[s_bucket][t_bucket] == 0

def _choose_and_prepare_bucket(self, bucket_idx=None):
"""pick a bucket (at random unless specified) and prepare examples for iteration"""
if bucket_idx is None:
bucket_idx = random.choices(range(len(self._buckets)), weights=self._lens, k=1)[0]
buckets = [(s, t) for s in range(self.n_buckets) for t in range(self.n_buckets)]
weights = [self._lens[s][t] for s in range(self.n_buckets) for t in range(self.n_buckets)]
s_bucket, t_bucket = random.choices(buckets, weights=self._lens, k=1)[0]
# if bucket_idx >= len(self._buckets):
# import pdb; pdb.set_trace()
# if len(self._prefetched[self._buckets[bucket_idx]]) == 0:
# import pdb; pdb.set_trace()
random.shuffle(self._buckets[bucket_idx])
return bucket_idx
random.shuffle(self._buckets[s_bucket][t_bucket])
return s_bucket, t_bucket

def is_empty(self):
return all(size == 0 for size in self._lens)

def _spiralling(self, s_idx, t_idx):
def _seq():
# from https://math.stackexchange.com/questions/163080/
# on-a-two-dimensional-grid-is-there-a-formula-i-can-use-to-spiral-coordinates-in#answer-3448361
for n in itertools.count(1):
k = math.ceil((math.sqrt(n) - 1) / 2.0)
t = 2 * k + 1
m = t ** 2
t = t - 1
if n >= m - t:
yield x + k - (m - n), y - k
else:
m = m - t
if n >= m - t:
yield x + -k, y -k + (m - n)
else:
m = m - t
if n >= m - t:
yield x -k + (m - n), y + k
else:
yield x + k, y + k - (m - n - t)
offsets = _seq()
offsets = map(lambda tup: tup[0] + s_idx, tup[1] + t_idx, offsets)
offsets = filter(
lambda tup: (0 <= tup[0] < self.n_buckets) and (0 <= tup[1] < self.n_buckets),
offsets,
)
offsets = filter(
lambda tup: self._lens[tup[0]][tup[1]] > 0,
offsets,
)
yield from offsets

def __iter__(self):
while True:
# 1. maybe we've exhausted the stream and the buckets
Expand All @@ -132,29 +188,12 @@ def __iter__(self):
# 3. build batch
batch_is_complete = False
while not batch_is_complete:
assert not self.is_empty(), 'Stream should never end!'
# maybe switch buckets
if self.bucket_is_empty(current_bucket_idx):
if self.is_empty():
logger.info('Reached end of stream') # should not happen
if accum:
yield self.collate_fn(accum)
break
if not any(self._lens[current_bucket_idx:]):
# this was the largest bucket, so we'll need to pick the next smallest instead
smallest_bucket_idx = next(
bucket_idx
for bucket_idx in range(smallest_bucket_idx, -1, -1)
if self._lens[bucket_idx] != 0
)
current_bucket_idx = smallest_bucket_idx
else:
# there was a larger bucket, shift the index by one
current_bucket_idx = next(
bucket_idx
for bucket_idx in range(current_bucket_idx, len(self._buckets) + 1)
if self._lens[bucket_idx] != 0
)
_ = self._choose_and_prepare_bucket(bucket_idx=current_bucket_idx)
current_bucket_idx = smallest_bucket_idx
next_indices = self._spiralling(*current_bucket_idx)
while self.bucket_is_empty(current_bucket_idx):
current_bucket_idx = next(next_indices)
# retrieve and process the example
example = self._buckets[current_bucket_idx].pop()
self._lens[current_bucket_idx] -= 1
Expand Down

0 comments on commit dd4575c

Please sign in to comment.