Skip to content

Commit

Permalink
bucket_size > pool_size
Browse files Browse the repository at this point in the history
  • Loading branch information
Mickus Timothee committed Sep 25, 2023
1 parent adea52a commit 1177ba8
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
10 changes: 5 additions & 5 deletions mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class DynamicDatasetIter(object):
batch_size (int): numbers of examples in a batch;
batch_size_multiple (int): make batch size multiply of this;
data_type (str): input data type, currently only text;
bucket_size (int): accum this number of examples in a dynamic dataset;
pool_size (int): accum this number of examples in a dynamic dataset;
skip_empty_level (str): security level when encouter empty line;
stride (int): iterate data files with this stride;
offset (int): iterate data files with this offset.
Expand All @@ -209,7 +209,7 @@ def __init__(
batch_size,
batch_size_multiple,
data_type="text",
bucket_size=2048,
pool_size=2048,
n_buckets=1024,
skip_empty_level='warning',
):
Expand All @@ -225,7 +225,7 @@ def __init__(
self.batch_size = batch_size
self.batch_size_multiple = batch_size_multiple
self.device = 'cpu'
self.bucket_size = bucket_size
self.pool_size = pool_size
self.n_buckets = n_buckets
if skip_empty_level not in ['silent', 'warning', 'error']:
raise ValueError(f"Invalid argument skip_empty_level={skip_empty_level}")
Expand All @@ -250,7 +250,7 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra
batch_size,
batch_size_multiple,
data_type=opts.data_type,
bucket_size=opts.bucket_size,
pool_size=opts.pool_size,
n_buckets=opts.n_buckets,
skip_empty_level=opts.skip_empty_level,
)
Expand Down Expand Up @@ -285,7 +285,7 @@ def _init_datasets(self):
corpus,
self.batch_size,
self.batch_type,
self.bucket_size,
self.pool_size,
n_buckets=self.n_buckets,
cycle=self.is_train,
as_iter=self.is_train,
Expand Down
4 changes: 2 additions & 2 deletions mammoth/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,8 +993,8 @@ def _add_train_general_opts(parser):
def _add_train_dynamic_data(parser):
group = parser.add_argument_group("Dynamic data")
group.add(
"-bucket_size",
"--bucket_size",
"-pool_size",
"--pool_size",
type=int,
default=2048,
help="Number of examples to dynamically pool before batching.",
Expand Down
10 changes: 5 additions & 5 deletions tools/attention_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _extract(sentences_file, model, vocabs_dict, transforms, enc_id, batch_size=
yield (memory_bank, src_lengths)


def extract(opts, vocabs_dict, model, model_opt, transforms):
def extract(opts, vocabs_dict, model, model_opts, transforms):
"""Compute representations drawn from the encoder and save them to file."""
sentence_reps = []
for src, src_length in _extract(
Expand All @@ -95,7 +95,7 @@ def extract(opts, vocabs_dict, model, model_opt, transforms):
torch.save(sentence_reps, opts.dump_file)


def estimate(opts, vocabs_dict, model, model_opt, transforms):
def estimate(opts, vocabs_dict, model, model_opts, transforms):
"""Estimate the matrix-variate distribution of representations drawn from the encoder."""
try:
import sklearn.covariance
Expand Down Expand Up @@ -134,7 +134,7 @@ def estimate(opts, vocabs_dict, model, model_opt, transforms):
# return sampling_fn


def classify(opts, vocabs_dict, model, model_opt, transforms):
def classify(opts, vocabs_dict, model, model_opts, transforms):
"""Learn a simple SGD classifier using representations drawn from the encoder."""
try:
import sklearn.linear_model
Expand Down Expand Up @@ -224,7 +224,7 @@ def main():
# ArgumentParser.validate_translate_opts_dynamic(opts)
opts.enc_id = opts.enc_id or opts.src_lang

vocabs_dict, model, model_opt = load_test_multitask_model(opts, opts.model)
vocabs_dict, model, model_opts = load_test_multitask_model(opts, opts.model)
command_fn = {
fn.__name__: fn for fn in
[extract, estimate, classify]
Expand All @@ -238,7 +238,7 @@ def main():
]
transform = TransformPipe.build_from(data_transform)

command_fn(opts, vocabs_dict, model.to(opts.device), model_opt, transform)
command_fn(opts, vocabs_dict, model.to(opts.device), model_opts, transform)


if __name__ == '__main__':
Expand Down
14 changes: 7 additions & 7 deletions tools/extract_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,22 @@ def main():

# Add in default model arguments, possibly added since training.
checkpoint = torch.load(opts.model, map_location=lambda storage, loc: storage)
model_opt = checkpoint['opts']
model_opts = checkpoint['opts']

fields = checkpoint['vocab']
src_dict = fields['src'].base_field.vocab # assumes src is text
tgt_dict = fields['tgt'].base_field.vocab

model_opt = checkpoint['opts']
model_opts = checkpoint['opts']
for arg in dummy_opt.__dict__:
if arg not in model_opt:
model_opt.__dict__[arg] = dummy_opt.__dict__[arg]
if arg not in model_opts:
model_opts.__dict__[arg] = dummy_opt.__dict__[arg]

# build_base_model expects updated and validated opts
ArgumentParser.update_model_opts(model_opt)
ArgumentParser.validate_model_opts(model_opt)
ArgumentParser.update_model_opts(model_opts)
ArgumentParser.validate_model_opts(model_opts)

model = mammoth.model_builder.build_base_model(model_opt, fields, use_gpu(opts), checkpoint)
model = mammoth.model_builder.build_base_model(model_opts, fields, use_gpu(opts), checkpoint)
encoder = model.encoder # no encoder for LM task
decoder = model.decoder

Expand Down

0 comments on commit 1177ba8

Please sign in to comment.