From 1177ba89c069beb89bd04880cbcc7cfa56464680 Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Mon, 25 Sep 2023 17:28:41 +0300 Subject: [PATCH] bucket_size > pool_size --- mammoth/inputters/dataloader.py | 10 +++++----- mammoth/opts.py | 4 ++-- tools/attention_bank.py | 10 +++++----- tools/extract_embeddings.py | 14 +++++++------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/mammoth/inputters/dataloader.py b/mammoth/inputters/dataloader.py index 78cff698..7e26987b 100644 --- a/mammoth/inputters/dataloader.py +++ b/mammoth/inputters/dataloader.py @@ -187,7 +187,7 @@ class DynamicDatasetIter(object): batch_size (int): numbers of examples in a batch; batch_size_multiple (int): make batch size multiply of this; data_type (str): input data type, currently only text; - bucket_size (int): accum this number of examples in a dynamic dataset; + pool_size (int): accum this number of examples in a dynamic dataset; skip_empty_level (str): security level when encouter empty line; stride (int): iterate data files with this stride; offset (int): iterate data files with this offset. @@ -209,7 +209,7 @@ def __init__( batch_size, batch_size_multiple, data_type="text", - bucket_size=2048, + pool_size=2048, n_buckets=1024, skip_empty_level='warning', ): @@ -225,7 +225,7 @@ def __init__( self.batch_size = batch_size self.batch_size_multiple = batch_size_multiple self.device = 'cpu' - self.bucket_size = bucket_size + self.pool_size = pool_size self.n_buckets = n_buckets if skip_empty_level not in ['silent', 'warning', 'error']: raise ValueError(f"Invalid argument skip_empty_level={skip_empty_level}") @@ -250,7 +250,7 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra batch_size, batch_size_multiple, data_type=opts.data_type, - bucket_size=opts.bucket_size, + pool_size=opts.pool_size, n_buckets=opts.n_buckets, skip_empty_level=opts.skip_empty_level, ) @@ -285,7 +285,7 @@ def _init_datasets(self): corpus, self.batch_size, self.batch_type, - self.bucket_size, + self.pool_size, n_buckets=self.n_buckets, cycle=self.is_train, as_iter=self.is_train, diff --git a/mammoth/opts.py b/mammoth/opts.py index 2788d0f6..22a3dfd6 100644 --- a/mammoth/opts.py +++ b/mammoth/opts.py @@ -993,8 +993,8 @@ def _add_train_general_opts(parser): def _add_train_dynamic_data(parser): group = parser.add_argument_group("Dynamic data") group.add( - "-bucket_size", - "--bucket_size", + "-pool_size", + "--pool_size", type=int, default=2048, help="Number of examples to dynamically pool before batching.", diff --git a/tools/attention_bank.py b/tools/attention_bank.py index 3cd7cf13..38bdb866 100644 --- a/tools/attention_bank.py +++ b/tools/attention_bank.py @@ -78,7 +78,7 @@ def _extract(sentences_file, model, vocabs_dict, transforms, enc_id, batch_size= yield (memory_bank, src_lengths) -def extract(opts, vocabs_dict, model, model_opt, transforms): +def extract(opts, vocabs_dict, model, model_opts, transforms): """Compute representations drawn from the encoder and save them to file.""" sentence_reps = [] for src, src_length in _extract( @@ -95,7 +95,7 @@ def extract(opts, vocabs_dict, model, model_opt, transforms): torch.save(sentence_reps, opts.dump_file) -def estimate(opts, vocabs_dict, model, model_opt, transforms): +def estimate(opts, vocabs_dict, model, model_opts, transforms): """Estimate the matrix-variate distribution of representations drawn from the encoder.""" try: import sklearn.covariance @@ -134,7 +134,7 @@ def estimate(opts, vocabs_dict, model, model_opt, transforms): # return sampling_fn -def classify(opts, vocabs_dict, model, model_opt, transforms): +def classify(opts, vocabs_dict, model, model_opts, transforms): """Learn a simple SGD classifier using representations drawn from the encoder.""" try: import sklearn.linear_model @@ -224,7 +224,7 @@ def main(): # ArgumentParser.validate_translate_opts_dynamic(opts) opts.enc_id = opts.enc_id or opts.src_lang - vocabs_dict, model, model_opt = load_test_multitask_model(opts, opts.model) + vocabs_dict, model, model_opts = load_test_multitask_model(opts, opts.model) command_fn = { fn.__name__: fn for fn in [extract, estimate, classify] @@ -238,7 +238,7 @@ def main(): ] transform = TransformPipe.build_from(data_transform) - command_fn(opts, vocabs_dict, model.to(opts.device), model_opt, transform) + command_fn(opts, vocabs_dict, model.to(opts.device), model_opts, transform) if __name__ == '__main__': diff --git a/tools/extract_embeddings.py b/tools/extract_embeddings.py index b439ee06..49ecbca9 100644 --- a/tools/extract_embeddings.py +++ b/tools/extract_embeddings.py @@ -38,22 +38,22 @@ def main(): # Add in default model arguments, possibly added since training. checkpoint = torch.load(opts.model, map_location=lambda storage, loc: storage) - model_opt = checkpoint['opts'] + model_opts = checkpoint['opts'] fields = checkpoint['vocab'] src_dict = fields['src'].base_field.vocab # assumes src is text tgt_dict = fields['tgt'].base_field.vocab - model_opt = checkpoint['opts'] + model_opts = checkpoint['opts'] for arg in dummy_opt.__dict__: - if arg not in model_opt: - model_opt.__dict__[arg] = dummy_opt.__dict__[arg] + if arg not in model_opts: + model_opts.__dict__[arg] = dummy_opt.__dict__[arg] # build_base_model expects updated and validated opts - ArgumentParser.update_model_opts(model_opt) - ArgumentParser.validate_model_opts(model_opt) + ArgumentParser.update_model_opts(model_opts) + ArgumentParser.validate_model_opts(model_opts) - model = mammoth.model_builder.build_base_model(model_opt, fields, use_gpu(opts), checkpoint) + model = mammoth.model_builder.build_base_model(model_opts, fields, use_gpu(opts), checkpoint) encoder = model.encoder # no encoder for LM task decoder = model.decoder