Skip to content

Commit

Permalink
debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
Mickus Timothee committed Sep 27, 2023
1 parent dd4575c commit 4476b26
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def bucket_fn(example_dict):
return src_len, tgt_len

def numel_fn(example_dict):
"""count tokens in example""""
"""count tokens in example"""
if 'tgt' in example_dict:
true_size = len(example_dict['src']) + len(example_dict['tgt'])
else:
Expand Down Expand Up @@ -75,14 +75,10 @@ def __iter__(self):
yield self.collate_fn(accum)






class LookAheadBucketing():
def __init__(self, examples_stream, look_ahead_size, n_buckets, batch_size, bucket_fn, numel_fn, collate_fn):
self.examples_stream = examples_stream
sekf.n_buckets = n_buckets
self.n_buckets = n_buckets
self._buckets = [
[
[]
Expand Down Expand Up @@ -116,23 +112,23 @@ def maybe_replenish(self) -> bool:
"""look up one more example to add to this reservoir."""
try:
example = next(self.examples_stream)
s_bucket, t_bucket = self.bucket_fn(example)
s_bucket, t_bucket = self.bucket_fn(example)
creates_new_bucket = self._lens[s_bucket][t_bucket] == 0
self._buckets[s_bucket][t_bucket].append(example)
self._lens[s_bucket][t_bucket] += 1
return creates_new_bucket
except StopIteration:
return None

def bucket_is_empty(self, bucket_idx) -> bool:
return self._lens[s_bucket][t_bucket] == 0
def bucket_is_empty(self, s_idx, t_idx) -> bool:
return self._lens[s_idx][t_idx] == 0

def _choose_and_prepare_bucket(self, bucket_idx=None):
"""pick a bucket (at random unless specified) and prepare examples for iteration"""
if bucket_idx is None:
buckets = [(s, t) for s in range(self.n_buckets) for t in range(self.n_buckets)]
weights = [self._lens[s][t] for s in range(self.n_buckets) for t in range(self.n_buckets)]
s_bucket, t_bucket = random.choices(buckets, weights=self._lens, k=1)[0]
s_bucket, t_bucket = random.choices(buckets, weights=weights, k=1)[0]
# if bucket_idx >= len(self._buckets):
# import pdb; pdb.set_trace()
# if len(self._prefetched[self._buckets[bucket_idx]]) == 0:
Expand All @@ -153,19 +149,19 @@ def _seq():
m = t ** 2
t = t - 1
if n >= m - t:
yield x + k - (m - n), y - k
yield s_idx + k - (m - n), t_idx - k
else:
m = m - t
if n >= m - t:
yield x + -k, y -k + (m - n)
yield s_idx - k, t_idx - k + (m - n)
else:
m = m - t
if n >= m - t:
yield x -k + (m - n), y + k
yield s_idx - k + (m - n), t_idx + k
else:
yield x + k, y + k - (m - n - t)
offsets = _seq()
offsets = map(lambda tup: tup[0] + s_idx, tup[1] + t_idx, offsets)
yield s_idx + k, t_idx + k - (m - n - t)

offsets = map(lambda tup: (tup[0] + s_idx, tup[1] + t_idx), _seq())
offsets = filter(
lambda tup: (0 <= tup[0] < self.n_buckets) and (0 <= tup[1] < self.n_buckets),
offsets,
Expand All @@ -192,7 +188,7 @@ def __iter__(self):
# maybe switch buckets
current_bucket_idx = smallest_bucket_idx
next_indices = self._spiralling(*current_bucket_idx)
while self.bucket_is_empty(current_bucket_idx):
while self.bucket_is_empty(*current_bucket_idx):
current_bucket_idx = next(next_indices)
# retrieve and process the example
example = self._buckets[current_bucket_idx].pop()
Expand Down

0 comments on commit 4476b26

Please sign in to comment.